use std::{sync::Arc, time::Duration};
use eyeball::{SharedObservable, Subscriber};
use futures_util::{
StreamExt as _,
future::{Either, select},
pin_mut,
};
use matrix_sdk::{
Client,
config::RequestConfig,
executor::{JoinHandle, spawn},
sleep::sleep,
};
use thiserror::Error;
use tokio::sync::{
Mutex as AsyncMutex, OwnedMutexGuard,
mpsc::{Receiver, Sender},
};
use tracing::{Instrument, Level, Span, error, info, instrument, trace, warn};
use crate::{
encryption_sync_service::{self, EncryptionSyncPermit, EncryptionSyncService},
room_list_service::{
self, DEFAULT_CONNECTION_ID, DEFAULT_LIST_TIMELINE_LIMIT, RoomListService,
},
};
#[derive(Clone, Debug)]
pub enum State {
Idle,
Running,
Terminated,
Error(Arc<Error>),
Offline,
}
enum MaybeAcquiredPermit {
Acquired(OwnedMutexGuard<EncryptionSyncPermit>),
Unacquired(Arc<AsyncMutex<EncryptionSyncPermit>>),
}
impl MaybeAcquiredPermit {
async fn acquire(self) -> OwnedMutexGuard<EncryptionSyncPermit> {
match self {
MaybeAcquiredPermit::Acquired(owned_mutex_guard) => owned_mutex_guard,
MaybeAcquiredPermit::Unacquired(lock) => lock.lock_owned().await,
}
}
}
struct SyncTaskSupervisor {
task: JoinHandle<()>,
termination_sender: Sender<TerminationReport>,
}
impl SyncTaskSupervisor {
async fn new(
inner: &SyncServiceInner,
room_list_service: Arc<RoomListService>,
encryption_sync_permit: Arc<AsyncMutex<EncryptionSyncPermit>>,
) -> Self {
let (task, termination_sender) =
Self::spawn_supervisor_task(inner, room_list_service, encryption_sync_permit).await;
Self { task, termination_sender }
}
async fn offline_check(
client: &Client,
receiver: &mut Receiver<TerminationReport>,
) -> Option<TerminationReport> {
info!("Entering the offline mode");
let wait_for_termination_report = async {
loop {
let report =
receiver.recv().await.unwrap_or_else(TerminationReport::supervisor_error);
match report.origin {
TerminationOrigin::EncryptionSync | TerminationOrigin::RoomList => {}
TerminationOrigin::Supervisor => break report,
}
}
};
let wait_to_be_online = async move {
loop {
let request_config = RequestConfig::default().retry_limit(5);
match client.fetch_server_versions(Some(request_config)).await {
Ok(_) => break,
Err(_) => sleep(Duration::from_millis(100)).await,
}
}
};
pin_mut!(wait_for_termination_report);
pin_mut!(wait_to_be_online);
let maybe_termination_report = select(wait_for_termination_report, wait_to_be_online).await;
let report = match maybe_termination_report {
Either::Left((termination_report, _)) => Some(termination_report),
Either::Right((_, _)) => None,
};
info!("Exiting offline mode: {report:?}");
report
}
async fn spawn_supervisor_task(
inner: &SyncServiceInner,
room_list_service: Arc<RoomListService>,
encryption_sync_permit: Arc<AsyncMutex<EncryptionSyncPermit>>,
) -> (JoinHandle<()>, Sender<TerminationReport>) {
let (sender, mut receiver) = tokio::sync::mpsc::channel(16);
let encryption_sync = inner.encryption_sync_service.clone();
let state = inner.state.clone();
let termination_sender = sender.clone();
let mut sync_permit_guard =
MaybeAcquiredPermit::Acquired(encryption_sync_permit.clone().lock_owned().await);
let offline_mode = inner.with_offline_mode;
let parent_span = inner.parent_span.clone();
let future = async move {
loop {
let (room_list_task, encryption_sync_task) = Self::spawn_child_tasks(
room_list_service.clone(),
encryption_sync.clone(),
sync_permit_guard,
sender.clone(),
parent_span.clone(),
)
.await;
sync_permit_guard = MaybeAcquiredPermit::Unacquired(encryption_sync_permit.clone());
let report = if let Some(report) = receiver.recv().await {
report
} else {
info!("internal channel has been closed?");
TerminationReport::supervisor_error()
};
let (stop_room_list, stop_encryption) = match &report.origin {
TerminationOrigin::EncryptionSync => (true, false),
TerminationOrigin::RoomList => (false, true),
TerminationOrigin::Supervisor => (true, true),
};
if stop_room_list {
if let Err(err) = room_list_service.stop_sync() {
warn!(?report, "unable to stop room list service: {err:#}");
}
if report.has_expired() {
room_list_service.expire_sync_session().await;
}
}
if let Err(err) = room_list_task.await {
error!("when awaiting room list service: {err:#}");
}
if stop_encryption {
if let Err(err) = encryption_sync.stop_sync() {
warn!(?report, "unable to stop encryption sync: {err:#}");
}
if report.has_expired() {
encryption_sync.expire_sync_session().await;
}
}
if let Err(err) = encryption_sync_task.await {
error!("when awaiting encryption sync: {err:#}");
}
if let Some(error) = report.error {
if offline_mode {
state.set(State::Offline);
let client = room_list_service.client();
if let Some(report) = Self::offline_check(client, &mut receiver).await {
if let Some(error) = report.error {
state.set(State::Error(Arc::new(error)));
} else {
state.set(State::Idle);
}
break;
}
state.set(State::Running);
} else {
state.set(State::Error(Arc::new(error)));
break;
}
} else if matches!(report.origin, TerminationOrigin::Supervisor) {
state.set(State::Idle);
break;
} else {
state.set(State::Terminated);
break;
}
}
}
.instrument(tracing::span!(Level::WARN, "supervisor task"));
let task = spawn(future);
(task, termination_sender)
}
async fn spawn_child_tasks(
room_list_service: Arc<RoomListService>,
encryption_sync_service: Arc<EncryptionSyncService>,
sync_permit_guard: MaybeAcquiredPermit,
sender: Sender<TerminationReport>,
parent_span: Span,
) -> (JoinHandle<()>, JoinHandle<()>) {
let room_list_task = spawn(
Self::room_list_sync_task(room_list_service, sender.clone())
.instrument(parent_span.clone()),
);
let encryption_sync_task = spawn(
Self::encryption_sync_task(
encryption_sync_service,
sender.clone(),
sync_permit_guard.acquire().await,
)
.instrument(parent_span),
);
(room_list_task, encryption_sync_task)
}
async fn encryption_sync_task(
encryption_sync: Arc<EncryptionSyncService>,
sender: Sender<TerminationReport>,
sync_permit_guard: OwnedMutexGuard<EncryptionSyncPermit>,
) {
let encryption_sync_stream = encryption_sync.sync(sync_permit_guard);
pin_mut!(encryption_sync_stream);
let termination_report = loop {
match encryption_sync_stream.next().await {
Some(Ok(())) => {
}
Some(Err(error)) => {
let termination_report = TerminationReport::encryption_sync(Some(error));
if !termination_report.has_expired() {
error!(
"Error while processing encryption in sync service: {:#?}",
termination_report.error
);
}
break termination_report;
}
None => {
break TerminationReport::encryption_sync(None);
}
}
};
if let Err(err) = sender.send(termination_report).await {
error!("Error while sending termination report: {err:#}");
}
}
async fn room_list_sync_task(
room_list_service: Arc<RoomListService>,
sender: Sender<TerminationReport>,
) {
let room_list_stream = room_list_service.sync();
pin_mut!(room_list_stream);
let termination_report = loop {
match room_list_stream.next().await {
Some(Ok(())) => {
}
Some(Err(error)) => {
let termination_report = TerminationReport::room_list(Some(error));
if !termination_report.has_expired() {
error!(
"Error while processing room list in sync service: {:#?}",
termination_report.error
);
}
break termination_report;
}
None => {
break TerminationReport::room_list(None);
}
}
};
if let Err(err) = sender.send(termination_report).await {
error!("Error while sending termination report: {err:#}");
}
}
async fn shutdown(self) {
match self.termination_sender.send(TerminationReport::supervisor()).await {
Ok(_) => {
let _ = self.task.await.inspect_err(|err| {
error!("The supervisor task has stopped unexpectedly: {err:?}");
});
}
Err(err) => {
error!("Couldn't send the termination report to the supervisor task: {err}");
self.task.abort();
}
}
}
}
struct SyncServiceInner {
encryption_sync_service: Arc<EncryptionSyncService>,
with_offline_mode: bool,
state: SharedObservable<State>,
parent_span: Span,
supervisor: Option<SyncTaskSupervisor>,
}
impl SyncServiceInner {
async fn start(
&mut self,
room_list_service: Arc<RoomListService>,
encryption_sync_permit: Arc<AsyncMutex<EncryptionSyncPermit>>,
) {
trace!("starting sync service");
self.supervisor =
Some(SyncTaskSupervisor::new(self, room_list_service, encryption_sync_permit).await);
self.state.set(State::Running);
}
async fn stop(&mut self) {
trace!("pausing sync service");
if let Some(supervisor) = self.supervisor.take() {
supervisor.shutdown().await;
} else {
error!("The sync service was not properly started, the supervisor task doesn't exist");
}
}
async fn restart(
&mut self,
room_list_service: Arc<RoomListService>,
encryption_sync_permit: Arc<AsyncMutex<EncryptionSyncPermit>>,
) {
self.stop().await;
self.start(room_list_service, encryption_sync_permit).await;
}
}
pub struct SyncService {
inner: Arc<AsyncMutex<SyncServiceInner>>,
room_list_service: Arc<RoomListService>,
state: SharedObservable<State>,
encryption_sync_permit: Arc<AsyncMutex<EncryptionSyncPermit>>,
}
impl SyncService {
pub fn builder(client: Client) -> SyncServiceBuilder {
SyncServiceBuilder::new(client)
}
pub fn room_list_service(&self) -> Arc<RoomListService> {
self.room_list_service.clone()
}
pub fn state(&self) -> Subscriber<State> {
self.state.subscribe()
}
pub async fn start(&self) {
let mut inner = self.inner.lock().await;
match inner.state.get() {
State::Running => {}
State::Offline => {
inner
.restart(self.room_list_service.clone(), self.encryption_sync_permit.clone())
.await
}
State::Idle | State::Terminated | State::Error(_) => {
inner
.start(self.room_list_service.clone(), self.encryption_sync_permit.clone())
.await
}
}
}
#[instrument(skip_all)]
pub async fn stop(&self) {
let mut inner = self.inner.lock().await;
match inner.state.get() {
State::Idle | State::Terminated | State::Error(_) => {
return;
}
State::Running | State::Offline => {}
}
inner.stop().await;
}
#[instrument(skip_all)]
pub async fn expire_sessions(&self) {
self.stop().await;
self.room_list_service.expire_sync_session().await;
self.inner.lock().await.encryption_sync_service.expire_sync_session().await;
}
pub fn try_get_encryption_sync_permit(&self) -> Option<OwnedMutexGuard<EncryptionSyncPermit>> {
self.encryption_sync_permit.clone().try_lock_owned().ok()
}
}
#[derive(Debug)]
enum TerminationOrigin {
EncryptionSync,
RoomList,
Supervisor,
}
#[derive(Debug)]
struct TerminationReport {
origin: TerminationOrigin,
error: Option<Error>,
}
impl TerminationReport {
fn encryption_sync(error: Option<encryption_sync_service::Error>) -> Self {
Self { origin: TerminationOrigin::EncryptionSync, error: error.map(Error::EncryptionSync) }
}
fn room_list(error: Option<room_list_service::Error>) -> Self {
Self { origin: TerminationOrigin::RoomList, error: error.map(Error::RoomList) }
}
fn supervisor_error() -> Self {
Self { origin: TerminationOrigin::Supervisor, error: Some(Error::Supervisor) }
}
fn supervisor() -> Self {
Self { origin: TerminationOrigin::Supervisor, error: None }
}
fn has_expired(&self) -> bool {
match &self.error {
Some(Error::RoomList(room_list_service::Error::SlidingSync(error)))
| Some(Error::EncryptionSync(encryption_sync_service::Error::SlidingSync(error))) => {
error.client_api_error_kind() == Some(&ruma::api::error::ErrorKind::UnknownPos)
}
_ => false,
}
}
}
#[doc(hidden)]
impl SyncService {
pub async fn is_supervisor_running(&self) -> bool {
self.inner.lock().await.supervisor.is_some()
}
}
#[derive(Clone)]
pub struct SyncServiceBuilder {
client: Client,
with_offline_mode: bool,
with_share_pos: bool,
room_list_conn_id: String,
room_list_timeline_limit: u32,
parent_span: Span,
}
impl SyncServiceBuilder {
fn new(client: Client) -> Self {
Self {
client,
with_offline_mode: false,
with_share_pos: true,
room_list_conn_id: DEFAULT_CONNECTION_ID.to_owned(),
room_list_timeline_limit: DEFAULT_LIST_TIMELINE_LIMIT,
parent_span: Span::none(),
}
}
pub fn with_offline_mode(mut self) -> Self {
self.with_offline_mode = true;
self
}
pub fn with_share_pos(mut self, enable: bool) -> Self {
self.with_share_pos = enable;
self
}
pub fn with_room_list_conn_id(mut self, conn_id: String) -> Self {
self.room_list_conn_id = conn_id;
self
}
pub fn with_room_list_timeline_limit(mut self, limit: u32) -> Self {
self.room_list_timeline_limit = limit;
self
}
pub fn with_parent_span(mut self, parent_span: Span) -> Self {
self.parent_span = parent_span;
self
}
pub async fn build(self) -> Result<SyncService, Error> {
let Self {
client,
with_offline_mode,
with_share_pos,
room_list_conn_id,
room_list_timeline_limit,
parent_span,
} = self;
let encryption_sync_permit = Arc::new(AsyncMutex::new(EncryptionSyncPermit::new()));
let room_list = RoomListService::new_with(
client.clone(),
with_share_pos,
&room_list_conn_id,
room_list_timeline_limit,
)
.await?;
let encryption_sync = Arc::new(EncryptionSyncService::new(client, None).await?);
let room_list_service = Arc::new(room_list);
let state = SharedObservable::new(State::Idle);
Ok(SyncService {
state: state.clone(),
room_list_service,
encryption_sync_permit,
inner: Arc::new(AsyncMutex::new(SyncServiceInner {
supervisor: None,
encryption_sync_service: encryption_sync,
state,
with_offline_mode,
parent_span,
})),
})
}
}
#[derive(Debug, Error)]
pub enum Error {
#[error(transparent)]
RoomList(#[from] room_list_service::Error),
#[error(transparent)]
EncryptionSync(#[from] encryption_sync_service::Error),
#[error("the supervisor channel has run into an unexpected error")]
Supervisor,
}