matrix_sdk_base/store/
memory_store.rs

1// Copyright 2021 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::{BTreeMap, BTreeSet, HashMap},
17    sync::RwLock,
18};
19
20use async_trait::async_trait;
21use growable_bloom_filter::GrowableBloom;
22use matrix_sdk_common::ROOM_VERSION_FALLBACK;
23use ruma::{
24    canonical_json::{redact, RedactedBecause},
25    events::{
26        presence::PresenceEvent,
27        receipt::{Receipt, ReceiptThread, ReceiptType},
28        room::member::{MembershipState, StrippedRoomMemberEvent, SyncRoomMemberEvent},
29        AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent,
30        AnySyncStateEvent, GlobalAccountDataEventType, RoomAccountDataEventType, StateEventType,
31    },
32    serde::Raw,
33    time::Instant,
34    CanonicalJsonObject, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedMxcUri,
35    OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId, UserId,
36};
37use tracing::{debug, instrument, warn};
38
39use super::{
40    send_queue::{ChildTransactionId, QueuedRequest, SentRequestKey},
41    traits::{ComposerDraft, ServerInfo},
42    DependentQueuedRequest, DependentQueuedRequestKind, QueuedRequestKind, Result, RoomInfo,
43    RoomLoadSettings, StateChanges, StateStore, StoreError,
44};
45use crate::{
46    deserialized_responses::{DisplayName, RawAnySyncOrStrippedState},
47    store::QueueWedgeError,
48    MinimalRoomMemberEvent, RoomMemberships, StateStoreDataKey, StateStoreDataValue,
49};
50
51#[derive(Debug, Default)]
52#[allow(clippy::type_complexity)]
53struct MemoryStoreInner {
54    recently_visited_rooms: HashMap<OwnedUserId, Vec<OwnedRoomId>>,
55    composer_drafts: HashMap<(OwnedRoomId, Option<OwnedEventId>), ComposerDraft>,
56    user_avatar_url: HashMap<OwnedUserId, OwnedMxcUri>,
57    sync_token: Option<String>,
58    server_info: Option<ServerInfo>,
59    filters: HashMap<String, String>,
60    utd_hook_manager_data: Option<GrowableBloom>,
61    account_data: HashMap<GlobalAccountDataEventType, Raw<AnyGlobalAccountDataEvent>>,
62    profiles: HashMap<OwnedRoomId, HashMap<OwnedUserId, MinimalRoomMemberEvent>>,
63    display_names: HashMap<OwnedRoomId, HashMap<DisplayName, BTreeSet<OwnedUserId>>>,
64    members: HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
65    room_info: HashMap<OwnedRoomId, RoomInfo>,
66    room_state:
67        HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<AnySyncStateEvent>>>>,
68    room_account_data:
69        HashMap<OwnedRoomId, HashMap<RoomAccountDataEventType, Raw<AnyRoomAccountDataEvent>>>,
70    stripped_room_state:
71        HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<AnyStrippedStateEvent>>>>,
72    stripped_members: HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
73    presence: HashMap<OwnedUserId, Raw<PresenceEvent>>,
74    room_user_receipts: HashMap<
75        OwnedRoomId,
76        HashMap<(String, Option<String>), HashMap<OwnedUserId, (OwnedEventId, Receipt)>>,
77    >,
78
79    room_event_receipts: HashMap<
80        OwnedRoomId,
81        HashMap<(String, Option<String>), HashMap<OwnedEventId, HashMap<OwnedUserId, Receipt>>>,
82    >,
83    custom: HashMap<Vec<u8>, Vec<u8>>,
84    send_queue_events: BTreeMap<OwnedRoomId, Vec<QueuedRequest>>,
85    dependent_send_queue_events: BTreeMap<OwnedRoomId, Vec<DependentQueuedRequest>>,
86    seen_knock_requests: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, OwnedUserId>>,
87}
88
89/// In-memory, non-persistent implementation of the `StateStore`.
90///
91/// Default if no other is configured at startup.
92#[derive(Debug, Default)]
93pub struct MemoryStore {
94    inner: RwLock<MemoryStoreInner>,
95}
96
97impl MemoryStore {
98    /// Create a new empty MemoryStore
99    pub fn new() -> Self {
100        Self::default()
101    }
102
103    fn get_user_room_receipt_event_impl(
104        &self,
105        room_id: &RoomId,
106        receipt_type: ReceiptType,
107        thread: ReceiptThread,
108        user_id: &UserId,
109    ) -> Option<(OwnedEventId, Receipt)> {
110        self.inner
111            .read()
112            .unwrap()
113            .room_user_receipts
114            .get(room_id)?
115            .get(&(receipt_type.to_string(), thread.as_str().map(ToOwned::to_owned)))?
116            .get(user_id)
117            .cloned()
118    }
119
120    fn get_event_room_receipt_events_impl(
121        &self,
122        room_id: &RoomId,
123        receipt_type: ReceiptType,
124        thread: ReceiptThread,
125        event_id: &EventId,
126    ) -> Option<Vec<(OwnedUserId, Receipt)>> {
127        Some(
128            self.inner
129                .read()
130                .unwrap()
131                .room_event_receipts
132                .get(room_id)?
133                .get(&(receipt_type.to_string(), thread.as_str().map(ToOwned::to_owned)))?
134                .get(event_id)?
135                .iter()
136                .map(|(key, value)| (key.clone(), value.clone()))
137                .collect(),
138        )
139    }
140}
141
142#[cfg_attr(target_family = "wasm", async_trait(?Send))]
143#[cfg_attr(not(target_family = "wasm"), async_trait)]
144impl StateStore for MemoryStore {
145    type Error = StoreError;
146
147    async fn get_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<Option<StateStoreDataValue>> {
148        let inner = self.inner.read().unwrap();
149        Ok(match key {
150            StateStoreDataKey::SyncToken => {
151                inner.sync_token.clone().map(StateStoreDataValue::SyncToken)
152            }
153            StateStoreDataKey::ServerInfo => {
154                inner.server_info.clone().map(StateStoreDataValue::ServerInfo)
155            }
156            StateStoreDataKey::Filter(filter_name) => {
157                inner.filters.get(filter_name).cloned().map(StateStoreDataValue::Filter)
158            }
159            StateStoreDataKey::UserAvatarUrl(user_id) => {
160                inner.user_avatar_url.get(user_id).cloned().map(StateStoreDataValue::UserAvatarUrl)
161            }
162            StateStoreDataKey::RecentlyVisitedRooms(user_id) => inner
163                .recently_visited_rooms
164                .get(user_id)
165                .cloned()
166                .map(StateStoreDataValue::RecentlyVisitedRooms),
167            StateStoreDataKey::UtdHookManagerData => {
168                inner.utd_hook_manager_data.clone().map(StateStoreDataValue::UtdHookManagerData)
169            }
170            StateStoreDataKey::ComposerDraft(room_id, thread_root) => {
171                let key = (room_id.to_owned(), thread_root.map(ToOwned::to_owned));
172                inner.composer_drafts.get(&key).cloned().map(StateStoreDataValue::ComposerDraft)
173            }
174            StateStoreDataKey::SeenKnockRequests(room_id) => inner
175                .seen_knock_requests
176                .get(room_id)
177                .cloned()
178                .map(StateStoreDataValue::SeenKnockRequests),
179        })
180    }
181
182    async fn set_kv_data(
183        &self,
184        key: StateStoreDataKey<'_>,
185        value: StateStoreDataValue,
186    ) -> Result<()> {
187        let mut inner = self.inner.write().unwrap();
188        match key {
189            StateStoreDataKey::SyncToken => {
190                inner.sync_token =
191                    Some(value.into_sync_token().expect("Session data not a sync token"))
192            }
193            StateStoreDataKey::Filter(filter_name) => {
194                inner.filters.insert(
195                    filter_name.to_owned(),
196                    value.into_filter().expect("Session data not a filter"),
197                );
198            }
199            StateStoreDataKey::UserAvatarUrl(user_id) => {
200                inner.user_avatar_url.insert(
201                    user_id.to_owned(),
202                    value.into_user_avatar_url().expect("Session data not a user avatar url"),
203                );
204            }
205            StateStoreDataKey::RecentlyVisitedRooms(user_id) => {
206                inner.recently_visited_rooms.insert(
207                    user_id.to_owned(),
208                    value
209                        .into_recently_visited_rooms()
210                        .expect("Session data not a list of recently visited rooms"),
211                );
212            }
213            StateStoreDataKey::UtdHookManagerData => {
214                inner.utd_hook_manager_data = Some(
215                    value
216                        .into_utd_hook_manager_data()
217                        .expect("Session data not the hook manager data"),
218                );
219            }
220            StateStoreDataKey::ComposerDraft(room_id, thread_root) => {
221                inner.composer_drafts.insert(
222                    (room_id.to_owned(), thread_root.map(ToOwned::to_owned)),
223                    value.into_composer_draft().expect("Session data not a composer draft"),
224                );
225            }
226            StateStoreDataKey::ServerInfo => {
227                inner.server_info = Some(
228                    value.into_server_info().expect("Session data not containing server info"),
229                );
230            }
231            StateStoreDataKey::SeenKnockRequests(room_id) => {
232                inner.seen_knock_requests.insert(
233                    room_id.to_owned(),
234                    value
235                        .into_seen_knock_requests()
236                        .expect("Session data is not a set of seen join request ids"),
237                );
238            }
239        }
240
241        Ok(())
242    }
243
244    async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<()> {
245        let mut inner = self.inner.write().unwrap();
246        match key {
247            StateStoreDataKey::SyncToken => inner.sync_token = None,
248            StateStoreDataKey::ServerInfo => inner.server_info = None,
249            StateStoreDataKey::Filter(filter_name) => {
250                inner.filters.remove(filter_name);
251            }
252            StateStoreDataKey::UserAvatarUrl(user_id) => {
253                inner.user_avatar_url.remove(user_id);
254            }
255            StateStoreDataKey::RecentlyVisitedRooms(user_id) => {
256                inner.recently_visited_rooms.remove(user_id);
257            }
258            StateStoreDataKey::UtdHookManagerData => inner.utd_hook_manager_data = None,
259            StateStoreDataKey::ComposerDraft(room_id, thread_root) => {
260                let key = (room_id.to_owned(), thread_root.map(ToOwned::to_owned));
261                inner.composer_drafts.remove(&key);
262            }
263            StateStoreDataKey::SeenKnockRequests(room_id) => {
264                inner.seen_knock_requests.remove(room_id);
265            }
266        }
267        Ok(())
268    }
269
270    #[instrument(skip(self, changes))]
271    async fn save_changes(&self, changes: &StateChanges) -> Result<()> {
272        let now = Instant::now();
273
274        let mut inner = self.inner.write().unwrap();
275
276        if let Some(s) = &changes.sync_token {
277            inner.sync_token = Some(s.to_owned());
278        }
279
280        for (room, users) in &changes.profiles_to_delete {
281            let Some(room_profiles) = inner.profiles.get_mut(room) else {
282                continue;
283            };
284            for user in users {
285                room_profiles.remove(user);
286            }
287        }
288
289        for (room, users) in &changes.profiles {
290            for (user_id, profile) in users {
291                inner
292                    .profiles
293                    .entry(room.clone())
294                    .or_default()
295                    .insert(user_id.clone(), profile.clone());
296            }
297        }
298
299        for (room, map) in &changes.ambiguity_maps {
300            for (display_name, display_names) in map {
301                inner
302                    .display_names
303                    .entry(room.clone())
304                    .or_default()
305                    .insert(display_name.clone(), display_names.clone());
306            }
307        }
308
309        for (event_type, event) in &changes.account_data {
310            inner.account_data.insert(event_type.clone(), event.clone());
311        }
312
313        for (room, events) in &changes.room_account_data {
314            for (event_type, event) in events {
315                inner
316                    .room_account_data
317                    .entry(room.clone())
318                    .or_default()
319                    .insert(event_type.clone(), event.clone());
320            }
321        }
322
323        for (room, event_types) in &changes.state {
324            for (event_type, events) in event_types {
325                for (state_key, raw_event) in events {
326                    inner
327                        .room_state
328                        .entry(room.clone())
329                        .or_default()
330                        .entry(event_type.clone())
331                        .or_default()
332                        .insert(state_key.to_owned(), raw_event.clone());
333                    inner.stripped_room_state.remove(room);
334
335                    if *event_type == StateEventType::RoomMember {
336                        let event = match raw_event.deserialize_as::<SyncRoomMemberEvent>() {
337                            Ok(ev) => ev,
338                            Err(e) => {
339                                let event_id: Option<String> =
340                                    raw_event.get_field("event_id").ok().flatten();
341                                debug!(event_id, "Failed to deserialize member event: {e}");
342                                continue;
343                            }
344                        };
345
346                        inner.stripped_members.remove(room);
347
348                        inner
349                            .members
350                            .entry(room.clone())
351                            .or_default()
352                            .insert(event.state_key().to_owned(), event.membership().clone());
353                    }
354                }
355            }
356        }
357
358        for (room_id, info) in &changes.room_infos {
359            inner.room_info.insert(room_id.clone(), info.clone());
360        }
361
362        for (sender, event) in &changes.presence {
363            inner.presence.insert(sender.clone(), event.clone());
364        }
365
366        for (room, event_types) in &changes.stripped_state {
367            for (event_type, events) in event_types {
368                for (state_key, raw_event) in events {
369                    inner
370                        .stripped_room_state
371                        .entry(room.clone())
372                        .or_default()
373                        .entry(event_type.clone())
374                        .or_default()
375                        .insert(state_key.to_owned(), raw_event.clone());
376
377                    if *event_type == StateEventType::RoomMember {
378                        let event = match raw_event.deserialize_as::<StrippedRoomMemberEvent>() {
379                            Ok(ev) => ev,
380                            Err(e) => {
381                                let event_id: Option<String> =
382                                    raw_event.get_field("event_id").ok().flatten();
383                                debug!(
384                                    event_id,
385                                    "Failed to deserialize stripped member event: {e}"
386                                );
387                                continue;
388                            }
389                        };
390
391                        inner
392                            .stripped_members
393                            .entry(room.clone())
394                            .or_default()
395                            .insert(event.state_key, event.content.membership.clone());
396                    }
397                }
398            }
399        }
400
401        for (room, content) in &changes.receipts {
402            for (event_id, receipts) in &content.0 {
403                for (receipt_type, receipts) in receipts {
404                    for (user_id, receipt) in receipts {
405                        let thread = receipt.thread.as_str().map(ToOwned::to_owned);
406                        // Add the receipt to the room user receipts
407                        if let Some((old_event, _)) = inner
408                            .room_user_receipts
409                            .entry(room.clone())
410                            .or_default()
411                            .entry((receipt_type.to_string(), thread.clone()))
412                            .or_default()
413                            .insert(user_id.clone(), (event_id.clone(), receipt.clone()))
414                        {
415                            // Remove the old receipt from the room event receipts
416                            if let Some(receipt_map) = inner.room_event_receipts.get_mut(room) {
417                                if let Some(event_map) =
418                                    receipt_map.get_mut(&(receipt_type.to_string(), thread.clone()))
419                                {
420                                    if let Some(user_map) = event_map.get_mut(&old_event) {
421                                        user_map.remove(user_id);
422                                    }
423                                }
424                            }
425                        }
426
427                        // Add the receipt to the room event receipts
428                        inner
429                            .room_event_receipts
430                            .entry(room.clone())
431                            .or_default()
432                            .entry((receipt_type.to_string(), thread))
433                            .or_default()
434                            .entry(event_id.clone())
435                            .or_default()
436                            .insert(user_id.clone(), receipt.clone());
437                    }
438                }
439            }
440        }
441
442        let make_room_version = |room_info: &HashMap<OwnedRoomId, RoomInfo>, room_id| {
443            room_info.get(room_id).map(|info| info.room_version_or_default()).unwrap_or_else(|| {
444                warn!(
445                    ?room_id,
446                    "Unable to find the room version, assuming {ROOM_VERSION_FALLBACK}"
447                );
448                ROOM_VERSION_FALLBACK
449            })
450        };
451
452        let inner = &mut *inner;
453        for (room_id, redactions) in &changes.redactions {
454            let mut room_version = None;
455
456            if let Some(room) = inner.room_state.get_mut(room_id) {
457                for ref_room_mu in room.values_mut() {
458                    for raw_evt in ref_room_mu.values_mut() {
459                        if let Ok(Some(event_id)) = raw_evt.get_field::<OwnedEventId>("event_id") {
460                            if let Some(redaction) = redactions.get(&event_id) {
461                                let redacted = redact(
462                                    raw_evt.deserialize_as::<CanonicalJsonObject>()?,
463                                    room_version.get_or_insert_with(|| {
464                                        make_room_version(&inner.room_info, room_id)
465                                    }),
466                                    Some(RedactedBecause::from_raw_event(redaction)?),
467                                )
468                                .map_err(StoreError::Redaction)?;
469                                *raw_evt = Raw::new(&redacted)?.cast();
470                            }
471                        }
472                    }
473                }
474            }
475        }
476
477        debug!("Saved changes in {:?}", now.elapsed());
478
479        Ok(())
480    }
481
482    async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<Raw<PresenceEvent>>> {
483        Ok(self.inner.read().unwrap().presence.get(user_id).cloned())
484    }
485
486    async fn get_presence_events(
487        &self,
488        user_ids: &[OwnedUserId],
489    ) -> Result<Vec<Raw<PresenceEvent>>> {
490        let presence = &self.inner.read().unwrap().presence;
491        Ok(user_ids.iter().filter_map(|user_id| presence.get(user_id).cloned()).collect())
492    }
493
494    async fn get_state_event(
495        &self,
496        room_id: &RoomId,
497        event_type: StateEventType,
498        state_key: &str,
499    ) -> Result<Option<RawAnySyncOrStrippedState>> {
500        Ok(self
501            .get_state_events_for_keys(room_id, event_type, &[state_key])
502            .await?
503            .into_iter()
504            .next())
505    }
506
507    async fn get_state_events(
508        &self,
509        room_id: &RoomId,
510        event_type: StateEventType,
511    ) -> Result<Vec<RawAnySyncOrStrippedState>> {
512        fn get_events<T>(
513            state_map: &HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<T>>>>,
514            room_id: &RoomId,
515            event_type: &StateEventType,
516            to_enum: fn(Raw<T>) -> RawAnySyncOrStrippedState,
517        ) -> Option<Vec<RawAnySyncOrStrippedState>> {
518            let state_events = state_map.get(room_id)?.get(event_type)?;
519            Some(state_events.values().cloned().map(to_enum).collect())
520        }
521
522        let inner = self.inner.read().unwrap();
523        Ok(get_events(
524            &inner.stripped_room_state,
525            room_id,
526            &event_type,
527            RawAnySyncOrStrippedState::Stripped,
528        )
529        .or_else(|| {
530            get_events(&inner.room_state, room_id, &event_type, RawAnySyncOrStrippedState::Sync)
531        })
532        .unwrap_or_default())
533    }
534
535    async fn get_state_events_for_keys(
536        &self,
537        room_id: &RoomId,
538        event_type: StateEventType,
539        state_keys: &[&str],
540    ) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
541        let inner = self.inner.read().unwrap();
542
543        if let Some(stripped_state_events) =
544            inner.stripped_room_state.get(room_id).and_then(|events| events.get(&event_type))
545        {
546            Ok(state_keys
547                .iter()
548                .filter_map(|k| {
549                    stripped_state_events
550                        .get(*k)
551                        .map(|e| RawAnySyncOrStrippedState::Stripped(e.clone()))
552                })
553                .collect())
554        } else if let Some(sync_state_events) =
555            inner.room_state.get(room_id).and_then(|events| events.get(&event_type))
556        {
557            Ok(state_keys
558                .iter()
559                .filter_map(|k| {
560                    sync_state_events.get(*k).map(|e| RawAnySyncOrStrippedState::Sync(e.clone()))
561                })
562                .collect())
563        } else {
564            Ok(Vec::new())
565        }
566    }
567
568    async fn get_profile(
569        &self,
570        room_id: &RoomId,
571        user_id: &UserId,
572    ) -> Result<Option<MinimalRoomMemberEvent>> {
573        Ok(self
574            .inner
575            .read()
576            .unwrap()
577            .profiles
578            .get(room_id)
579            .and_then(|room_profiles| room_profiles.get(user_id))
580            .cloned())
581    }
582
583    async fn get_profiles<'a>(
584        &self,
585        room_id: &RoomId,
586        user_ids: &'a [OwnedUserId],
587    ) -> Result<BTreeMap<&'a UserId, MinimalRoomMemberEvent>> {
588        if user_ids.is_empty() {
589            return Ok(BTreeMap::new());
590        }
591
592        let profiles = &self.inner.read().unwrap().profiles;
593        let Some(room_profiles) = profiles.get(room_id) else {
594            return Ok(BTreeMap::new());
595        };
596
597        Ok(user_ids
598            .iter()
599            .filter_map(|user_id| room_profiles.get(user_id).map(|p| (&**user_id, p.clone())))
600            .collect())
601    }
602
603    #[instrument(skip(self, memberships))]
604    async fn get_user_ids(
605        &self,
606        room_id: &RoomId,
607        memberships: RoomMemberships,
608    ) -> Result<Vec<OwnedUserId>> {
609        /// Get the user IDs for the given room with the given memberships and
610        /// stripped state.
611        ///
612        /// If `memberships` is empty, returns all user IDs in the room with the
613        /// given stripped state.
614        fn get_user_ids_inner(
615            members: &HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
616            room_id: &RoomId,
617            memberships: RoomMemberships,
618        ) -> Vec<OwnedUserId> {
619            members
620                .get(room_id)
621                .map(|members| {
622                    members
623                        .iter()
624                        .filter_map(|(user_id, membership)| {
625                            memberships.matches(membership).then_some(user_id)
626                        })
627                        .cloned()
628                        .collect()
629                })
630                .unwrap_or_default()
631        }
632        let inner = self.inner.read().unwrap();
633        let v = get_user_ids_inner(&inner.stripped_members, room_id, memberships);
634        if !v.is_empty() {
635            return Ok(v);
636        }
637        Ok(get_user_ids_inner(&inner.members, room_id, memberships))
638    }
639
640    async fn get_room_infos(&self, room_load_settings: &RoomLoadSettings) -> Result<Vec<RoomInfo>> {
641        let memory_store_inner = self.inner.read().unwrap();
642        let room_infos = &memory_store_inner.room_info;
643
644        Ok(match room_load_settings {
645            RoomLoadSettings::All => room_infos.values().cloned().collect(),
646
647            RoomLoadSettings::One(room_id) => match room_infos.get(room_id) {
648                Some(room_info) => vec![room_info.clone()],
649                None => vec![],
650            },
651        })
652    }
653
654    async fn get_users_with_display_name(
655        &self,
656        room_id: &RoomId,
657        display_name: &DisplayName,
658    ) -> Result<BTreeSet<OwnedUserId>> {
659        Ok(self
660            .inner
661            .read()
662            .unwrap()
663            .display_names
664            .get(room_id)
665            .and_then(|room_names| room_names.get(display_name).cloned())
666            .unwrap_or_default())
667    }
668
669    async fn get_users_with_display_names<'a>(
670        &self,
671        room_id: &RoomId,
672        display_names: &'a [DisplayName],
673    ) -> Result<HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>> {
674        if display_names.is_empty() {
675            return Ok(HashMap::new());
676        }
677
678        let inner = self.inner.read().unwrap();
679        let Some(room_names) = inner.display_names.get(room_id) else {
680            return Ok(HashMap::new());
681        };
682
683        Ok(display_names.iter().filter_map(|n| room_names.get(n).map(|d| (n, d.clone()))).collect())
684    }
685
686    async fn get_account_data_event(
687        &self,
688        event_type: GlobalAccountDataEventType,
689    ) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>> {
690        Ok(self.inner.read().unwrap().account_data.get(&event_type).cloned())
691    }
692
693    async fn get_room_account_data_event(
694        &self,
695        room_id: &RoomId,
696        event_type: RoomAccountDataEventType,
697    ) -> Result<Option<Raw<AnyRoomAccountDataEvent>>> {
698        Ok(self
699            .inner
700            .read()
701            .unwrap()
702            .room_account_data
703            .get(room_id)
704            .and_then(|m| m.get(&event_type))
705            .cloned())
706    }
707
708    async fn get_user_room_receipt_event(
709        &self,
710        room_id: &RoomId,
711        receipt_type: ReceiptType,
712        thread: ReceiptThread,
713        user_id: &UserId,
714    ) -> Result<Option<(OwnedEventId, Receipt)>> {
715        Ok(self.get_user_room_receipt_event_impl(room_id, receipt_type, thread, user_id))
716    }
717
718    async fn get_event_room_receipt_events(
719        &self,
720        room_id: &RoomId,
721        receipt_type: ReceiptType,
722        thread: ReceiptThread,
723        event_id: &EventId,
724    ) -> Result<Vec<(OwnedUserId, Receipt)>> {
725        Ok(self
726            .get_event_room_receipt_events_impl(room_id, receipt_type, thread, event_id)
727            .unwrap_or_default())
728    }
729
730    async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
731        Ok(self.inner.read().unwrap().custom.get(key).cloned())
732    }
733
734    async fn set_custom_value(&self, key: &[u8], value: Vec<u8>) -> Result<Option<Vec<u8>>> {
735        Ok(self.inner.write().unwrap().custom.insert(key.to_vec(), value))
736    }
737
738    async fn remove_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
739        Ok(self.inner.write().unwrap().custom.remove(key))
740    }
741
742    async fn remove_room(&self, room_id: &RoomId) -> Result<()> {
743        let mut inner = self.inner.write().unwrap();
744
745        inner.profiles.remove(room_id);
746        inner.display_names.remove(room_id);
747        inner.members.remove(room_id);
748        inner.room_info.remove(room_id);
749        inner.room_state.remove(room_id);
750        inner.room_account_data.remove(room_id);
751        inner.stripped_room_state.remove(room_id);
752        inner.stripped_members.remove(room_id);
753        inner.room_user_receipts.remove(room_id);
754        inner.room_event_receipts.remove(room_id);
755        inner.send_queue_events.remove(room_id);
756        inner.dependent_send_queue_events.remove(room_id);
757
758        Ok(())
759    }
760
761    async fn save_send_queue_request(
762        &self,
763        room_id: &RoomId,
764        transaction_id: OwnedTransactionId,
765        created_at: MilliSecondsSinceUnixEpoch,
766        kind: QueuedRequestKind,
767        priority: usize,
768    ) -> Result<(), Self::Error> {
769        self.inner
770            .write()
771            .unwrap()
772            .send_queue_events
773            .entry(room_id.to_owned())
774            .or_default()
775            .push(QueuedRequest { kind, transaction_id, error: None, priority, created_at });
776        Ok(())
777    }
778
779    async fn update_send_queue_request(
780        &self,
781        room_id: &RoomId,
782        transaction_id: &TransactionId,
783        kind: QueuedRequestKind,
784    ) -> Result<bool, Self::Error> {
785        if let Some(entry) = self
786            .inner
787            .write()
788            .unwrap()
789            .send_queue_events
790            .entry(room_id.to_owned())
791            .or_default()
792            .iter_mut()
793            .find(|item| item.transaction_id == transaction_id)
794        {
795            entry.kind = kind;
796            entry.error = None;
797            Ok(true)
798        } else {
799            Ok(false)
800        }
801    }
802
803    async fn remove_send_queue_request(
804        &self,
805        room_id: &RoomId,
806        transaction_id: &TransactionId,
807    ) -> Result<bool, Self::Error> {
808        let mut inner = self.inner.write().unwrap();
809        let q = &mut inner.send_queue_events;
810
811        let entry = q.get_mut(room_id);
812        if let Some(entry) = entry {
813            // Find the event by id in its room queue, and remove it if present.
814            if let Some(pos) = entry.iter().position(|item| item.transaction_id == transaction_id) {
815                entry.remove(pos);
816                // And if this was the last event before removal, remove the entire room entry.
817                if entry.is_empty() {
818                    q.remove(room_id);
819                }
820                return Ok(true);
821            }
822        }
823
824        Ok(false)
825    }
826
827    async fn load_send_queue_requests(
828        &self,
829        room_id: &RoomId,
830    ) -> Result<Vec<QueuedRequest>, Self::Error> {
831        let mut ret = self
832            .inner
833            .write()
834            .unwrap()
835            .send_queue_events
836            .entry(room_id.to_owned())
837            .or_default()
838            .clone();
839        // Inverted order of priority, use stable sort to keep insertion order.
840        ret.sort_by(|lhs, rhs| rhs.priority.cmp(&lhs.priority));
841        Ok(ret)
842    }
843
844    async fn update_send_queue_request_status(
845        &self,
846        room_id: &RoomId,
847        transaction_id: &TransactionId,
848        error: Option<QueueWedgeError>,
849    ) -> Result<(), Self::Error> {
850        if let Some(entry) = self
851            .inner
852            .write()
853            .unwrap()
854            .send_queue_events
855            .entry(room_id.to_owned())
856            .or_default()
857            .iter_mut()
858            .find(|item| item.transaction_id == transaction_id)
859        {
860            entry.error = error;
861        }
862        Ok(())
863    }
864
865    async fn load_rooms_with_unsent_requests(&self) -> Result<Vec<OwnedRoomId>, Self::Error> {
866        Ok(self.inner.read().unwrap().send_queue_events.keys().cloned().collect())
867    }
868
869    async fn save_dependent_queued_request(
870        &self,
871        room: &RoomId,
872        parent_transaction_id: &TransactionId,
873        own_transaction_id: ChildTransactionId,
874        created_at: MilliSecondsSinceUnixEpoch,
875        content: DependentQueuedRequestKind,
876    ) -> Result<(), Self::Error> {
877        self.inner
878            .write()
879            .unwrap()
880            .dependent_send_queue_events
881            .entry(room.to_owned())
882            .or_default()
883            .push(DependentQueuedRequest {
884                kind: content,
885                parent_transaction_id: parent_transaction_id.to_owned(),
886                own_transaction_id,
887                parent_key: None,
888                created_at,
889            });
890        Ok(())
891    }
892
893    async fn mark_dependent_queued_requests_as_ready(
894        &self,
895        room: &RoomId,
896        parent_txn_id: &TransactionId,
897        sent_parent_key: SentRequestKey,
898    ) -> Result<usize, Self::Error> {
899        let mut inner = self.inner.write().unwrap();
900        let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
901        let mut num_updated = 0;
902        for d in dependents.iter_mut().filter(|item| item.parent_transaction_id == parent_txn_id) {
903            d.parent_key = Some(sent_parent_key.clone());
904            num_updated += 1;
905        }
906        Ok(num_updated)
907    }
908
909    async fn update_dependent_queued_request(
910        &self,
911        room: &RoomId,
912        own_transaction_id: &ChildTransactionId,
913        new_content: DependentQueuedRequestKind,
914    ) -> Result<bool, Self::Error> {
915        let mut inner = self.inner.write().unwrap();
916        let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
917        for d in dependents.iter_mut() {
918            if d.own_transaction_id == *own_transaction_id {
919                d.kind = new_content;
920                return Ok(true);
921            }
922        }
923        Ok(false)
924    }
925
926    async fn remove_dependent_queued_request(
927        &self,
928        room: &RoomId,
929        txn_id: &ChildTransactionId,
930    ) -> Result<bool, Self::Error> {
931        let mut inner = self.inner.write().unwrap();
932        let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
933        if let Some(pos) = dependents.iter().position(|item| item.own_transaction_id == *txn_id) {
934            dependents.remove(pos);
935            Ok(true)
936        } else {
937            Ok(false)
938        }
939    }
940
941    /// List all the dependent send queue events.
942    ///
943    /// This returns absolutely all the dependent send queue events, whether
944    /// they have an event id or not.
945    async fn load_dependent_queued_requests(
946        &self,
947        room: &RoomId,
948    ) -> Result<Vec<DependentQueuedRequest>, Self::Error> {
949        Ok(self
950            .inner
951            .read()
952            .unwrap()
953            .dependent_send_queue_events
954            .get(room)
955            .cloned()
956            .unwrap_or_default())
957    }
958}
959
960#[cfg(test)]
961mod tests {
962    use super::{MemoryStore, Result, StateStore};
963
964    async fn get_store() -> Result<impl StateStore> {
965        Ok(MemoryStore::new())
966    }
967
968    statestore_integration_tests!();
969}