use std::{pin::Pin, time::Duration};
use async_stream::stream;
use futures_core::stream::Stream;
use futures_util::{StreamExt, pin_mut};
use matrix_sdk::{Client, LEASE_DURATION_MS, SlidingSync, sleep::sleep};
use ruma::{api::client::sync::sync_events::v5 as http, assign};
use tokio::sync::OwnedMutexGuard;
use tracing::{Span, debug, instrument, trace};
pub struct EncryptionSyncPermit(());
impl EncryptionSyncPermit {
pub(crate) fn new() -> Self {
Self(())
}
}
impl EncryptionSyncPermit {
#[doc(hidden)]
pub fn new_for_testing() -> Self {
Self::new()
}
}
pub enum WithLocking {
Yes,
No,
}
impl From<bool> for WithLocking {
fn from(value: bool) -> Self {
if value { Self::Yes } else { Self::No }
}
}
pub struct EncryptionSyncService {
client: Client,
sliding_sync: SlidingSync,
with_locking: bool,
}
impl EncryptionSyncService {
pub async fn new(
client: Client,
poll_and_network_timeouts: Option<(Duration, Duration)>,
with_locking: WithLocking,
) -> Result<Self, Error> {
let mut builder = client
.sliding_sync("encryption")
.map_err(Error::SlidingSync)?
.with_to_device_extension(
assign!(http::request::ToDevice::default(), { enabled: Some(true)}),
)
.with_e2ee_extension(assign!(http::request::E2EE::default(), { enabled: Some(true)}));
if let Some((poll_timeout, network_timeout)) = poll_and_network_timeouts {
builder = builder.poll_timeout(poll_timeout).network_timeout(network_timeout);
}
let sliding_sync = builder.build().await.map_err(Error::SlidingSync)?;
let with_locking = matches!(with_locking, WithLocking::Yes);
if with_locking {
match client
.encryption()
.enable_cross_process_store_lock(
client.cross_process_store_locks_holder_name().to_owned(),
)
.await
{
Ok(()) | Err(matrix_sdk::Error::BadCryptoStoreState) => {
}
Err(err) => {
return Err(Error::ClientError(err));
}
}
}
Ok(Self { client, sliding_sync, with_locking })
}
#[instrument(skip_all, fields(store_generation))]
pub async fn run_fixed_iterations(
self,
num_iterations: u8,
_permit: OwnedMutexGuard<EncryptionSyncPermit>,
) -> Result<(), Error> {
let sync = self.sliding_sync.sync();
pin_mut!(sync);
let lock_guard = if self.with_locking {
let mut lock_guard =
self.client.encryption().try_lock_store_once().await.map_err(Error::LockError)?;
if lock_guard.is_none() {
tracing::debug!(
"Lock was already taken, and we're not the main loop; retrying in {}ms...",
LEASE_DURATION_MS
);
sleep(Duration::from_millis(LEASE_DURATION_MS.into())).await;
lock_guard = self
.client
.encryption()
.try_lock_store_once()
.await
.map_err(Error::LockError)?;
if lock_guard.is_none() {
tracing::debug!(
"Second attempt at locking outside the main app failed, aborting."
);
return Ok(());
}
}
lock_guard
} else {
None
};
Span::current().record("store_generation", lock_guard.map(|guard| guard.generation()));
for _ in 0..num_iterations {
match sync.next().await {
Some(Ok(update_summary)) => {
if !update_summary.lists.is_empty() {
debug!(?update_summary.lists, "unexpected non-empty list of lists in encryption sync API");
}
if !update_summary.rooms.is_empty() {
debug!(?update_summary.rooms, "unexpected non-empty list of rooms in encryption sync API");
}
trace!("Encryption sync received an update!");
}
Some(Err(err)) => {
trace!("Encryption sync stopped because of an error: {err:#}");
return Err(Error::SlidingSync(err));
}
None => {
trace!("Encryption sync properly terminated.");
break;
}
}
}
Ok(())
}
#[doc(hidden)] pub fn sync(
&self,
_permit: OwnedMutexGuard<EncryptionSyncPermit>,
) -> impl Stream<Item = Result<(), Error>> + '_ {
stream!({
let sync = self.sliding_sync.sync();
pin_mut!(sync);
loop {
match self.next_sync_with_lock(&mut sync).await? {
Some(Ok(update_summary)) => {
if !update_summary.lists.is_empty() {
debug!(?update_summary.lists, "unexpected non-empty list of lists in encryption sync API");
}
if !update_summary.rooms.is_empty() {
debug!(?update_summary.rooms, "unexpected non-empty list of rooms in encryption sync API");
}
trace!("Encryption sync received an update!");
yield Ok(());
continue;
}
Some(Err(err)) => {
trace!("Encryption sync stopped because of an error: {err:#}");
yield Err(Error::SlidingSync(err));
break;
}
None => {
trace!("Encryption sync properly terminated.");
break;
}
}
}
})
}
#[instrument(skip_all, fields(store_generation))]
async fn next_sync_with_lock<Item>(
&self,
sync: &mut Pin<&mut impl Stream<Item = Item>>,
) -> Result<Option<Item>, Error> {
let guard = if self.with_locking {
self.client.encryption().spin_lock_store(Some(60000)).await.map_err(Error::LockError)?
} else {
None
};
Span::current().record("store_generation", guard.map(|guard| guard.generation()));
Ok(sync.next().await)
}
pub(crate) fn stop_sync(&self) -> Result<(), Error> {
self.sliding_sync.stop_sync().map_err(Error::SlidingSync)?;
Ok(())
}
pub(crate) async fn expire_sync_session(&self) {
self.sliding_sync.expire_session().await;
}
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("Something wrong happened in sliding sync: {0:#}")]
SlidingSync(matrix_sdk::Error),
#[error("Locking failed: {0:#}")]
LockError(matrix_sdk::Error),
#[error(transparent)]
ClientError(matrix_sdk::Error),
}