use std::collections::BTreeMap;
use eyeball::{AsyncLock, ObservableWriteGuard};
use ruma::{
OwnedEventId, OwnedUserId,
events::{
StateEventType, SyncStateEvent,
room::member::{MembershipState, RoomMemberEventContent},
},
};
use tracing::warn;
use super::Room;
use crate::{
StateStoreDataKey, StateStoreDataValue, StoreError,
deserialized_responses::{MemberEvent, RawMemberEvent, SyncOrStrippedState},
store::{Result as StoreResult, StateStoreExt},
};
impl Room {
pub async fn mark_knock_requests_as_seen(&self, user_ids: &[OwnedUserId]) -> StoreResult<()> {
let raw_user_ids: Vec<&str> = user_ids.iter().map(|id| id.as_str()).collect();
let member_raw_events = self
.store
.get_state_events_for_keys(self.room_id(), StateEventType::RoomMember, &raw_user_ids)
.await?;
let mut event_to_user_ids = Vec::with_capacity(member_raw_events.len());
for raw_event in member_raw_events {
let event = raw_event.cast::<RoomMemberEventContent>().deserialize()?;
match event {
SyncOrStrippedState::Sync(SyncStateEvent::Original(event)) => {
if event.content.membership == MembershipState::Knock {
event_to_user_ids.push((event.event_id, event.state_key))
} else {
warn!(
"Could not mark knock event as seen: event {} for user {} \
is not in Knock membership state.",
event.event_id, event.state_key
);
}
}
_ => warn!(
"Could not mark knock event as seen: event for user {} is not valid.",
event.state_key()
),
}
}
let current_seen_events_guard = self.get_write_guarded_current_knock_request_ids().await?;
let mut current_seen_events = current_seen_events_guard.clone().unwrap_or_default();
current_seen_events.extend(event_to_user_ids);
self.update_seen_knock_request_ids(current_seen_events_guard, current_seen_events).await?;
Ok(())
}
pub async fn remove_outdated_seen_knock_requests_ids(&self) -> StoreResult<()> {
let current_seen_events_guard = self.get_write_guarded_current_knock_request_ids().await?;
let mut current_seen_events = current_seen_events_guard.clone().unwrap_or_default();
let keys: Vec<OwnedUserId> = current_seen_events.values().map(|id| id.to_owned()).collect();
let raw_member_events: Vec<RawMemberEvent> =
self.store.get_state_events_for_keys_static(self.room_id(), &keys).await?;
let member_events = raw_member_events
.into_iter()
.map(|raw| raw.deserialize())
.collect::<Result<Vec<MemberEvent>, _>>()?;
let mut ids_to_remove = Vec::new();
for (event_id, user_id) in current_seen_events.iter() {
let matching_member = member_events.iter().find(|event| event.user_id() == user_id);
if let Some(member) = matching_member {
let member_event_id = member.event_id();
if *member.membership() != MembershipState::Knock
|| member_event_id.is_some_and(|id| id != event_id)
{
ids_to_remove.push(event_id.to_owned());
}
} else {
ids_to_remove.push(event_id.to_owned());
}
}
if ids_to_remove.is_empty() {
return Ok(());
}
for event_id in ids_to_remove {
current_seen_events.remove(&event_id);
}
self.update_seen_knock_request_ids(current_seen_events_guard, current_seen_events).await?;
Ok(())
}
pub async fn get_seen_knock_request_ids(
&self,
) -> Result<BTreeMap<OwnedEventId, OwnedUserId>, StoreError> {
Ok(self.get_write_guarded_current_knock_request_ids().await?.clone().unwrap_or_default())
}
async fn get_write_guarded_current_knock_request_ids(
&self,
) -> StoreResult<ObservableWriteGuard<'_, Option<BTreeMap<OwnedEventId, OwnedUserId>>, AsyncLock>>
{
let mut guard = self.seen_knock_request_ids_map.write().await;
if guard.is_none() {
let updated_seen_ids = self
.store
.get_kv_data(StateStoreDataKey::SeenKnockRequests(self.room_id()))
.await?
.and_then(|v| v.into_seen_knock_requests())
.unwrap_or_default();
ObservableWriteGuard::set(&mut guard, Some(updated_seen_ids));
}
Ok(guard)
}
async fn update_seen_knock_request_ids(
&self,
mut guard: ObservableWriteGuard<'_, Option<BTreeMap<OwnedEventId, OwnedUserId>>, AsyncLock>,
new_value: BTreeMap<OwnedEventId, OwnedUserId>,
) -> StoreResult<()> {
ObservableWriteGuard::set(&mut guard, Some(new_value.clone()));
self.store
.set_kv_data(
StateStoreDataKey::SeenKnockRequests(self.room_id()),
StateStoreDataValue::SeenKnockRequests(new_value),
)
.await?;
Ok(())
}
}