matrix_sdk_base/store/
memory_store.rs

1// Copyright 2021 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::{BTreeMap, BTreeSet, HashMap},
17    sync::RwLock,
18};
19
20use async_trait::async_trait;
21use growable_bloom_filter::GrowableBloom;
22use matrix_sdk_common::{ROOM_VERSION_FALLBACK, ROOM_VERSION_RULES_FALLBACK};
23use ruma::{
24    CanonicalJsonObject, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedMxcUri,
25    OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId, UserId,
26    canonical_json::{RedactedBecause, redact},
27    events::{
28        AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent,
29        AnySyncStateEvent, GlobalAccountDataEventType, RoomAccountDataEventType, StateEventType,
30        presence::PresenceEvent,
31        receipt::{Receipt, ReceiptThread, ReceiptType},
32        room::member::{MembershipState, StrippedRoomMemberEvent, SyncRoomMemberEvent},
33    },
34    serde::Raw,
35    time::Instant,
36};
37use tracing::{debug, instrument, warn};
38
39use super::{
40    DependentQueuedRequest, DependentQueuedRequestKind, QueuedRequestKind, Result, RoomInfo,
41    RoomLoadSettings, StateChanges, StateStore, StoreError, SupportedVersionsResponse,
42    TtlStoreValue, WellKnownResponse,
43    send_queue::{ChildTransactionId, QueuedRequest, SentRequestKey},
44    traits::ComposerDraft,
45};
46use crate::{
47    MinimalRoomMemberEvent, RoomMemberships, StateStoreDataKey, StateStoreDataValue,
48    deserialized_responses::{DisplayName, RawAnySyncOrStrippedState},
49    store::{
50        QueueWedgeError, StoredThreadSubscription,
51        traits::{ThreadSubscriptionCatchupToken, compare_thread_subscription_bump_stamps},
52    },
53};
54
55#[derive(Debug, Default)]
56#[allow(clippy::type_complexity)]
57struct MemoryStoreInner {
58    recently_visited_rooms: HashMap<OwnedUserId, Vec<OwnedRoomId>>,
59    composer_drafts: HashMap<(OwnedRoomId, Option<OwnedEventId>), ComposerDraft>,
60    user_avatar_url: HashMap<OwnedUserId, OwnedMxcUri>,
61    sync_token: Option<String>,
62    supported_versions: Option<TtlStoreValue<SupportedVersionsResponse>>,
63    well_known: Option<TtlStoreValue<Option<WellKnownResponse>>>,
64    filters: HashMap<String, String>,
65    utd_hook_manager_data: Option<GrowableBloom>,
66    one_time_key_uploaded_error: bool,
67    account_data: HashMap<GlobalAccountDataEventType, Raw<AnyGlobalAccountDataEvent>>,
68    profiles: HashMap<OwnedRoomId, HashMap<OwnedUserId, MinimalRoomMemberEvent>>,
69    display_names: HashMap<OwnedRoomId, HashMap<DisplayName, BTreeSet<OwnedUserId>>>,
70    members: HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
71    room_info: HashMap<OwnedRoomId, RoomInfo>,
72    room_state:
73        HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<AnySyncStateEvent>>>>,
74    room_account_data:
75        HashMap<OwnedRoomId, HashMap<RoomAccountDataEventType, Raw<AnyRoomAccountDataEvent>>>,
76    stripped_room_state:
77        HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<AnyStrippedStateEvent>>>>,
78    stripped_members: HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
79    presence: HashMap<OwnedUserId, Raw<PresenceEvent>>,
80    room_user_receipts: HashMap<
81        OwnedRoomId,
82        HashMap<(String, Option<String>), HashMap<OwnedUserId, (OwnedEventId, Receipt)>>,
83    >,
84    room_event_receipts: HashMap<
85        OwnedRoomId,
86        HashMap<(String, Option<String>), HashMap<OwnedEventId, HashMap<OwnedUserId, Receipt>>>,
87    >,
88    custom: HashMap<Vec<u8>, Vec<u8>>,
89    send_queue_events: BTreeMap<OwnedRoomId, Vec<QueuedRequest>>,
90    dependent_send_queue_events: BTreeMap<OwnedRoomId, Vec<DependentQueuedRequest>>,
91    seen_knock_requests: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, OwnedUserId>>,
92    thread_subscriptions: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, StoredThreadSubscription>>,
93    thread_subscriptions_catchup_tokens: Option<Vec<ThreadSubscriptionCatchupToken>>,
94}
95
96/// In-memory, non-persistent implementation of the `StateStore`.
97///
98/// Default if no other is configured at startup.
99#[derive(Debug, Default)]
100pub struct MemoryStore {
101    inner: RwLock<MemoryStoreInner>,
102}
103
104impl MemoryStore {
105    /// Create a new empty MemoryStore
106    pub fn new() -> Self {
107        Self::default()
108    }
109
110    fn get_user_room_receipt_event_impl(
111        &self,
112        room_id: &RoomId,
113        receipt_type: ReceiptType,
114        thread: ReceiptThread,
115        user_id: &UserId,
116    ) -> Option<(OwnedEventId, Receipt)> {
117        self.inner
118            .read()
119            .unwrap()
120            .room_user_receipts
121            .get(room_id)?
122            .get(&(receipt_type.to_string(), thread.as_str().map(ToOwned::to_owned)))?
123            .get(user_id)
124            .cloned()
125    }
126
127    fn get_event_room_receipt_events_impl(
128        &self,
129        room_id: &RoomId,
130        receipt_type: ReceiptType,
131        thread: ReceiptThread,
132        event_id: &EventId,
133    ) -> Option<Vec<(OwnedUserId, Receipt)>> {
134        Some(
135            self.inner
136                .read()
137                .unwrap()
138                .room_event_receipts
139                .get(room_id)?
140                .get(&(receipt_type.to_string(), thread.as_str().map(ToOwned::to_owned)))?
141                .get(event_id)?
142                .iter()
143                .map(|(key, value)| (key.clone(), value.clone()))
144                .collect(),
145        )
146    }
147}
148
149#[cfg_attr(target_family = "wasm", async_trait(?Send))]
150#[cfg_attr(not(target_family = "wasm"), async_trait)]
151impl StateStore for MemoryStore {
152    type Error = StoreError;
153
154    async fn get_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<Option<StateStoreDataValue>> {
155        let inner = self.inner.read().unwrap();
156
157        Ok(match key {
158            StateStoreDataKey::SyncToken => {
159                inner.sync_token.clone().map(StateStoreDataValue::SyncToken)
160            }
161            StateStoreDataKey::SupportedVersions => {
162                inner.supported_versions.clone().map(StateStoreDataValue::SupportedVersions)
163            }
164            StateStoreDataKey::WellKnown => {
165                inner.well_known.clone().map(StateStoreDataValue::WellKnown)
166            }
167            StateStoreDataKey::Filter(filter_name) => {
168                inner.filters.get(filter_name).cloned().map(StateStoreDataValue::Filter)
169            }
170            StateStoreDataKey::UserAvatarUrl(user_id) => {
171                inner.user_avatar_url.get(user_id).cloned().map(StateStoreDataValue::UserAvatarUrl)
172            }
173            StateStoreDataKey::RecentlyVisitedRooms(user_id) => inner
174                .recently_visited_rooms
175                .get(user_id)
176                .cloned()
177                .map(StateStoreDataValue::RecentlyVisitedRooms),
178            StateStoreDataKey::UtdHookManagerData => {
179                inner.utd_hook_manager_data.clone().map(StateStoreDataValue::UtdHookManagerData)
180            }
181            StateStoreDataKey::OneTimeKeyAlreadyUploaded => inner
182                .one_time_key_uploaded_error
183                .then_some(StateStoreDataValue::OneTimeKeyAlreadyUploaded),
184            StateStoreDataKey::ComposerDraft(room_id, thread_root) => {
185                let key = (room_id.to_owned(), thread_root.map(ToOwned::to_owned));
186                inner.composer_drafts.get(&key).cloned().map(StateStoreDataValue::ComposerDraft)
187            }
188            StateStoreDataKey::SeenKnockRequests(room_id) => inner
189                .seen_knock_requests
190                .get(room_id)
191                .cloned()
192                .map(StateStoreDataValue::SeenKnockRequests),
193            StateStoreDataKey::ThreadSubscriptionsCatchupTokens => inner
194                .thread_subscriptions_catchup_tokens
195                .clone()
196                .map(StateStoreDataValue::ThreadSubscriptionsCatchupTokens),
197        })
198    }
199
200    async fn set_kv_data(
201        &self,
202        key: StateStoreDataKey<'_>,
203        value: StateStoreDataValue,
204    ) -> Result<()> {
205        let mut inner = self.inner.write().unwrap();
206        match key {
207            StateStoreDataKey::SyncToken => {
208                inner.sync_token =
209                    Some(value.into_sync_token().expect("Session data not a sync token"))
210            }
211            StateStoreDataKey::Filter(filter_name) => {
212                inner.filters.insert(
213                    filter_name.to_owned(),
214                    value.into_filter().expect("Session data not a filter"),
215                );
216            }
217            StateStoreDataKey::UserAvatarUrl(user_id) => {
218                inner.user_avatar_url.insert(
219                    user_id.to_owned(),
220                    value.into_user_avatar_url().expect("Session data not a user avatar url"),
221                );
222            }
223            StateStoreDataKey::RecentlyVisitedRooms(user_id) => {
224                inner.recently_visited_rooms.insert(
225                    user_id.to_owned(),
226                    value
227                        .into_recently_visited_rooms()
228                        .expect("Session data not a list of recently visited rooms"),
229                );
230            }
231            StateStoreDataKey::UtdHookManagerData => {
232                inner.utd_hook_manager_data = Some(
233                    value
234                        .into_utd_hook_manager_data()
235                        .expect("Session data not the hook manager data"),
236                );
237            }
238            StateStoreDataKey::OneTimeKeyAlreadyUploaded => {
239                inner.one_time_key_uploaded_error = true;
240            }
241            StateStoreDataKey::ComposerDraft(room_id, thread_root) => {
242                inner.composer_drafts.insert(
243                    (room_id.to_owned(), thread_root.map(ToOwned::to_owned)),
244                    value.into_composer_draft().expect("Session data not a composer draft"),
245                );
246            }
247            StateStoreDataKey::SupportedVersions => {
248                inner.supported_versions = Some(
249                    value
250                        .into_supported_versions()
251                        .expect("Session data not containing supported versions"),
252                );
253            }
254            StateStoreDataKey::WellKnown => {
255                inner.well_known =
256                    Some(value.into_well_known().expect("Session data not containing well-known"));
257            }
258            StateStoreDataKey::SeenKnockRequests(room_id) => {
259                inner.seen_knock_requests.insert(
260                    room_id.to_owned(),
261                    value
262                        .into_seen_knock_requests()
263                        .expect("Session data is not a set of seen join request ids"),
264                );
265            }
266            StateStoreDataKey::ThreadSubscriptionsCatchupTokens => {
267                inner.thread_subscriptions_catchup_tokens =
268                    Some(value.into_thread_subscriptions_catchup_tokens().expect(
269                        "Session data is not a list of thread subscription catchup tokens",
270                    ));
271            }
272        }
273
274        Ok(())
275    }
276
277    async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<()> {
278        let mut inner = self.inner.write().unwrap();
279        match key {
280            StateStoreDataKey::SyncToken => inner.sync_token = None,
281            StateStoreDataKey::SupportedVersions => inner.supported_versions = None,
282            StateStoreDataKey::WellKnown => inner.well_known = None,
283            StateStoreDataKey::Filter(filter_name) => {
284                inner.filters.remove(filter_name);
285            }
286            StateStoreDataKey::UserAvatarUrl(user_id) => {
287                inner.user_avatar_url.remove(user_id);
288            }
289            StateStoreDataKey::RecentlyVisitedRooms(user_id) => {
290                inner.recently_visited_rooms.remove(user_id);
291            }
292            StateStoreDataKey::UtdHookManagerData => inner.utd_hook_manager_data = None,
293            StateStoreDataKey::OneTimeKeyAlreadyUploaded => {
294                inner.one_time_key_uploaded_error = false
295            }
296            StateStoreDataKey::ComposerDraft(room_id, thread_root) => {
297                let key = (room_id.to_owned(), thread_root.map(ToOwned::to_owned));
298                inner.composer_drafts.remove(&key);
299            }
300            StateStoreDataKey::SeenKnockRequests(room_id) => {
301                inner.seen_knock_requests.remove(room_id);
302            }
303            StateStoreDataKey::ThreadSubscriptionsCatchupTokens => {
304                inner.thread_subscriptions_catchup_tokens = None;
305            }
306        }
307        Ok(())
308    }
309
310    #[instrument(skip(self, changes))]
311    async fn save_changes(&self, changes: &StateChanges) -> Result<()> {
312        let now = Instant::now();
313
314        let mut inner = self.inner.write().unwrap();
315
316        if let Some(s) = &changes.sync_token {
317            inner.sync_token = Some(s.to_owned());
318        }
319
320        for (room, users) in &changes.profiles_to_delete {
321            let Some(room_profiles) = inner.profiles.get_mut(room) else {
322                continue;
323            };
324            for user in users {
325                room_profiles.remove(user);
326            }
327        }
328
329        for (room, users) in &changes.profiles {
330            for (user_id, profile) in users {
331                inner
332                    .profiles
333                    .entry(room.clone())
334                    .or_default()
335                    .insert(user_id.clone(), profile.clone());
336            }
337        }
338
339        for (room, map) in &changes.ambiguity_maps {
340            for (display_name, display_names) in map {
341                inner
342                    .display_names
343                    .entry(room.clone())
344                    .or_default()
345                    .insert(display_name.clone(), display_names.clone());
346            }
347        }
348
349        for (event_type, event) in &changes.account_data {
350            inner.account_data.insert(event_type.clone(), event.clone());
351        }
352
353        for (room, events) in &changes.room_account_data {
354            for (event_type, event) in events {
355                inner
356                    .room_account_data
357                    .entry(room.clone())
358                    .or_default()
359                    .insert(event_type.clone(), event.clone());
360            }
361        }
362
363        for (room, event_types) in &changes.state {
364            for (event_type, events) in event_types {
365                for (state_key, raw_event) in events {
366                    inner
367                        .room_state
368                        .entry(room.clone())
369                        .or_default()
370                        .entry(event_type.clone())
371                        .or_default()
372                        .insert(state_key.to_owned(), raw_event.clone());
373                    inner.stripped_room_state.remove(room);
374
375                    if *event_type == StateEventType::RoomMember {
376                        let event =
377                            match raw_event.deserialize_as_unchecked::<SyncRoomMemberEvent>() {
378                                Ok(ev) => ev,
379                                Err(e) => {
380                                    let event_id: Option<String> =
381                                        raw_event.get_field("event_id").ok().flatten();
382                                    debug!(event_id, "Failed to deserialize member event: {e}");
383                                    continue;
384                                }
385                            };
386
387                        inner.stripped_members.remove(room);
388
389                        inner
390                            .members
391                            .entry(room.clone())
392                            .or_default()
393                            .insert(event.state_key().to_owned(), event.membership().clone());
394                    }
395                }
396            }
397        }
398
399        for (room_id, info) in &changes.room_infos {
400            inner.room_info.insert(room_id.clone(), info.clone());
401        }
402
403        for (sender, event) in &changes.presence {
404            inner.presence.insert(sender.clone(), event.clone());
405        }
406
407        for (room, event_types) in &changes.stripped_state {
408            for (event_type, events) in event_types {
409                for (state_key, raw_event) in events {
410                    inner
411                        .stripped_room_state
412                        .entry(room.clone())
413                        .or_default()
414                        .entry(event_type.clone())
415                        .or_default()
416                        .insert(state_key.to_owned(), raw_event.clone());
417
418                    if *event_type == StateEventType::RoomMember {
419                        let event =
420                            match raw_event.deserialize_as_unchecked::<StrippedRoomMemberEvent>() {
421                                Ok(ev) => ev,
422                                Err(e) => {
423                                    let event_id: Option<String> =
424                                        raw_event.get_field("event_id").ok().flatten();
425                                    debug!(
426                                        event_id,
427                                        "Failed to deserialize stripped member event: {e}"
428                                    );
429                                    continue;
430                                }
431                            };
432
433                        inner
434                            .stripped_members
435                            .entry(room.clone())
436                            .or_default()
437                            .insert(event.state_key, event.content.membership.clone());
438                    }
439                }
440            }
441        }
442
443        for (room, content) in &changes.receipts {
444            for (event_id, receipts) in &content.0 {
445                for (receipt_type, receipts) in receipts {
446                    for (user_id, receipt) in receipts {
447                        let thread = receipt.thread.as_str().map(ToOwned::to_owned);
448                        // Add the receipt to the room user receipts
449                        if let Some((old_event, _)) = inner
450                            .room_user_receipts
451                            .entry(room.clone())
452                            .or_default()
453                            .entry((receipt_type.to_string(), thread.clone()))
454                            .or_default()
455                            .insert(user_id.clone(), (event_id.clone(), receipt.clone()))
456                        {
457                            // Remove the old receipt from the room event receipts
458                            if let Some(receipt_map) = inner.room_event_receipts.get_mut(room)
459                                && let Some(event_map) =
460                                    receipt_map.get_mut(&(receipt_type.to_string(), thread.clone()))
461                                && let Some(user_map) = event_map.get_mut(&old_event)
462                            {
463                                user_map.remove(user_id);
464                            }
465                        }
466
467                        // Add the receipt to the room event receipts
468                        inner
469                            .room_event_receipts
470                            .entry(room.clone())
471                            .or_default()
472                            .entry((receipt_type.to_string(), thread))
473                            .or_default()
474                            .entry(event_id.clone())
475                            .or_default()
476                            .insert(user_id.clone(), receipt.clone());
477                    }
478                }
479            }
480        }
481
482        let make_redaction_rules = |room_info: &HashMap<OwnedRoomId, RoomInfo>, room_id| {
483            room_info.get(room_id).map(|info| info.room_version_rules_or_default()).unwrap_or_else(|| {
484                warn!(
485                    ?room_id,
486                    "Unable to get the room version rules, defaulting to rules for room version {ROOM_VERSION_FALLBACK}"
487                );
488                ROOM_VERSION_RULES_FALLBACK
489            }).redaction
490        };
491
492        let inner = &mut *inner;
493        for (room_id, redactions) in &changes.redactions {
494            let mut redaction_rules = None;
495
496            if let Some(room) = inner.room_state.get_mut(room_id) {
497                for ref_room_mu in room.values_mut() {
498                    for raw_evt in ref_room_mu.values_mut() {
499                        if let Ok(Some(event_id)) = raw_evt.get_field::<OwnedEventId>("event_id")
500                            && let Some(redaction) = redactions.get(&event_id)
501                        {
502                            let redacted = redact(
503                                raw_evt.deserialize_as::<CanonicalJsonObject>()?,
504                                redaction_rules.get_or_insert_with(|| {
505                                    make_redaction_rules(&inner.room_info, room_id)
506                                }),
507                                Some(RedactedBecause::from_raw_event(redaction)?),
508                            )
509                            .map_err(StoreError::Redaction)?;
510                            *raw_evt = Raw::new(&redacted)?.cast_unchecked();
511                        }
512                    }
513                }
514            }
515        }
516
517        debug!("Saved changes in {:?}", now.elapsed());
518
519        Ok(())
520    }
521
522    async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<Raw<PresenceEvent>>> {
523        Ok(self.inner.read().unwrap().presence.get(user_id).cloned())
524    }
525
526    async fn get_presence_events(
527        &self,
528        user_ids: &[OwnedUserId],
529    ) -> Result<Vec<Raw<PresenceEvent>>> {
530        let presence = &self.inner.read().unwrap().presence;
531        Ok(user_ids.iter().filter_map(|user_id| presence.get(user_id).cloned()).collect())
532    }
533
534    async fn get_state_event(
535        &self,
536        room_id: &RoomId,
537        event_type: StateEventType,
538        state_key: &str,
539    ) -> Result<Option<RawAnySyncOrStrippedState>> {
540        Ok(self
541            .get_state_events_for_keys(room_id, event_type, &[state_key])
542            .await?
543            .into_iter()
544            .next())
545    }
546
547    async fn get_state_events(
548        &self,
549        room_id: &RoomId,
550        event_type: StateEventType,
551    ) -> Result<Vec<RawAnySyncOrStrippedState>> {
552        fn get_events<T>(
553            state_map: &HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<T>>>>,
554            room_id: &RoomId,
555            event_type: &StateEventType,
556            to_enum: fn(Raw<T>) -> RawAnySyncOrStrippedState,
557        ) -> Option<Vec<RawAnySyncOrStrippedState>> {
558            let state_events = state_map.get(room_id)?.get(event_type)?;
559            Some(state_events.values().cloned().map(to_enum).collect())
560        }
561
562        let inner = self.inner.read().unwrap();
563        Ok(get_events(
564            &inner.stripped_room_state,
565            room_id,
566            &event_type,
567            RawAnySyncOrStrippedState::Stripped,
568        )
569        .or_else(|| {
570            get_events(&inner.room_state, room_id, &event_type, RawAnySyncOrStrippedState::Sync)
571        })
572        .unwrap_or_default())
573    }
574
575    async fn get_state_events_for_keys(
576        &self,
577        room_id: &RoomId,
578        event_type: StateEventType,
579        state_keys: &[&str],
580    ) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
581        let inner = self.inner.read().unwrap();
582
583        if let Some(stripped_state_events) =
584            inner.stripped_room_state.get(room_id).and_then(|events| events.get(&event_type))
585        {
586            Ok(state_keys
587                .iter()
588                .filter_map(|k| {
589                    stripped_state_events
590                        .get(*k)
591                        .map(|e| RawAnySyncOrStrippedState::Stripped(e.clone()))
592                })
593                .collect())
594        } else if let Some(sync_state_events) =
595            inner.room_state.get(room_id).and_then(|events| events.get(&event_type))
596        {
597            Ok(state_keys
598                .iter()
599                .filter_map(|k| {
600                    sync_state_events.get(*k).map(|e| RawAnySyncOrStrippedState::Sync(e.clone()))
601                })
602                .collect())
603        } else {
604            Ok(Vec::new())
605        }
606    }
607
608    async fn get_profile(
609        &self,
610        room_id: &RoomId,
611        user_id: &UserId,
612    ) -> Result<Option<MinimalRoomMemberEvent>> {
613        Ok(self
614            .inner
615            .read()
616            .unwrap()
617            .profiles
618            .get(room_id)
619            .and_then(|room_profiles| room_profiles.get(user_id))
620            .cloned())
621    }
622
623    async fn get_profiles<'a>(
624        &self,
625        room_id: &RoomId,
626        user_ids: &'a [OwnedUserId],
627    ) -> Result<BTreeMap<&'a UserId, MinimalRoomMemberEvent>> {
628        if user_ids.is_empty() {
629            return Ok(BTreeMap::new());
630        }
631
632        let profiles = &self.inner.read().unwrap().profiles;
633        let Some(room_profiles) = profiles.get(room_id) else {
634            return Ok(BTreeMap::new());
635        };
636
637        Ok(user_ids
638            .iter()
639            .filter_map(|user_id| room_profiles.get(user_id).map(|p| (&**user_id, p.clone())))
640            .collect())
641    }
642
643    #[instrument(skip(self, memberships))]
644    async fn get_user_ids(
645        &self,
646        room_id: &RoomId,
647        memberships: RoomMemberships,
648    ) -> Result<Vec<OwnedUserId>> {
649        /// Get the user IDs for the given room with the given memberships and
650        /// stripped state.
651        ///
652        /// If `memberships` is empty, returns all user IDs in the room with the
653        /// given stripped state.
654        fn get_user_ids_inner(
655            members: &HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
656            room_id: &RoomId,
657            memberships: RoomMemberships,
658        ) -> Vec<OwnedUserId> {
659            members
660                .get(room_id)
661                .map(|members| {
662                    members
663                        .iter()
664                        .filter_map(|(user_id, membership)| {
665                            memberships.matches(membership).then_some(user_id)
666                        })
667                        .cloned()
668                        .collect()
669                })
670                .unwrap_or_default()
671        }
672        let inner = self.inner.read().unwrap();
673        let v = get_user_ids_inner(&inner.stripped_members, room_id, memberships);
674        if !v.is_empty() {
675            return Ok(v);
676        }
677        Ok(get_user_ids_inner(&inner.members, room_id, memberships))
678    }
679
680    async fn get_room_infos(&self, room_load_settings: &RoomLoadSettings) -> Result<Vec<RoomInfo>> {
681        let memory_store_inner = self.inner.read().unwrap();
682        let room_infos = &memory_store_inner.room_info;
683
684        Ok(match room_load_settings {
685            RoomLoadSettings::All => room_infos.values().cloned().collect(),
686
687            RoomLoadSettings::One(room_id) => match room_infos.get(room_id) {
688                Some(room_info) => vec![room_info.clone()],
689                None => vec![],
690            },
691        })
692    }
693
694    async fn get_users_with_display_name(
695        &self,
696        room_id: &RoomId,
697        display_name: &DisplayName,
698    ) -> Result<BTreeSet<OwnedUserId>> {
699        Ok(self
700            .inner
701            .read()
702            .unwrap()
703            .display_names
704            .get(room_id)
705            .and_then(|room_names| room_names.get(display_name).cloned())
706            .unwrap_or_default())
707    }
708
709    async fn get_users_with_display_names<'a>(
710        &self,
711        room_id: &RoomId,
712        display_names: &'a [DisplayName],
713    ) -> Result<HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>> {
714        if display_names.is_empty() {
715            return Ok(HashMap::new());
716        }
717
718        let inner = self.inner.read().unwrap();
719        let Some(room_names) = inner.display_names.get(room_id) else {
720            return Ok(HashMap::new());
721        };
722
723        Ok(display_names.iter().filter_map(|n| room_names.get(n).map(|d| (n, d.clone()))).collect())
724    }
725
726    async fn get_account_data_event(
727        &self,
728        event_type: GlobalAccountDataEventType,
729    ) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>> {
730        Ok(self.inner.read().unwrap().account_data.get(&event_type).cloned())
731    }
732
733    async fn get_room_account_data_event(
734        &self,
735        room_id: &RoomId,
736        event_type: RoomAccountDataEventType,
737    ) -> Result<Option<Raw<AnyRoomAccountDataEvent>>> {
738        Ok(self
739            .inner
740            .read()
741            .unwrap()
742            .room_account_data
743            .get(room_id)
744            .and_then(|m| m.get(&event_type))
745            .cloned())
746    }
747
748    async fn get_user_room_receipt_event(
749        &self,
750        room_id: &RoomId,
751        receipt_type: ReceiptType,
752        thread: ReceiptThread,
753        user_id: &UserId,
754    ) -> Result<Option<(OwnedEventId, Receipt)>> {
755        Ok(self.get_user_room_receipt_event_impl(room_id, receipt_type, thread, user_id))
756    }
757
758    async fn get_event_room_receipt_events(
759        &self,
760        room_id: &RoomId,
761        receipt_type: ReceiptType,
762        thread: ReceiptThread,
763        event_id: &EventId,
764    ) -> Result<Vec<(OwnedUserId, Receipt)>> {
765        Ok(self
766            .get_event_room_receipt_events_impl(room_id, receipt_type, thread, event_id)
767            .unwrap_or_default())
768    }
769
770    async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
771        Ok(self.inner.read().unwrap().custom.get(key).cloned())
772    }
773
774    async fn set_custom_value(&self, key: &[u8], value: Vec<u8>) -> Result<Option<Vec<u8>>> {
775        Ok(self.inner.write().unwrap().custom.insert(key.to_vec(), value))
776    }
777
778    async fn remove_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
779        Ok(self.inner.write().unwrap().custom.remove(key))
780    }
781
782    async fn remove_room(&self, room_id: &RoomId) -> Result<()> {
783        let mut inner = self.inner.write().unwrap();
784
785        inner.profiles.remove(room_id);
786        inner.display_names.remove(room_id);
787        inner.members.remove(room_id);
788        inner.room_info.remove(room_id);
789        inner.room_state.remove(room_id);
790        inner.room_account_data.remove(room_id);
791        inner.stripped_room_state.remove(room_id);
792        inner.stripped_members.remove(room_id);
793        inner.room_user_receipts.remove(room_id);
794        inner.room_event_receipts.remove(room_id);
795        inner.send_queue_events.remove(room_id);
796        inner.dependent_send_queue_events.remove(room_id);
797        inner.thread_subscriptions.remove(room_id);
798
799        Ok(())
800    }
801
802    async fn save_send_queue_request(
803        &self,
804        room_id: &RoomId,
805        transaction_id: OwnedTransactionId,
806        created_at: MilliSecondsSinceUnixEpoch,
807        kind: QueuedRequestKind,
808        priority: usize,
809    ) -> Result<(), Self::Error> {
810        self.inner
811            .write()
812            .unwrap()
813            .send_queue_events
814            .entry(room_id.to_owned())
815            .or_default()
816            .push(QueuedRequest { kind, transaction_id, error: None, priority, created_at });
817        Ok(())
818    }
819
820    async fn update_send_queue_request(
821        &self,
822        room_id: &RoomId,
823        transaction_id: &TransactionId,
824        kind: QueuedRequestKind,
825    ) -> Result<bool, Self::Error> {
826        if let Some(entry) = self
827            .inner
828            .write()
829            .unwrap()
830            .send_queue_events
831            .entry(room_id.to_owned())
832            .or_default()
833            .iter_mut()
834            .find(|item| item.transaction_id == transaction_id)
835        {
836            entry.kind = kind;
837            entry.error = None;
838            Ok(true)
839        } else {
840            Ok(false)
841        }
842    }
843
844    async fn remove_send_queue_request(
845        &self,
846        room_id: &RoomId,
847        transaction_id: &TransactionId,
848    ) -> Result<bool, Self::Error> {
849        let mut inner = self.inner.write().unwrap();
850        let q = &mut inner.send_queue_events;
851
852        let entry = q.get_mut(room_id);
853        if let Some(entry) = entry {
854            // Find the event by id in its room queue, and remove it if present.
855            if let Some(pos) = entry.iter().position(|item| item.transaction_id == transaction_id) {
856                entry.remove(pos);
857                // And if this was the last event before removal, remove the entire room entry.
858                if entry.is_empty() {
859                    q.remove(room_id);
860                }
861                return Ok(true);
862            }
863        }
864
865        Ok(false)
866    }
867
868    async fn load_send_queue_requests(
869        &self,
870        room_id: &RoomId,
871    ) -> Result<Vec<QueuedRequest>, Self::Error> {
872        let mut ret = self
873            .inner
874            .write()
875            .unwrap()
876            .send_queue_events
877            .entry(room_id.to_owned())
878            .or_default()
879            .clone();
880        // Inverted order of priority, use stable sort to keep insertion order.
881        ret.sort_by(|lhs, rhs| rhs.priority.cmp(&lhs.priority));
882        Ok(ret)
883    }
884
885    async fn update_send_queue_request_status(
886        &self,
887        room_id: &RoomId,
888        transaction_id: &TransactionId,
889        error: Option<QueueWedgeError>,
890    ) -> Result<(), Self::Error> {
891        if let Some(entry) = self
892            .inner
893            .write()
894            .unwrap()
895            .send_queue_events
896            .entry(room_id.to_owned())
897            .or_default()
898            .iter_mut()
899            .find(|item| item.transaction_id == transaction_id)
900        {
901            entry.error = error;
902        }
903        Ok(())
904    }
905
906    async fn load_rooms_with_unsent_requests(&self) -> Result<Vec<OwnedRoomId>, Self::Error> {
907        Ok(self.inner.read().unwrap().send_queue_events.keys().cloned().collect())
908    }
909
910    async fn save_dependent_queued_request(
911        &self,
912        room: &RoomId,
913        parent_transaction_id: &TransactionId,
914        own_transaction_id: ChildTransactionId,
915        created_at: MilliSecondsSinceUnixEpoch,
916        content: DependentQueuedRequestKind,
917    ) -> Result<(), Self::Error> {
918        self.inner
919            .write()
920            .unwrap()
921            .dependent_send_queue_events
922            .entry(room.to_owned())
923            .or_default()
924            .push(DependentQueuedRequest {
925                kind: content,
926                parent_transaction_id: parent_transaction_id.to_owned(),
927                own_transaction_id,
928                parent_key: None,
929                created_at,
930            });
931        Ok(())
932    }
933
934    async fn mark_dependent_queued_requests_as_ready(
935        &self,
936        room: &RoomId,
937        parent_txn_id: &TransactionId,
938        sent_parent_key: SentRequestKey,
939    ) -> Result<usize, Self::Error> {
940        let mut inner = self.inner.write().unwrap();
941        let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
942        let mut num_updated = 0;
943        for d in dependents.iter_mut().filter(|item| item.parent_transaction_id == parent_txn_id) {
944            d.parent_key = Some(sent_parent_key.clone());
945            num_updated += 1;
946        }
947        Ok(num_updated)
948    }
949
950    async fn update_dependent_queued_request(
951        &self,
952        room: &RoomId,
953        own_transaction_id: &ChildTransactionId,
954        new_content: DependentQueuedRequestKind,
955    ) -> Result<bool, Self::Error> {
956        let mut inner = self.inner.write().unwrap();
957        let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
958        for d in dependents.iter_mut() {
959            if d.own_transaction_id == *own_transaction_id {
960                d.kind = new_content;
961                return Ok(true);
962            }
963        }
964        Ok(false)
965    }
966
967    async fn remove_dependent_queued_request(
968        &self,
969        room: &RoomId,
970        txn_id: &ChildTransactionId,
971    ) -> Result<bool, Self::Error> {
972        let mut inner = self.inner.write().unwrap();
973        let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
974        if let Some(pos) = dependents.iter().position(|item| item.own_transaction_id == *txn_id) {
975            dependents.remove(pos);
976            Ok(true)
977        } else {
978            Ok(false)
979        }
980    }
981
982    async fn load_dependent_queued_requests(
983        &self,
984        room: &RoomId,
985    ) -> Result<Vec<DependentQueuedRequest>, Self::Error> {
986        Ok(self
987            .inner
988            .read()
989            .unwrap()
990            .dependent_send_queue_events
991            .get(room)
992            .cloned()
993            .unwrap_or_default())
994    }
995
996    async fn upsert_thread_subscription(
997        &self,
998        room: &RoomId,
999        thread_id: &EventId,
1000        mut new: StoredThreadSubscription,
1001    ) -> Result<(), Self::Error> {
1002        let mut inner = self.inner.write().unwrap();
1003        let room_subs = inner.thread_subscriptions.entry(room.to_owned()).or_default();
1004
1005        if let Some(previous) = room_subs.get(thread_id) {
1006            // Nothing to do.
1007            if *previous == new {
1008                return Ok(());
1009            }
1010            if !compare_thread_subscription_bump_stamps(previous.bump_stamp, &mut new.bump_stamp) {
1011                return Ok(());
1012            }
1013        }
1014
1015        room_subs.insert(thread_id.to_owned(), new);
1016
1017        Ok(())
1018    }
1019
1020    async fn load_thread_subscription(
1021        &self,
1022        room: &RoomId,
1023        thread_id: &EventId,
1024    ) -> Result<Option<StoredThreadSubscription>, Self::Error> {
1025        let inner = self.inner.read().unwrap();
1026        Ok(inner
1027            .thread_subscriptions
1028            .get(room)
1029            .and_then(|subscriptions| subscriptions.get(thread_id))
1030            .copied())
1031    }
1032
1033    async fn remove_thread_subscription(
1034        &self,
1035        room: &RoomId,
1036        thread_id: &EventId,
1037    ) -> Result<(), Self::Error> {
1038        let mut inner = self.inner.write().unwrap();
1039
1040        let Some(room_subs) = inner.thread_subscriptions.get_mut(room) else {
1041            return Ok(());
1042        };
1043
1044        room_subs.remove(thread_id);
1045
1046        if room_subs.is_empty() {
1047            // If there are no more subscriptions for this room, remove the room entry.
1048            inner.thread_subscriptions.remove(room);
1049        }
1050
1051        Ok(())
1052    }
1053
1054    async fn optimize(&self) -> Result<(), Self::Error> {
1055        Ok(())
1056    }
1057
1058    async fn get_size(&self) -> Result<Option<usize>, Self::Error> {
1059        Ok(None)
1060    }
1061}
1062
1063#[cfg(test)]
1064mod tests {
1065    use super::{MemoryStore, Result, StateStore};
1066
1067    async fn get_store() -> Result<impl StateStore> {
1068        Ok(MemoryStore::new())
1069    }
1070
1071    statestore_integration_tests!();
1072}