use std::{
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
fmt::Display,
ops::Deref,
sync::{
atomic::{AtomicBool, Ordering},
Arc, Weak,
},
};
use matrix_sdk_common::locks::RwLock as StdRwLock;
use ruma::{DeviceId, OwnedDeviceId, OwnedUserId, UserId};
use serde::{Deserialize, Serialize};
use tokio::sync::{Mutex, MutexGuard, OwnedRwLockReadGuard, RwLock};
use tracing::{field::display, instrument, trace, Span};
use super::{CryptoStoreError, CryptoStoreWrapper};
use crate::{identities::DeviceData, olm::Session, Account};
#[derive(Debug, Default, Clone)]
pub struct SessionStore {
#[allow(clippy::type_complexity)]
pub(crate) entries: Arc<RwLock<BTreeMap<String, Arc<Mutex<Vec<Session>>>>>>,
}
impl SessionStore {
pub fn new() -> Self {
Self::default()
}
pub async fn clear(&self) {
self.entries.write().await.clear()
}
pub async fn add(&self, session: Session) -> bool {
let sessions_lock =
self.entries.write().await.entry(session.sender_key.to_base64()).or_default().clone();
let mut sessions = sessions_lock.lock().await;
if !sessions.contains(&session) {
sessions.push(session);
true
} else {
false
}
}
pub async fn get(&self, sender_key: &str) -> Option<Arc<Mutex<Vec<Session>>>> {
self.entries.read().await.get(sender_key).cloned()
}
pub async fn set_for_sender(&self, sender_key: &str, sessions: Vec<Session>) {
self.entries.write().await.insert(sender_key.to_owned(), Arc::new(Mutex::new(sessions)));
}
}
#[derive(Debug, Default)]
pub struct DeviceStore {
entries: StdRwLock<BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, DeviceData>>>,
}
impl DeviceStore {
pub fn new() -> Self {
Self::default()
}
pub fn add(&self, device: DeviceData) -> bool {
let user_id = device.user_id();
self.entries
.write()
.entry(user_id.to_owned())
.or_default()
.insert(device.device_id().into(), device)
.is_none()
}
pub fn get(&self, user_id: &UserId, device_id: &DeviceId) -> Option<DeviceData> {
Some(self.entries.read().get(user_id)?.get(device_id)?.clone())
}
pub fn remove(&self, user_id: &UserId, device_id: &DeviceId) -> Option<DeviceData> {
self.entries.write().get_mut(user_id)?.remove(device_id)
}
pub fn user_devices(&self, user_id: &UserId) -> HashMap<OwnedDeviceId, DeviceData> {
self.entries
.write()
.entry(user_id.to_owned())
.or_default()
.iter()
.map(|(key, value)| (key.to_owned(), value.clone()))
.collect()
}
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Deserialize, Serialize)]
#[serde(transparent)]
pub struct SequenceNumber(i64);
impl Display for SequenceNumber {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl PartialOrd for SequenceNumber {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for SequenceNumber {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0.wrapping_sub(other.0).cmp(&0)
}
}
impl SequenceNumber {
pub(crate) fn increment(&mut self) {
self.0 = self.0.wrapping_add(1)
}
fn previous(&self) -> Self {
Self(self.0.wrapping_sub(1))
}
}
#[derive(Debug)]
pub(super) struct KeysQueryWaiter {
user: OwnedUserId,
sequence_number: SequenceNumber,
pub(super) completed: AtomicBool,
}
#[derive(Debug, Default)]
pub(super) struct UsersForKeyQuery {
next_sequence_number: SequenceNumber,
user_map: HashMap<OwnedUserId, SequenceNumber>,
tasks_awaiting_key_query: Vec<Weak<KeysQueryWaiter>>,
}
impl UsersForKeyQuery {
pub(super) fn insert_user(&mut self, user: &UserId) {
let sequence_number = self.next_sequence_number;
trace!(?user, %sequence_number, "Flagging user for key query");
self.user_map.insert(user.to_owned(), sequence_number);
self.next_sequence_number.increment();
}
#[instrument(level = "trace", skip(self), fields(invalidation_sequence))]
pub(super) fn maybe_remove_user(
&mut self,
user: &UserId,
query_sequence: SequenceNumber,
) -> bool {
let last_invalidation = self.user_map.get(user).copied();
self.tasks_awaiting_key_query.retain(|waiter| {
let Some(waiter) = waiter.upgrade() else {
trace!("removing expired waiting task");
return false;
};
if waiter.user == user && waiter.sequence_number <= query_sequence {
trace!(
?user,
%query_sequence,
waiter_sequence = %waiter.sequence_number,
"Removing completed waiting task"
);
waiter.completed.store(true, Ordering::Relaxed);
false
} else {
trace!(
?user,
%query_sequence,
waiter_user = ?waiter.user,
waiter_sequence= %waiter.sequence_number,
"Retaining still-waiting task"
);
true
}
});
if let Some(last_invalidation) = last_invalidation {
Span::current().record("invalidation_sequence", display(last_invalidation));
if last_invalidation > query_sequence {
trace!("User invalidated since this query started: still not up-to-date");
false
} else {
trace!("User now up-to-date");
self.user_map.remove(user);
true
}
} else {
trace!("User already up-to-date, nothing to do");
true
}
}
pub(super) fn users_for_key_query(&self) -> (HashSet<OwnedUserId>, SequenceNumber) {
let sequence_number = self.next_sequence_number.previous();
(self.user_map.keys().cloned().collect(), sequence_number)
}
pub(super) fn maybe_register_waiting_task(
&mut self,
user: &UserId,
) -> Option<Arc<KeysQueryWaiter>> {
self.user_map.get(user).map(|&sequence_number| {
trace!(?user, %sequence_number, "Registering new waiting task");
let waiter = Arc::new(KeysQueryWaiter {
sequence_number,
user: user.to_owned(),
completed: AtomicBool::new(false),
});
self.tasks_awaiting_key_query.push(Arc::downgrade(&waiter));
waiter
})
}
}
#[derive(Debug)]
pub(crate) struct StoreCache {
pub(super) store: Arc<CryptoStoreWrapper>,
pub(super) tracked_users: StdRwLock<BTreeSet<OwnedUserId>>,
pub(super) loaded_tracked_users: RwLock<bool>,
pub(super) account: Mutex<Option<Account>>,
}
impl StoreCache {
pub(crate) fn store_wrapper(&self) -> &CryptoStoreWrapper {
self.store.as_ref()
}
pub(super) async fn account(&self) -> super::Result<impl Deref<Target = Account> + '_> {
let mut guard = self.account.lock().await;
if guard.is_some() {
Ok(MutexGuard::map(guard, |acc| acc.as_mut().unwrap()))
} else {
match self.store.load_account().await? {
Some(account) => {
*guard = Some(account);
Ok(MutexGuard::map(guard, |acc| acc.as_mut().unwrap()))
}
None => Err(CryptoStoreError::AccountUnset),
}
}
}
}
pub(crate) struct StoreCacheGuard {
pub(super) cache: OwnedRwLockReadGuard<StoreCache>,
}
impl StoreCacheGuard {
pub async fn account(&self) -> super::Result<impl Deref<Target = Account> + '_> {
self.cache.account().await
}
}
impl Deref for StoreCacheGuard {
type Target = StoreCache;
fn deref(&self) -> &Self::Target {
&self.cache
}
}
#[cfg(test)]
mod tests {
use matrix_sdk_test::async_test;
use proptest::prelude::*;
use super::{DeviceStore, SequenceNumber, SessionStore};
use crate::{
identities::device::testing::get_device, olm::tests::get_account_and_session_test_helper,
};
#[async_test]
async fn test_session_store() {
let (_, session) = get_account_and_session_test_helper();
let store = SessionStore::new();
assert!(store.add(session.clone()).await);
assert!(!store.add(session.clone()).await);
let sessions = store.get(&session.sender_key.to_base64()).await.unwrap();
let sessions = sessions.lock().await;
let loaded_session = &sessions[0];
assert_eq!(&session, loaded_session);
}
#[async_test]
async fn test_session_store_bulk_storing() {
let (_, session) = get_account_and_session_test_helper();
let store = SessionStore::new();
store.set_for_sender(&session.sender_key.to_base64(), vec![session.clone()]).await;
let sessions = store.get(&session.sender_key.to_base64()).await.unwrap();
let sessions = sessions.lock().await;
let loaded_session = &sessions[0];
assert_eq!(&session, loaded_session);
}
#[async_test]
async fn test_device_store() {
let device = get_device();
let store = DeviceStore::new();
assert!(store.add(device.clone()));
assert!(!store.add(device.clone()));
let loaded_device = store.get(device.user_id(), device.device_id()).unwrap();
assert_eq!(device, loaded_device);
let user_devices = store.user_devices(device.user_id());
assert_eq!(&**user_devices.keys().next().unwrap(), device.device_id());
assert_eq!(user_devices.values().next().unwrap(), &device);
let loaded_device = user_devices.get(device.device_id()).unwrap();
assert_eq!(&device, loaded_device);
store.remove(device.user_id(), device.device_id());
let loaded_device = store.get(device.user_id(), device.device_id());
assert!(loaded_device.is_none());
}
#[test]
fn sequence_at_boundary() {
let first = SequenceNumber(i64::MAX);
let second = SequenceNumber(first.0.wrapping_add(1));
let third = SequenceNumber(first.0.wrapping_sub(1));
assert!(second > first);
assert!(first < second);
assert!(third < first);
assert!(first > third);
assert!(second > third);
assert!(third < second);
}
proptest! {
#[test]
fn partial_eq_sequence_number(sequence in i64::MIN..i64::MAX) {
let first = SequenceNumber(sequence);
let second = SequenceNumber(first.0.wrapping_add(1));
let third = SequenceNumber(first.0.wrapping_sub(1));
assert!(second > first);
assert!(first < second);
assert!(third < first);
assert!(first > third);
assert!(second > third);
assert!(third < second);
}
}
}