1use std::{
16 collections::{BTreeMap, BTreeSet, HashMap},
17 sync::RwLock,
18};
19
20use async_trait::async_trait;
21use growable_bloom_filter::GrowableBloom;
22use matrix_sdk_common::{ROOM_VERSION_FALLBACK, ROOM_VERSION_RULES_FALLBACK};
23use ruma::{
24 CanonicalJsonObject, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedMxcUri,
25 OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId, UserId,
26 canonical_json::{RedactedBecause, redact},
27 events::{
28 AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent,
29 AnySyncStateEvent, GlobalAccountDataEventType, RoomAccountDataEventType, StateEventType,
30 presence::PresenceEvent,
31 receipt::{Receipt, ReceiptThread, ReceiptType},
32 room::member::{MembershipState, StrippedRoomMemberEvent, SyncRoomMemberEvent},
33 },
34 serde::Raw,
35 time::Instant,
36};
37use tracing::{debug, instrument, warn};
38
39use super::{
40 DependentQueuedRequest, DependentQueuedRequestKind, QueuedRequestKind, Result, RoomInfo,
41 RoomLoadSettings, StateChanges, StateStore, StoreError, SupportedVersionsResponse,
42 TtlStoreValue, WellKnownResponse,
43 send_queue::{ChildTransactionId, QueuedRequest, SentRequestKey},
44 traits::ComposerDraft,
45};
46use crate::{
47 MinimalRoomMemberEvent, RoomMemberships, StateStoreDataKey, StateStoreDataValue,
48 deserialized_responses::{DisplayName, RawAnySyncOrStrippedState},
49 store::{
50 QueueWedgeError, StoredThreadSubscription,
51 traits::{ThreadSubscriptionCatchupToken, compare_thread_subscription_bump_stamps},
52 },
53};
54
55#[derive(Debug, Default)]
56#[allow(clippy::type_complexity)]
57struct MemoryStoreInner {
58 recently_visited_rooms: HashMap<OwnedUserId, Vec<OwnedRoomId>>,
59 composer_drafts: HashMap<(OwnedRoomId, Option<OwnedEventId>), ComposerDraft>,
60 user_avatar_url: HashMap<OwnedUserId, OwnedMxcUri>,
61 sync_token: Option<String>,
62 supported_versions: Option<TtlStoreValue<SupportedVersionsResponse>>,
63 well_known: Option<TtlStoreValue<Option<WellKnownResponse>>>,
64 filters: HashMap<String, String>,
65 utd_hook_manager_data: Option<GrowableBloom>,
66 one_time_key_uploaded_error: bool,
67 account_data: HashMap<GlobalAccountDataEventType, Raw<AnyGlobalAccountDataEvent>>,
68 profiles: HashMap<OwnedRoomId, HashMap<OwnedUserId, MinimalRoomMemberEvent>>,
69 display_names: HashMap<OwnedRoomId, HashMap<DisplayName, BTreeSet<OwnedUserId>>>,
70 members: HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
71 room_info: HashMap<OwnedRoomId, RoomInfo>,
72 room_state:
73 HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<AnySyncStateEvent>>>>,
74 room_account_data:
75 HashMap<OwnedRoomId, HashMap<RoomAccountDataEventType, Raw<AnyRoomAccountDataEvent>>>,
76 stripped_room_state:
77 HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<AnyStrippedStateEvent>>>>,
78 stripped_members: HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
79 presence: HashMap<OwnedUserId, Raw<PresenceEvent>>,
80 room_user_receipts: HashMap<
81 OwnedRoomId,
82 HashMap<(String, Option<String>), HashMap<OwnedUserId, (OwnedEventId, Receipt)>>,
83 >,
84 room_event_receipts: HashMap<
85 OwnedRoomId,
86 HashMap<(String, Option<String>), HashMap<OwnedEventId, HashMap<OwnedUserId, Receipt>>>,
87 >,
88 custom: HashMap<Vec<u8>, Vec<u8>>,
89 send_queue_events: BTreeMap<OwnedRoomId, Vec<QueuedRequest>>,
90 dependent_send_queue_events: BTreeMap<OwnedRoomId, Vec<DependentQueuedRequest>>,
91 seen_knock_requests: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, OwnedUserId>>,
92 thread_subscriptions: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, StoredThreadSubscription>>,
93 thread_subscriptions_catchup_tokens: Option<Vec<ThreadSubscriptionCatchupToken>>,
94}
95
96#[derive(Debug, Default)]
100pub struct MemoryStore {
101 inner: RwLock<MemoryStoreInner>,
102}
103
104impl MemoryStore {
105 pub fn new() -> Self {
107 Self::default()
108 }
109
110 fn get_user_room_receipt_event_impl(
111 &self,
112 room_id: &RoomId,
113 receipt_type: ReceiptType,
114 thread: ReceiptThread,
115 user_id: &UserId,
116 ) -> Option<(OwnedEventId, Receipt)> {
117 self.inner
118 .read()
119 .unwrap()
120 .room_user_receipts
121 .get(room_id)?
122 .get(&(receipt_type.to_string(), thread.as_str().map(ToOwned::to_owned)))?
123 .get(user_id)
124 .cloned()
125 }
126
127 fn get_event_room_receipt_events_impl(
128 &self,
129 room_id: &RoomId,
130 receipt_type: ReceiptType,
131 thread: ReceiptThread,
132 event_id: &EventId,
133 ) -> Option<Vec<(OwnedUserId, Receipt)>> {
134 Some(
135 self.inner
136 .read()
137 .unwrap()
138 .room_event_receipts
139 .get(room_id)?
140 .get(&(receipt_type.to_string(), thread.as_str().map(ToOwned::to_owned)))?
141 .get(event_id)?
142 .iter()
143 .map(|(key, value)| (key.clone(), value.clone()))
144 .collect(),
145 )
146 }
147}
148
149#[cfg_attr(target_family = "wasm", async_trait(?Send))]
150#[cfg_attr(not(target_family = "wasm"), async_trait)]
151impl StateStore for MemoryStore {
152 type Error = StoreError;
153
154 async fn get_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<Option<StateStoreDataValue>> {
155 let inner = self.inner.read().unwrap();
156
157 Ok(match key {
158 StateStoreDataKey::SyncToken => {
159 inner.sync_token.clone().map(StateStoreDataValue::SyncToken)
160 }
161 StateStoreDataKey::SupportedVersions => {
162 inner.supported_versions.clone().map(StateStoreDataValue::SupportedVersions)
163 }
164 StateStoreDataKey::WellKnown => {
165 inner.well_known.clone().map(StateStoreDataValue::WellKnown)
166 }
167 StateStoreDataKey::Filter(filter_name) => {
168 inner.filters.get(filter_name).cloned().map(StateStoreDataValue::Filter)
169 }
170 StateStoreDataKey::UserAvatarUrl(user_id) => {
171 inner.user_avatar_url.get(user_id).cloned().map(StateStoreDataValue::UserAvatarUrl)
172 }
173 StateStoreDataKey::RecentlyVisitedRooms(user_id) => inner
174 .recently_visited_rooms
175 .get(user_id)
176 .cloned()
177 .map(StateStoreDataValue::RecentlyVisitedRooms),
178 StateStoreDataKey::UtdHookManagerData => {
179 inner.utd_hook_manager_data.clone().map(StateStoreDataValue::UtdHookManagerData)
180 }
181 StateStoreDataKey::OneTimeKeyAlreadyUploaded => inner
182 .one_time_key_uploaded_error
183 .then_some(StateStoreDataValue::OneTimeKeyAlreadyUploaded),
184 StateStoreDataKey::ComposerDraft(room_id, thread_root) => {
185 let key = (room_id.to_owned(), thread_root.map(ToOwned::to_owned));
186 inner.composer_drafts.get(&key).cloned().map(StateStoreDataValue::ComposerDraft)
187 }
188 StateStoreDataKey::SeenKnockRequests(room_id) => inner
189 .seen_knock_requests
190 .get(room_id)
191 .cloned()
192 .map(StateStoreDataValue::SeenKnockRequests),
193 StateStoreDataKey::ThreadSubscriptionsCatchupTokens => inner
194 .thread_subscriptions_catchup_tokens
195 .clone()
196 .map(StateStoreDataValue::ThreadSubscriptionsCatchupTokens),
197 })
198 }
199
200 async fn set_kv_data(
201 &self,
202 key: StateStoreDataKey<'_>,
203 value: StateStoreDataValue,
204 ) -> Result<()> {
205 let mut inner = self.inner.write().unwrap();
206 match key {
207 StateStoreDataKey::SyncToken => {
208 inner.sync_token =
209 Some(value.into_sync_token().expect("Session data not a sync token"))
210 }
211 StateStoreDataKey::Filter(filter_name) => {
212 inner.filters.insert(
213 filter_name.to_owned(),
214 value.into_filter().expect("Session data not a filter"),
215 );
216 }
217 StateStoreDataKey::UserAvatarUrl(user_id) => {
218 inner.user_avatar_url.insert(
219 user_id.to_owned(),
220 value.into_user_avatar_url().expect("Session data not a user avatar url"),
221 );
222 }
223 StateStoreDataKey::RecentlyVisitedRooms(user_id) => {
224 inner.recently_visited_rooms.insert(
225 user_id.to_owned(),
226 value
227 .into_recently_visited_rooms()
228 .expect("Session data not a list of recently visited rooms"),
229 );
230 }
231 StateStoreDataKey::UtdHookManagerData => {
232 inner.utd_hook_manager_data = Some(
233 value
234 .into_utd_hook_manager_data()
235 .expect("Session data not the hook manager data"),
236 );
237 }
238 StateStoreDataKey::OneTimeKeyAlreadyUploaded => {
239 inner.one_time_key_uploaded_error = true;
240 }
241 StateStoreDataKey::ComposerDraft(room_id, thread_root) => {
242 inner.composer_drafts.insert(
243 (room_id.to_owned(), thread_root.map(ToOwned::to_owned)),
244 value.into_composer_draft().expect("Session data not a composer draft"),
245 );
246 }
247 StateStoreDataKey::SupportedVersions => {
248 inner.supported_versions = Some(
249 value
250 .into_supported_versions()
251 .expect("Session data not containing supported versions"),
252 );
253 }
254 StateStoreDataKey::WellKnown => {
255 inner.well_known =
256 Some(value.into_well_known().expect("Session data not containing well-known"));
257 }
258 StateStoreDataKey::SeenKnockRequests(room_id) => {
259 inner.seen_knock_requests.insert(
260 room_id.to_owned(),
261 value
262 .into_seen_knock_requests()
263 .expect("Session data is not a set of seen join request ids"),
264 );
265 }
266 StateStoreDataKey::ThreadSubscriptionsCatchupTokens => {
267 inner.thread_subscriptions_catchup_tokens =
268 Some(value.into_thread_subscriptions_catchup_tokens().expect(
269 "Session data is not a list of thread subscription catchup tokens",
270 ));
271 }
272 }
273
274 Ok(())
275 }
276
277 async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<()> {
278 let mut inner = self.inner.write().unwrap();
279 match key {
280 StateStoreDataKey::SyncToken => inner.sync_token = None,
281 StateStoreDataKey::SupportedVersions => inner.supported_versions = None,
282 StateStoreDataKey::WellKnown => inner.well_known = None,
283 StateStoreDataKey::Filter(filter_name) => {
284 inner.filters.remove(filter_name);
285 }
286 StateStoreDataKey::UserAvatarUrl(user_id) => {
287 inner.user_avatar_url.remove(user_id);
288 }
289 StateStoreDataKey::RecentlyVisitedRooms(user_id) => {
290 inner.recently_visited_rooms.remove(user_id);
291 }
292 StateStoreDataKey::UtdHookManagerData => inner.utd_hook_manager_data = None,
293 StateStoreDataKey::OneTimeKeyAlreadyUploaded => {
294 inner.one_time_key_uploaded_error = false
295 }
296 StateStoreDataKey::ComposerDraft(room_id, thread_root) => {
297 let key = (room_id.to_owned(), thread_root.map(ToOwned::to_owned));
298 inner.composer_drafts.remove(&key);
299 }
300 StateStoreDataKey::SeenKnockRequests(room_id) => {
301 inner.seen_knock_requests.remove(room_id);
302 }
303 StateStoreDataKey::ThreadSubscriptionsCatchupTokens => {
304 inner.thread_subscriptions_catchup_tokens = None;
305 }
306 }
307 Ok(())
308 }
309
310 #[instrument(skip(self, changes))]
311 async fn save_changes(&self, changes: &StateChanges) -> Result<()> {
312 let now = Instant::now();
313
314 let mut inner = self.inner.write().unwrap();
315
316 if let Some(s) = &changes.sync_token {
317 inner.sync_token = Some(s.to_owned());
318 }
319
320 for (room, users) in &changes.profiles_to_delete {
321 let Some(room_profiles) = inner.profiles.get_mut(room) else {
322 continue;
323 };
324 for user in users {
325 room_profiles.remove(user);
326 }
327 }
328
329 for (room, users) in &changes.profiles {
330 for (user_id, profile) in users {
331 inner
332 .profiles
333 .entry(room.clone())
334 .or_default()
335 .insert(user_id.clone(), profile.clone());
336 }
337 }
338
339 for (room, map) in &changes.ambiguity_maps {
340 for (display_name, display_names) in map {
341 inner
342 .display_names
343 .entry(room.clone())
344 .or_default()
345 .insert(display_name.clone(), display_names.clone());
346 }
347 }
348
349 for (event_type, event) in &changes.account_data {
350 inner.account_data.insert(event_type.clone(), event.clone());
351 }
352
353 for (room, events) in &changes.room_account_data {
354 for (event_type, event) in events {
355 inner
356 .room_account_data
357 .entry(room.clone())
358 .or_default()
359 .insert(event_type.clone(), event.clone());
360 }
361 }
362
363 for (room, event_types) in &changes.state {
364 for (event_type, events) in event_types {
365 for (state_key, raw_event) in events {
366 inner
367 .room_state
368 .entry(room.clone())
369 .or_default()
370 .entry(event_type.clone())
371 .or_default()
372 .insert(state_key.to_owned(), raw_event.clone());
373 inner.stripped_room_state.remove(room);
374
375 if *event_type == StateEventType::RoomMember {
376 let event =
377 match raw_event.deserialize_as_unchecked::<SyncRoomMemberEvent>() {
378 Ok(ev) => ev,
379 Err(e) => {
380 let event_id: Option<String> =
381 raw_event.get_field("event_id").ok().flatten();
382 debug!(event_id, "Failed to deserialize member event: {e}");
383 continue;
384 }
385 };
386
387 inner.stripped_members.remove(room);
388
389 inner
390 .members
391 .entry(room.clone())
392 .or_default()
393 .insert(event.state_key().to_owned(), event.membership().clone());
394 }
395 }
396 }
397 }
398
399 for (room_id, info) in &changes.room_infos {
400 inner.room_info.insert(room_id.clone(), info.clone());
401 }
402
403 for (sender, event) in &changes.presence {
404 inner.presence.insert(sender.clone(), event.clone());
405 }
406
407 for (room, event_types) in &changes.stripped_state {
408 for (event_type, events) in event_types {
409 for (state_key, raw_event) in events {
410 inner
411 .stripped_room_state
412 .entry(room.clone())
413 .or_default()
414 .entry(event_type.clone())
415 .or_default()
416 .insert(state_key.to_owned(), raw_event.clone());
417
418 if *event_type == StateEventType::RoomMember {
419 let event =
420 match raw_event.deserialize_as_unchecked::<StrippedRoomMemberEvent>() {
421 Ok(ev) => ev,
422 Err(e) => {
423 let event_id: Option<String> =
424 raw_event.get_field("event_id").ok().flatten();
425 debug!(
426 event_id,
427 "Failed to deserialize stripped member event: {e}"
428 );
429 continue;
430 }
431 };
432
433 inner
434 .stripped_members
435 .entry(room.clone())
436 .or_default()
437 .insert(event.state_key, event.content.membership.clone());
438 }
439 }
440 }
441 }
442
443 for (room, content) in &changes.receipts {
444 for (event_id, receipts) in &content.0 {
445 for (receipt_type, receipts) in receipts {
446 for (user_id, receipt) in receipts {
447 let thread = receipt.thread.as_str().map(ToOwned::to_owned);
448 if let Some((old_event, _)) = inner
450 .room_user_receipts
451 .entry(room.clone())
452 .or_default()
453 .entry((receipt_type.to_string(), thread.clone()))
454 .or_default()
455 .insert(user_id.clone(), (event_id.clone(), receipt.clone()))
456 {
457 if let Some(receipt_map) = inner.room_event_receipts.get_mut(room)
459 && let Some(event_map) =
460 receipt_map.get_mut(&(receipt_type.to_string(), thread.clone()))
461 && let Some(user_map) = event_map.get_mut(&old_event)
462 {
463 user_map.remove(user_id);
464 }
465 }
466
467 inner
469 .room_event_receipts
470 .entry(room.clone())
471 .or_default()
472 .entry((receipt_type.to_string(), thread))
473 .or_default()
474 .entry(event_id.clone())
475 .or_default()
476 .insert(user_id.clone(), receipt.clone());
477 }
478 }
479 }
480 }
481
482 let make_redaction_rules = |room_info: &HashMap<OwnedRoomId, RoomInfo>, room_id| {
483 room_info.get(room_id).map(|info| info.room_version_rules_or_default()).unwrap_or_else(|| {
484 warn!(
485 ?room_id,
486 "Unable to get the room version rules, defaulting to rules for room version {ROOM_VERSION_FALLBACK}"
487 );
488 ROOM_VERSION_RULES_FALLBACK
489 }).redaction
490 };
491
492 let inner = &mut *inner;
493 for (room_id, redactions) in &changes.redactions {
494 let mut redaction_rules = None;
495
496 if let Some(room) = inner.room_state.get_mut(room_id) {
497 for ref_room_mu in room.values_mut() {
498 for raw_evt in ref_room_mu.values_mut() {
499 if let Ok(Some(event_id)) = raw_evt.get_field::<OwnedEventId>("event_id")
500 && let Some(redaction) = redactions.get(&event_id)
501 {
502 let redacted = redact(
503 raw_evt.deserialize_as::<CanonicalJsonObject>()?,
504 redaction_rules.get_or_insert_with(|| {
505 make_redaction_rules(&inner.room_info, room_id)
506 }),
507 Some(RedactedBecause::from_raw_event(redaction)?),
508 )
509 .map_err(StoreError::Redaction)?;
510 *raw_evt = Raw::new(&redacted)?.cast_unchecked();
511 }
512 }
513 }
514 }
515 }
516
517 debug!("Saved changes in {:?}", now.elapsed());
518
519 Ok(())
520 }
521
522 async fn get_presence_event(&self, user_id: &UserId) -> Result<Option<Raw<PresenceEvent>>> {
523 Ok(self.inner.read().unwrap().presence.get(user_id).cloned())
524 }
525
526 async fn get_presence_events(
527 &self,
528 user_ids: &[OwnedUserId],
529 ) -> Result<Vec<Raw<PresenceEvent>>> {
530 let presence = &self.inner.read().unwrap().presence;
531 Ok(user_ids.iter().filter_map(|user_id| presence.get(user_id).cloned()).collect())
532 }
533
534 async fn get_state_event(
535 &self,
536 room_id: &RoomId,
537 event_type: StateEventType,
538 state_key: &str,
539 ) -> Result<Option<RawAnySyncOrStrippedState>> {
540 Ok(self
541 .get_state_events_for_keys(room_id, event_type, &[state_key])
542 .await?
543 .into_iter()
544 .next())
545 }
546
547 async fn get_state_events(
548 &self,
549 room_id: &RoomId,
550 event_type: StateEventType,
551 ) -> Result<Vec<RawAnySyncOrStrippedState>> {
552 fn get_events<T>(
553 state_map: &HashMap<OwnedRoomId, HashMap<StateEventType, HashMap<String, Raw<T>>>>,
554 room_id: &RoomId,
555 event_type: &StateEventType,
556 to_enum: fn(Raw<T>) -> RawAnySyncOrStrippedState,
557 ) -> Option<Vec<RawAnySyncOrStrippedState>> {
558 let state_events = state_map.get(room_id)?.get(event_type)?;
559 Some(state_events.values().cloned().map(to_enum).collect())
560 }
561
562 let inner = self.inner.read().unwrap();
563 Ok(get_events(
564 &inner.stripped_room_state,
565 room_id,
566 &event_type,
567 RawAnySyncOrStrippedState::Stripped,
568 )
569 .or_else(|| {
570 get_events(&inner.room_state, room_id, &event_type, RawAnySyncOrStrippedState::Sync)
571 })
572 .unwrap_or_default())
573 }
574
575 async fn get_state_events_for_keys(
576 &self,
577 room_id: &RoomId,
578 event_type: StateEventType,
579 state_keys: &[&str],
580 ) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
581 let inner = self.inner.read().unwrap();
582
583 if let Some(stripped_state_events) =
584 inner.stripped_room_state.get(room_id).and_then(|events| events.get(&event_type))
585 {
586 Ok(state_keys
587 .iter()
588 .filter_map(|k| {
589 stripped_state_events
590 .get(*k)
591 .map(|e| RawAnySyncOrStrippedState::Stripped(e.clone()))
592 })
593 .collect())
594 } else if let Some(sync_state_events) =
595 inner.room_state.get(room_id).and_then(|events| events.get(&event_type))
596 {
597 Ok(state_keys
598 .iter()
599 .filter_map(|k| {
600 sync_state_events.get(*k).map(|e| RawAnySyncOrStrippedState::Sync(e.clone()))
601 })
602 .collect())
603 } else {
604 Ok(Vec::new())
605 }
606 }
607
608 async fn get_profile(
609 &self,
610 room_id: &RoomId,
611 user_id: &UserId,
612 ) -> Result<Option<MinimalRoomMemberEvent>> {
613 Ok(self
614 .inner
615 .read()
616 .unwrap()
617 .profiles
618 .get(room_id)
619 .and_then(|room_profiles| room_profiles.get(user_id))
620 .cloned())
621 }
622
623 async fn get_profiles<'a>(
624 &self,
625 room_id: &RoomId,
626 user_ids: &'a [OwnedUserId],
627 ) -> Result<BTreeMap<&'a UserId, MinimalRoomMemberEvent>> {
628 if user_ids.is_empty() {
629 return Ok(BTreeMap::new());
630 }
631
632 let profiles = &self.inner.read().unwrap().profiles;
633 let Some(room_profiles) = profiles.get(room_id) else {
634 return Ok(BTreeMap::new());
635 };
636
637 Ok(user_ids
638 .iter()
639 .filter_map(|user_id| room_profiles.get(user_id).map(|p| (&**user_id, p.clone())))
640 .collect())
641 }
642
643 #[instrument(skip(self, memberships))]
644 async fn get_user_ids(
645 &self,
646 room_id: &RoomId,
647 memberships: RoomMemberships,
648 ) -> Result<Vec<OwnedUserId>> {
649 fn get_user_ids_inner(
655 members: &HashMap<OwnedRoomId, HashMap<OwnedUserId, MembershipState>>,
656 room_id: &RoomId,
657 memberships: RoomMemberships,
658 ) -> Vec<OwnedUserId> {
659 members
660 .get(room_id)
661 .map(|members| {
662 members
663 .iter()
664 .filter_map(|(user_id, membership)| {
665 memberships.matches(membership).then_some(user_id)
666 })
667 .cloned()
668 .collect()
669 })
670 .unwrap_or_default()
671 }
672 let inner = self.inner.read().unwrap();
673 let v = get_user_ids_inner(&inner.stripped_members, room_id, memberships);
674 if !v.is_empty() {
675 return Ok(v);
676 }
677 Ok(get_user_ids_inner(&inner.members, room_id, memberships))
678 }
679
680 async fn get_room_infos(&self, room_load_settings: &RoomLoadSettings) -> Result<Vec<RoomInfo>> {
681 let memory_store_inner = self.inner.read().unwrap();
682 let room_infos = &memory_store_inner.room_info;
683
684 Ok(match room_load_settings {
685 RoomLoadSettings::All => room_infos.values().cloned().collect(),
686
687 RoomLoadSettings::One(room_id) => match room_infos.get(room_id) {
688 Some(room_info) => vec![room_info.clone()],
689 None => vec![],
690 },
691 })
692 }
693
694 async fn get_users_with_display_name(
695 &self,
696 room_id: &RoomId,
697 display_name: &DisplayName,
698 ) -> Result<BTreeSet<OwnedUserId>> {
699 Ok(self
700 .inner
701 .read()
702 .unwrap()
703 .display_names
704 .get(room_id)
705 .and_then(|room_names| room_names.get(display_name).cloned())
706 .unwrap_or_default())
707 }
708
709 async fn get_users_with_display_names<'a>(
710 &self,
711 room_id: &RoomId,
712 display_names: &'a [DisplayName],
713 ) -> Result<HashMap<&'a DisplayName, BTreeSet<OwnedUserId>>> {
714 if display_names.is_empty() {
715 return Ok(HashMap::new());
716 }
717
718 let inner = self.inner.read().unwrap();
719 let Some(room_names) = inner.display_names.get(room_id) else {
720 return Ok(HashMap::new());
721 };
722
723 Ok(display_names.iter().filter_map(|n| room_names.get(n).map(|d| (n, d.clone()))).collect())
724 }
725
726 async fn get_account_data_event(
727 &self,
728 event_type: GlobalAccountDataEventType,
729 ) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>> {
730 Ok(self.inner.read().unwrap().account_data.get(&event_type).cloned())
731 }
732
733 async fn get_room_account_data_event(
734 &self,
735 room_id: &RoomId,
736 event_type: RoomAccountDataEventType,
737 ) -> Result<Option<Raw<AnyRoomAccountDataEvent>>> {
738 Ok(self
739 .inner
740 .read()
741 .unwrap()
742 .room_account_data
743 .get(room_id)
744 .and_then(|m| m.get(&event_type))
745 .cloned())
746 }
747
748 async fn get_user_room_receipt_event(
749 &self,
750 room_id: &RoomId,
751 receipt_type: ReceiptType,
752 thread: ReceiptThread,
753 user_id: &UserId,
754 ) -> Result<Option<(OwnedEventId, Receipt)>> {
755 Ok(self.get_user_room_receipt_event_impl(room_id, receipt_type, thread, user_id))
756 }
757
758 async fn get_event_room_receipt_events(
759 &self,
760 room_id: &RoomId,
761 receipt_type: ReceiptType,
762 thread: ReceiptThread,
763 event_id: &EventId,
764 ) -> Result<Vec<(OwnedUserId, Receipt)>> {
765 Ok(self
766 .get_event_room_receipt_events_impl(room_id, receipt_type, thread, event_id)
767 .unwrap_or_default())
768 }
769
770 async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
771 Ok(self.inner.read().unwrap().custom.get(key).cloned())
772 }
773
774 async fn set_custom_value(&self, key: &[u8], value: Vec<u8>) -> Result<Option<Vec<u8>>> {
775 Ok(self.inner.write().unwrap().custom.insert(key.to_vec(), value))
776 }
777
778 async fn remove_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
779 Ok(self.inner.write().unwrap().custom.remove(key))
780 }
781
782 async fn remove_room(&self, room_id: &RoomId) -> Result<()> {
783 let mut inner = self.inner.write().unwrap();
784
785 inner.profiles.remove(room_id);
786 inner.display_names.remove(room_id);
787 inner.members.remove(room_id);
788 inner.room_info.remove(room_id);
789 inner.room_state.remove(room_id);
790 inner.room_account_data.remove(room_id);
791 inner.stripped_room_state.remove(room_id);
792 inner.stripped_members.remove(room_id);
793 inner.room_user_receipts.remove(room_id);
794 inner.room_event_receipts.remove(room_id);
795 inner.send_queue_events.remove(room_id);
796 inner.dependent_send_queue_events.remove(room_id);
797 inner.thread_subscriptions.remove(room_id);
798
799 Ok(())
800 }
801
802 async fn save_send_queue_request(
803 &self,
804 room_id: &RoomId,
805 transaction_id: OwnedTransactionId,
806 created_at: MilliSecondsSinceUnixEpoch,
807 kind: QueuedRequestKind,
808 priority: usize,
809 ) -> Result<(), Self::Error> {
810 self.inner
811 .write()
812 .unwrap()
813 .send_queue_events
814 .entry(room_id.to_owned())
815 .or_default()
816 .push(QueuedRequest { kind, transaction_id, error: None, priority, created_at });
817 Ok(())
818 }
819
820 async fn update_send_queue_request(
821 &self,
822 room_id: &RoomId,
823 transaction_id: &TransactionId,
824 kind: QueuedRequestKind,
825 ) -> Result<bool, Self::Error> {
826 if let Some(entry) = self
827 .inner
828 .write()
829 .unwrap()
830 .send_queue_events
831 .entry(room_id.to_owned())
832 .or_default()
833 .iter_mut()
834 .find(|item| item.transaction_id == transaction_id)
835 {
836 entry.kind = kind;
837 entry.error = None;
838 Ok(true)
839 } else {
840 Ok(false)
841 }
842 }
843
844 async fn remove_send_queue_request(
845 &self,
846 room_id: &RoomId,
847 transaction_id: &TransactionId,
848 ) -> Result<bool, Self::Error> {
849 let mut inner = self.inner.write().unwrap();
850 let q = &mut inner.send_queue_events;
851
852 let entry = q.get_mut(room_id);
853 if let Some(entry) = entry {
854 if let Some(pos) = entry.iter().position(|item| item.transaction_id == transaction_id) {
856 entry.remove(pos);
857 if entry.is_empty() {
859 q.remove(room_id);
860 }
861 return Ok(true);
862 }
863 }
864
865 Ok(false)
866 }
867
868 async fn load_send_queue_requests(
869 &self,
870 room_id: &RoomId,
871 ) -> Result<Vec<QueuedRequest>, Self::Error> {
872 let mut ret = self
873 .inner
874 .write()
875 .unwrap()
876 .send_queue_events
877 .entry(room_id.to_owned())
878 .or_default()
879 .clone();
880 ret.sort_by(|lhs, rhs| rhs.priority.cmp(&lhs.priority));
882 Ok(ret)
883 }
884
885 async fn update_send_queue_request_status(
886 &self,
887 room_id: &RoomId,
888 transaction_id: &TransactionId,
889 error: Option<QueueWedgeError>,
890 ) -> Result<(), Self::Error> {
891 if let Some(entry) = self
892 .inner
893 .write()
894 .unwrap()
895 .send_queue_events
896 .entry(room_id.to_owned())
897 .or_default()
898 .iter_mut()
899 .find(|item| item.transaction_id == transaction_id)
900 {
901 entry.error = error;
902 }
903 Ok(())
904 }
905
906 async fn load_rooms_with_unsent_requests(&self) -> Result<Vec<OwnedRoomId>, Self::Error> {
907 Ok(self.inner.read().unwrap().send_queue_events.keys().cloned().collect())
908 }
909
910 async fn save_dependent_queued_request(
911 &self,
912 room: &RoomId,
913 parent_transaction_id: &TransactionId,
914 own_transaction_id: ChildTransactionId,
915 created_at: MilliSecondsSinceUnixEpoch,
916 content: DependentQueuedRequestKind,
917 ) -> Result<(), Self::Error> {
918 self.inner
919 .write()
920 .unwrap()
921 .dependent_send_queue_events
922 .entry(room.to_owned())
923 .or_default()
924 .push(DependentQueuedRequest {
925 kind: content,
926 parent_transaction_id: parent_transaction_id.to_owned(),
927 own_transaction_id,
928 parent_key: None,
929 created_at,
930 });
931 Ok(())
932 }
933
934 async fn mark_dependent_queued_requests_as_ready(
935 &self,
936 room: &RoomId,
937 parent_txn_id: &TransactionId,
938 sent_parent_key: SentRequestKey,
939 ) -> Result<usize, Self::Error> {
940 let mut inner = self.inner.write().unwrap();
941 let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
942 let mut num_updated = 0;
943 for d in dependents.iter_mut().filter(|item| item.parent_transaction_id == parent_txn_id) {
944 d.parent_key = Some(sent_parent_key.clone());
945 num_updated += 1;
946 }
947 Ok(num_updated)
948 }
949
950 async fn update_dependent_queued_request(
951 &self,
952 room: &RoomId,
953 own_transaction_id: &ChildTransactionId,
954 new_content: DependentQueuedRequestKind,
955 ) -> Result<bool, Self::Error> {
956 let mut inner = self.inner.write().unwrap();
957 let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
958 for d in dependents.iter_mut() {
959 if d.own_transaction_id == *own_transaction_id {
960 d.kind = new_content;
961 return Ok(true);
962 }
963 }
964 Ok(false)
965 }
966
967 async fn remove_dependent_queued_request(
968 &self,
969 room: &RoomId,
970 txn_id: &ChildTransactionId,
971 ) -> Result<bool, Self::Error> {
972 let mut inner = self.inner.write().unwrap();
973 let dependents = inner.dependent_send_queue_events.entry(room.to_owned()).or_default();
974 if let Some(pos) = dependents.iter().position(|item| item.own_transaction_id == *txn_id) {
975 dependents.remove(pos);
976 Ok(true)
977 } else {
978 Ok(false)
979 }
980 }
981
982 async fn load_dependent_queued_requests(
983 &self,
984 room: &RoomId,
985 ) -> Result<Vec<DependentQueuedRequest>, Self::Error> {
986 Ok(self
987 .inner
988 .read()
989 .unwrap()
990 .dependent_send_queue_events
991 .get(room)
992 .cloned()
993 .unwrap_or_default())
994 }
995
996 async fn upsert_thread_subscription(
997 &self,
998 room: &RoomId,
999 thread_id: &EventId,
1000 mut new: StoredThreadSubscription,
1001 ) -> Result<(), Self::Error> {
1002 let mut inner = self.inner.write().unwrap();
1003 let room_subs = inner.thread_subscriptions.entry(room.to_owned()).or_default();
1004
1005 if let Some(previous) = room_subs.get(thread_id) {
1006 if *previous == new {
1008 return Ok(());
1009 }
1010 if !compare_thread_subscription_bump_stamps(previous.bump_stamp, &mut new.bump_stamp) {
1011 return Ok(());
1012 }
1013 }
1014
1015 room_subs.insert(thread_id.to_owned(), new);
1016
1017 Ok(())
1018 }
1019
1020 async fn load_thread_subscription(
1021 &self,
1022 room: &RoomId,
1023 thread_id: &EventId,
1024 ) -> Result<Option<StoredThreadSubscription>, Self::Error> {
1025 let inner = self.inner.read().unwrap();
1026 Ok(inner
1027 .thread_subscriptions
1028 .get(room)
1029 .and_then(|subscriptions| subscriptions.get(thread_id))
1030 .copied())
1031 }
1032
1033 async fn remove_thread_subscription(
1034 &self,
1035 room: &RoomId,
1036 thread_id: &EventId,
1037 ) -> Result<(), Self::Error> {
1038 let mut inner = self.inner.write().unwrap();
1039
1040 let Some(room_subs) = inner.thread_subscriptions.get_mut(room) else {
1041 return Ok(());
1042 };
1043
1044 room_subs.remove(thread_id);
1045
1046 if room_subs.is_empty() {
1047 inner.thread_subscriptions.remove(room);
1049 }
1050
1051 Ok(())
1052 }
1053
1054 async fn optimize(&self) -> Result<(), Self::Error> {
1055 Ok(())
1056 }
1057
1058 async fn get_size(&self) -> Result<Option<usize>, Self::Error> {
1059 Ok(None)
1060 }
1061}
1062
1063#[cfg(test)]
1064mod tests {
1065 use super::{MemoryStore, Result, StateStore};
1066
1067 async fn get_store() -> Result<impl StateStore> {
1068 Ok(MemoryStore::new())
1069 }
1070
1071 statestore_integration_tests!();
1072}