use std::{
borrow::Borrow,
collections::{BTreeMap, BTreeSet},
fmt,
sync::Arc,
};
use as_variant::as_variant;
use async_trait::async_trait;
use matrix_sdk_common::AsyncTraitDeps;
use ruma::{
events::{
presence::PresenceEvent,
receipt::{Receipt, ReceiptThread, ReceiptType},
AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, EmptyStateKey, GlobalAccountDataEvent,
GlobalAccountDataEventContent, GlobalAccountDataEventType, RedactContent,
RedactedStateEventContent, RoomAccountDataEvent, RoomAccountDataEventContent,
RoomAccountDataEventType, StateEventType, StaticEventContent, StaticStateEventContent,
},
serde::Raw,
EventId, MxcUri, OwnedEventId, OwnedUserId, RoomId, UserId,
};
use super::{StateChanges, StoreError};
use crate::{
deserialized_responses::{RawAnySyncOrStrippedState, RawMemberEvent, RawSyncOrStrippedState},
media::MediaRequest,
MinimalRoomMemberEvent, RoomInfo, RoomMemberships,
};
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
pub trait StateStore: AsyncTraitDeps {
type Error: fmt::Debug + Into<StoreError> + From<serde_json::Error>;
async fn get_kv_data(
&self,
key: StateStoreDataKey<'_>,
) -> Result<Option<StateStoreDataValue>, Self::Error>;
async fn set_kv_data(
&self,
key: StateStoreDataKey<'_>,
value: StateStoreDataValue,
) -> Result<(), Self::Error>;
async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<(), Self::Error>;
async fn save_changes(&self, changes: &StateChanges) -> Result<(), Self::Error>;
async fn get_presence_event(
&self,
user_id: &UserId,
) -> Result<Option<Raw<PresenceEvent>>, Self::Error>;
async fn get_presence_events(
&self,
user_ids: &[OwnedUserId],
) -> Result<Vec<Raw<PresenceEvent>>, Self::Error>;
async fn get_state_event(
&self,
room_id: &RoomId,
event_type: StateEventType,
state_key: &str,
) -> Result<Option<RawAnySyncOrStrippedState>, Self::Error>;
async fn get_state_events(
&self,
room_id: &RoomId,
event_type: StateEventType,
) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error>;
async fn get_state_events_for_keys(
&self,
room_id: &RoomId,
event_type: StateEventType,
state_keys: &[&str],
) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error>;
async fn get_profile(
&self,
room_id: &RoomId,
user_id: &UserId,
) -> Result<Option<MinimalRoomMemberEvent>, Self::Error>;
async fn get_profiles<'a>(
&self,
room_id: &RoomId,
user_ids: &'a [OwnedUserId],
) -> Result<BTreeMap<&'a UserId, MinimalRoomMemberEvent>, Self::Error>;
async fn get_user_ids(
&self,
room_id: &RoomId,
memberships: RoomMemberships,
) -> Result<Vec<OwnedUserId>, Self::Error>;
#[deprecated = "Use get_user_ids with RoomMemberships::INVITE instead."]
async fn get_invited_user_ids(&self, room_id: &RoomId)
-> Result<Vec<OwnedUserId>, Self::Error>;
#[deprecated = "Use get_user_ids with RoomMemberships::JOIN instead."]
async fn get_joined_user_ids(&self, room_id: &RoomId) -> Result<Vec<OwnedUserId>, Self::Error>;
async fn get_room_infos(&self) -> Result<Vec<RoomInfo>, Self::Error>;
#[deprecated = "Use get_room_infos instead and filter by RoomState"]
async fn get_stripped_room_infos(&self) -> Result<Vec<RoomInfo>, Self::Error>;
async fn get_users_with_display_name(
&self,
room_id: &RoomId,
display_name: &str,
) -> Result<BTreeSet<OwnedUserId>, Self::Error>;
async fn get_users_with_display_names<'a>(
&self,
room_id: &RoomId,
display_names: &'a [String],
) -> Result<BTreeMap<&'a str, BTreeSet<OwnedUserId>>, Self::Error>;
async fn get_account_data_event(
&self,
event_type: GlobalAccountDataEventType,
) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>, Self::Error>;
async fn get_room_account_data_event(
&self,
room_id: &RoomId,
event_type: RoomAccountDataEventType,
) -> Result<Option<Raw<AnyRoomAccountDataEvent>>, Self::Error>;
async fn get_user_room_receipt_event(
&self,
room_id: &RoomId,
receipt_type: ReceiptType,
thread: ReceiptThread,
user_id: &UserId,
) -> Result<Option<(OwnedEventId, Receipt)>, Self::Error>;
async fn get_event_room_receipt_events(
&self,
room_id: &RoomId,
receipt_type: ReceiptType,
thread: ReceiptThread,
event_id: &EventId,
) -> Result<Vec<(OwnedUserId, Receipt)>, Self::Error>;
async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>, Self::Error>;
async fn set_custom_value(
&self,
key: &[u8],
value: Vec<u8>,
) -> Result<Option<Vec<u8>>, Self::Error>;
async fn remove_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>, Self::Error>;
async fn add_media_content(
&self,
request: &MediaRequest,
content: Vec<u8>,
) -> Result<(), Self::Error>;
async fn get_media_content(
&self,
request: &MediaRequest,
) -> Result<Option<Vec<u8>>, Self::Error>;
async fn remove_media_content(&self, request: &MediaRequest) -> Result<(), Self::Error>;
async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<(), Self::Error>;
async fn remove_room(&self, room_id: &RoomId) -> Result<(), Self::Error>;
}
#[repr(transparent)]
struct EraseStateStoreError<T>(T);
#[cfg(not(tarpaulin_include))]
impl<T: fmt::Debug> fmt::Debug for EraseStateStoreError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl<T: StateStore> StateStore for EraseStateStoreError<T> {
type Error = StoreError;
async fn get_kv_data(
&self,
key: StateStoreDataKey<'_>,
) -> Result<Option<StateStoreDataValue>, Self::Error> {
self.0.get_kv_data(key).await.map_err(Into::into)
}
async fn set_kv_data(
&self,
key: StateStoreDataKey<'_>,
value: StateStoreDataValue,
) -> Result<(), Self::Error> {
self.0.set_kv_data(key, value).await.map_err(Into::into)
}
async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<(), Self::Error> {
self.0.remove_kv_data(key).await.map_err(Into::into)
}
async fn save_changes(&self, changes: &StateChanges) -> Result<(), Self::Error> {
self.0.save_changes(changes).await.map_err(Into::into)
}
async fn get_presence_event(
&self,
user_id: &UserId,
) -> Result<Option<Raw<PresenceEvent>>, Self::Error> {
self.0.get_presence_event(user_id).await.map_err(Into::into)
}
async fn get_presence_events(
&self,
user_ids: &[OwnedUserId],
) -> Result<Vec<Raw<PresenceEvent>>, Self::Error> {
self.0.get_presence_events(user_ids).await.map_err(Into::into)
}
async fn get_state_event(
&self,
room_id: &RoomId,
event_type: StateEventType,
state_key: &str,
) -> Result<Option<RawAnySyncOrStrippedState>, Self::Error> {
self.0.get_state_event(room_id, event_type, state_key).await.map_err(Into::into)
}
async fn get_state_events(
&self,
room_id: &RoomId,
event_type: StateEventType,
) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
self.0.get_state_events(room_id, event_type).await.map_err(Into::into)
}
async fn get_state_events_for_keys(
&self,
room_id: &RoomId,
event_type: StateEventType,
state_keys: &[&str],
) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
self.0.get_state_events_for_keys(room_id, event_type, state_keys).await.map_err(Into::into)
}
async fn get_profile(
&self,
room_id: &RoomId,
user_id: &UserId,
) -> Result<Option<MinimalRoomMemberEvent>, Self::Error> {
self.0.get_profile(room_id, user_id).await.map_err(Into::into)
}
async fn get_profiles<'a>(
&self,
room_id: &RoomId,
user_ids: &'a [OwnedUserId],
) -> Result<BTreeMap<&'a UserId, MinimalRoomMemberEvent>, Self::Error> {
self.0.get_profiles(room_id, user_ids).await.map_err(Into::into)
}
async fn get_user_ids(
&self,
room_id: &RoomId,
memberships: RoomMemberships,
) -> Result<Vec<OwnedUserId>, Self::Error> {
self.0.get_user_ids(room_id, memberships).await.map_err(Into::into)
}
async fn get_invited_user_ids(
&self,
room_id: &RoomId,
) -> Result<Vec<OwnedUserId>, Self::Error> {
self.0.get_user_ids(room_id, RoomMemberships::INVITE).await.map_err(Into::into)
}
async fn get_joined_user_ids(&self, room_id: &RoomId) -> Result<Vec<OwnedUserId>, Self::Error> {
self.0.get_user_ids(room_id, RoomMemberships::JOIN).await.map_err(Into::into)
}
async fn get_room_infos(&self) -> Result<Vec<RoomInfo>, Self::Error> {
self.0.get_room_infos().await.map_err(Into::into)
}
#[allow(deprecated)]
async fn get_stripped_room_infos(&self) -> Result<Vec<RoomInfo>, Self::Error> {
self.0.get_stripped_room_infos().await.map_err(Into::into)
}
async fn get_users_with_display_name(
&self,
room_id: &RoomId,
display_name: &str,
) -> Result<BTreeSet<OwnedUserId>, Self::Error> {
self.0.get_users_with_display_name(room_id, display_name).await.map_err(Into::into)
}
async fn get_users_with_display_names<'a>(
&self,
room_id: &RoomId,
display_names: &'a [String],
) -> Result<BTreeMap<&'a str, BTreeSet<OwnedUserId>>, Self::Error> {
self.0.get_users_with_display_names(room_id, display_names).await.map_err(Into::into)
}
async fn get_account_data_event(
&self,
event_type: GlobalAccountDataEventType,
) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>, Self::Error> {
self.0.get_account_data_event(event_type).await.map_err(Into::into)
}
async fn get_room_account_data_event(
&self,
room_id: &RoomId,
event_type: RoomAccountDataEventType,
) -> Result<Option<Raw<AnyRoomAccountDataEvent>>, Self::Error> {
self.0.get_room_account_data_event(room_id, event_type).await.map_err(Into::into)
}
async fn get_user_room_receipt_event(
&self,
room_id: &RoomId,
receipt_type: ReceiptType,
thread: ReceiptThread,
user_id: &UserId,
) -> Result<Option<(OwnedEventId, Receipt)>, Self::Error> {
self.0
.get_user_room_receipt_event(room_id, receipt_type, thread, user_id)
.await
.map_err(Into::into)
}
async fn get_event_room_receipt_events(
&self,
room_id: &RoomId,
receipt_type: ReceiptType,
thread: ReceiptThread,
event_id: &EventId,
) -> Result<Vec<(OwnedUserId, Receipt)>, Self::Error> {
self.0
.get_event_room_receipt_events(room_id, receipt_type, thread, event_id)
.await
.map_err(Into::into)
}
async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> {
self.0.get_custom_value(key).await.map_err(Into::into)
}
async fn set_custom_value(
&self,
key: &[u8],
value: Vec<u8>,
) -> Result<Option<Vec<u8>>, Self::Error> {
self.0.set_custom_value(key, value).await.map_err(Into::into)
}
async fn remove_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> {
self.0.remove_custom_value(key).await.map_err(Into::into)
}
async fn add_media_content(
&self,
request: &MediaRequest,
content: Vec<u8>,
) -> Result<(), Self::Error> {
self.0.add_media_content(request, content).await.map_err(Into::into)
}
async fn get_media_content(
&self,
request: &MediaRequest,
) -> Result<Option<Vec<u8>>, Self::Error> {
self.0.get_media_content(request).await.map_err(Into::into)
}
async fn remove_media_content(&self, request: &MediaRequest) -> Result<(), Self::Error> {
self.0.remove_media_content(request).await.map_err(Into::into)
}
async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<(), Self::Error> {
self.0.remove_media_content_for_uri(uri).await.map_err(Into::into)
}
async fn remove_room(&self, room_id: &RoomId) -> Result<(), Self::Error> {
self.0.remove_room(room_id).await.map_err(Into::into)
}
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
pub trait StateStoreExt: StateStore {
async fn get_state_event_static<C>(
&self,
room_id: &RoomId,
) -> Result<Option<RawSyncOrStrippedState<C>>, Self::Error>
where
C: StaticEventContent + StaticStateEventContent<StateKey = EmptyStateKey> + RedactContent,
C::Redacted: RedactedStateEventContent,
{
Ok(self.get_state_event(room_id, C::TYPE.into(), "").await?.map(|raw| raw.cast()))
}
async fn get_state_event_static_for_key<C, K>(
&self,
room_id: &RoomId,
state_key: &K,
) -> Result<Option<RawSyncOrStrippedState<C>>, Self::Error>
where
C: StaticEventContent + StaticStateEventContent + RedactContent,
C::StateKey: Borrow<K>,
C::Redacted: RedactedStateEventContent,
K: AsRef<str> + ?Sized + Sync,
{
Ok(self
.get_state_event(room_id, C::TYPE.into(), state_key.as_ref())
.await?
.map(|raw| raw.cast()))
}
async fn get_state_events_static<C>(
&self,
room_id: &RoomId,
) -> Result<Vec<RawSyncOrStrippedState<C>>, Self::Error>
where
C: StaticEventContent + StaticStateEventContent + RedactContent,
C::Redacted: RedactedStateEventContent,
{
Ok(self
.get_state_events(room_id, C::TYPE.into())
.await?
.into_iter()
.map(|raw| raw.cast())
.collect())
}
async fn get_state_events_for_keys_static<'a, C, K, I>(
&self,
room_id: &RoomId,
state_keys: I,
) -> Result<Vec<RawSyncOrStrippedState<C>>, Self::Error>
where
C: StaticEventContent + StaticStateEventContent + RedactContent,
C::StateKey: Borrow<K>,
C::Redacted: RedactedStateEventContent,
K: AsRef<str> + Sized + Sync + 'a,
I: IntoIterator<Item = &'a K> + Send,
I::IntoIter: Send,
{
Ok(self
.get_state_events_for_keys(
room_id,
C::TYPE.into(),
&state_keys.into_iter().map(|k| k.as_ref()).collect::<Vec<_>>(),
)
.await?
.into_iter()
.map(|raw| raw.cast())
.collect())
}
async fn get_account_data_event_static<C>(
&self,
) -> Result<Option<Raw<GlobalAccountDataEvent<C>>>, Self::Error>
where
C: StaticEventContent + GlobalAccountDataEventContent,
{
Ok(self.get_account_data_event(C::TYPE.into()).await?.map(Raw::cast))
}
async fn get_room_account_data_event_static<C>(
&self,
room_id: &RoomId,
) -> Result<Option<Raw<RoomAccountDataEvent<C>>>, Self::Error>
where
C: StaticEventContent + RoomAccountDataEventContent,
{
Ok(self.get_room_account_data_event(room_id, C::TYPE.into()).await?.map(Raw::cast))
}
async fn get_member_event(
&self,
room_id: &RoomId,
state_key: &UserId,
) -> Result<Option<RawMemberEvent>, Self::Error> {
self.get_state_event_static_for_key(room_id, state_key).await
}
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl<T: StateStore + ?Sized> StateStoreExt for T {}
pub type DynStateStore = dyn StateStore<Error = StoreError>;
pub trait IntoStateStore {
#[doc(hidden)]
fn into_state_store(self) -> Arc<DynStateStore>;
}
impl<T> IntoStateStore for T
where
T: StateStore + Sized + 'static,
{
fn into_state_store(self) -> Arc<DynStateStore> {
Arc::new(EraseStateStoreError(self))
}
}
impl<T> IntoStateStore for Arc<T>
where
T: StateStore + 'static,
{
fn into_state_store(self) -> Arc<DynStateStore> {
let ptr: *const T = Arc::into_raw(self);
let ptr_erased = ptr as *const EraseStateStoreError<T>;
unsafe { Arc::from_raw(ptr_erased) }
}
}
#[derive(Debug, Clone)]
pub enum StateStoreDataValue {
SyncToken(String),
Filter(String),
UserAvatarUrl(String),
}
impl StateStoreDataValue {
pub fn into_sync_token(self) -> Option<String> {
as_variant!(self, Self::SyncToken)
}
pub fn into_filter(self) -> Option<String> {
as_variant!(self, Self::Filter)
}
pub fn into_user_avatar_url(self) -> Option<String> {
as_variant!(self, Self::UserAvatarUrl)
}
}
#[derive(Debug, Clone, Copy)]
pub enum StateStoreDataKey<'a> {
SyncToken,
Filter(&'a str),
UserAvatarUrl(&'a UserId),
}
impl StateStoreDataKey<'_> {
pub const SYNC_TOKEN: &'static str = "sync_token";
pub const FILTER: &'static str = "filter";
pub const USER_AVATAR_URL: &'static str = "user_avatar_url";
}