use std::future::Future;
use std::pin::Pin;
use matrix_sdk::Room;
use matrix_sdk::ruma;
use matrix_sdk::ruma::OwnedUserId;
use matrix_sdk::ruma::api::client::config::set_room_account_data;
use matrix_sdk::ruma::events::{
RoomAccountDataEvent, RoomAccountDataEventContent, StaticEventContent,
};
use quick_cache::sync::Cache;
use super::ConfigError;
use crate::helpers::encryption::Manager as EncryptionManager;
pub trait RoomConfig: Clone + serde::Serialize + serde::de::DeserializeOwned {}
pub trait RoomConfigCarrierContent:
StaticEventContent<IsPrefix = ruma::events::False>
+ RoomAccountDataEventContent
+ serde::de::DeserializeOwned
{
fn new(payload: String) -> Self;
fn payload(&self) -> &str;
}
pub struct Manager<ConfigType, ConfigCarrierContentType> {
user_id: OwnedUserId,
encryption_manager: EncryptionManager,
#[allow(clippy::type_complexity)]
initial_room_config_callback:
Box<dyn Fn(Room) -> Pin<Box<dyn Future<Output = ConfigType> + Send>> + Send + Sync>,
lru_cache: Option<Cache<String, ConfigType>>,
lock: tokio::sync::Mutex<()>,
_marker_config: std::marker::PhantomData<ConfigType>,
_marker_carrier: std::marker::PhantomData<ConfigCarrierContentType>,
}
impl<ConfigType, ConfigCarrierContentType> Manager<ConfigType, ConfigCarrierContentType>
where
ConfigType: RoomConfig,
ConfigCarrierContentType: RoomConfigCarrierContent,
{
pub fn new<InitialRoomConfigCallback>(
user_id: OwnedUserId,
encryption_manager: EncryptionManager,
initial_room_config_callback: InitialRoomConfigCallback,
lru_cache_size: Option<usize>,
) -> Self
where
InitialRoomConfigCallback:
Fn(Room) -> Pin<Box<dyn Future<Output = ConfigType> + Send>> + Send + Sync + 'static,
{
let lru_cache = lru_cache_size.map(Cache::new);
Self {
user_id,
encryption_manager,
initial_room_config_callback: Box::new(initial_room_config_callback),
lru_cache,
lock: tokio::sync::Mutex::new(()),
_marker_config: std::marker::PhantomData,
_marker_carrier: std::marker::PhantomData,
}
}
#[tracing::instrument(skip_all, name="room_config_get_or_create", fields(room_id = room.room_id().as_str()))]
pub async fn get_or_create_for_room(&self, room: &Room) -> Result<ConfigType, ConfigError> {
let start = std::time::Instant::now();
tracing::debug!("Request for room config");
let _lock = self.lock.lock().await;
let Some(lru_cache) = &self.lru_cache else {
let result = self
.do_get_or_create_for_room_without_locking_and_caching(room)
.await;
tracing::trace!(
"Returning uncached room config (after {:?}) for room {}",
start.elapsed(),
room.room_id()
);
return result;
};
let guard = lru_cache
.get_value_or_guard_async(room.room_id().as_str())
.await;
match guard {
Ok(config) => {
tracing::trace!("Returning existing cached room config..");
return Ok(config);
}
Err(guard) => {
let config = self
.do_get_or_create_for_room_without_locking_and_caching(room)
.await?;
let _ = guard.insert(config.clone());
tracing::trace!(
"Returning now-cached room config (after {:?}) for room {}",
start.elapsed(),
room.room_id()
);
return Ok(config);
}
}
}
async fn do_get_or_create_for_room_without_locking_and_caching(
&self,
room: &Room,
) -> Result<ConfigType, ConfigError> {
tracing::trace!("Fetching config for room: {}..", room.room_id());
let data = room
.account_data_static::<ConfigCarrierContentType>()
.await?;
let room_config: ConfigType = match data {
Some(raw_event) => {
tracing::trace!("Found existing room config: {:?}", raw_event);
let event: serde_json::Result<RoomAccountDataEvent<ConfigCarrierContentType>> =
raw_event.deserialize();
match event {
Ok(event) => {
let room_config = super::utils::parse_encrypted_config(
&self.encryption_manager,
event.content.payload(),
);
if let Some(room_config) = room_config {
tracing::trace!("Reusing existing room config");
room_config
} else {
tracing::warn!(
"Found existing room config, but failed decrypting/parsing it.. Making new.."
);
self.do_create_new_for_room_without_locking(room).await?
}
}
Err(err) => {
tracing::warn!("Failed parsing existing room config: {:?}", err);
self.do_create_new_for_room_without_locking(room).await?
}
}
}
None => self.do_create_new_for_room_without_locking(room).await?,
};
tracing::trace!("Returning room config");
Ok(room_config)
}
#[tracing::instrument(skip_all, name="room_config_create_new", fields(room_id = room.room_id().as_str()))]
pub async fn create_new_for_room(&self, room: &Room) -> Result<ConfigType, ConfigError> {
let _lock = self.lock.lock().await;
self.do_create_new_for_room_without_locking(room).await
}
async fn do_create_new_for_room_without_locking(
&self,
room: &Room,
) -> Result<ConfigType, ConfigError> {
tracing::info!("Creating new room config");
let config = (self.initial_room_config_callback)(room.clone()).await;
tracing::trace!("Persisting new room config");
self.persist_without_locking(room, &config).await?;
tracing::trace!("Persisted new room config");
Ok(config)
}
#[tracing::instrument(skip_all, name="room_config_persist", fields(room_id = room.room_id().as_str()))]
pub async fn persist(&self, room: &Room, config: &ConfigType) -> Result<(), ConfigError> {
let _lock = self.lock.lock().await;
self.persist_without_locking(room, config).await?;
if let Some(lru_cache) = &self.lru_cache {
let _ = lru_cache.replace(room.room_id().as_str().to_owned(), config.clone(), false);
}
Ok(())
}
async fn persist_without_locking(
&self,
room: &Room,
config: &ConfigType,
) -> Result<(), ConfigError> {
let config_json =
serde_json::to_string(config).map_err(ConfigError::SerializeDeserialize)?;
let config_json_encrypted = self
.encryption_manager
.encrypt_string(&config_json)
.map_err(ConfigError::Encryption)?;
let encrypted_config = ConfigCarrierContentType::new(config_json_encrypted);
let request = set_room_account_data::v3::Request::new(
self.user_id.clone(),
room.room_id().to_owned(),
&encrypted_config,
)
.map_err(ConfigError::SerializeDeserialize)?;
room.client()
.send(request)
.await
.map_err(ConfigError::SdkHttp)?;
Ok(())
}
}