use std::sync::Arc;
use matrix_sdk_base::{
deserialized_responses::TimelineEvent,
event_cache::{Event, Gap},
linked_chunk::OwnedLinkedChunkId,
};
use matrix_sdk_common::{linked_chunk::ChunkIdentifier, serde_helpers::extract_thread_root};
use ruma::{OwnedEventId, UInt, api::Direction};
use tokio::sync::{
RwLock,
broadcast::{Receiver, Sender},
};
use tracing::{instrument, trace};
#[cfg(feature = "e2e-encryption")]
use crate::event_cache::redecryptor::ResolvedUtd;
use crate::{
Room,
event_cache::{
EventCacheError, EventsOrigin, Result, RoomEventCacheLinkedChunkUpdate,
caches::{TimelineVectorDiffs, event_linked_chunk::EventLinkedChunk},
},
paginators::{PaginationResult, Paginator, StartFromResult, thread::PaginableThread},
room::{IncludeRelations, MessagesOptions, RelationsOptions, WeakRoom},
};
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
pub enum EventFocusThreadMode {
ForceThread,
Automatic,
}
#[derive(Debug, Clone)]
pub(crate) enum EventFocusedPaginationMode {
Room { hide_thread_events: bool },
Thread {
thread_root: OwnedEventId,
},
}
struct EventFocusedCacheInner {
room: WeakRoom,
focused_event_id: OwnedEventId,
pagination_mode: EventFocusedPaginationMode,
chunk: EventLinkedChunk,
sender: Sender<TimelineVectorDiffs>,
linked_chunk_update_sender: Sender<RoomEventCacheLinkedChunkUpdate>,
}
impl EventFocusedCacheInner {
#[instrument(skip(self, room), fields(room_id = %self.room.room_id(), event_id = %self.focused_event_id))]
async fn start_from(
&mut self,
room: Room,
num_context_events: u16,
thread_mode: EventFocusThreadMode,
) -> Result<StartFromResult> {
trace!(num_context_events, "fetching event with context via /context");
let paginator = Paginator::new(room);
let result =
paginator.start_from(&self.focused_event_id, UInt::from(num_context_events)).await?;
let thread_root = match thread_mode {
EventFocusThreadMode::ForceThread => {
let focused_event = result
.events
.iter()
.find(|event| event.event_id().as_ref() == Some(&self.focused_event_id));
let mut thread_root =
focused_event.and_then(|event| extract_thread_root(event.raw()));
if thread_root.is_none() {
thread_root = Some(self.focused_event_id.clone());
}
trace!("force thread mode enabled, treating focused event as thread root");
thread_root
}
EventFocusThreadMode::Automatic => {
trace!(
"automatic thread mode enabled, checking if focused event is part of a thread"
);
result
.events
.iter()
.find(|event| event.event_id().as_ref() == Some(&self.focused_event_id))
.and_then(|event| extract_thread_root(event.raw()))
}
};
let tokens = paginator.tokens();
if let Some(root_id) = thread_root {
trace!(thread_root = %root_id, "focused event is part of a thread, setting up thread pagination");
let includes_root =
result.events.iter().any(|event| event.event_id().as_ref() == Some(&root_id));
self.pagination_mode =
EventFocusedPaginationMode::Thread { thread_root: root_id.clone() };
let thread_events = result
.events
.iter()
.filter(|event| {
extract_thread_root(event.raw()).as_ref() == Some(&root_id)
|| event.event_id().as_ref() == Some(&root_id)
})
.cloned()
.collect();
let backward_token = if includes_root {
None
} else {
tokens.previous.into_token()
};
let forward_token = tokens.next.into_token();
self.add_initial_events_with_gaps(thread_events, backward_token, forward_token);
} else {
trace!("focused event is not part of a thread, setting up room pagination");
let backward_token = tokens.previous.into_token();
let forward_token = tokens.next.into_token();
let hide_thread_events =
matches!(thread_mode, EventFocusThreadMode::Automatic) && thread_root.is_none();
self.pagination_mode = EventFocusedPaginationMode::Room { hide_thread_events };
let events = if hide_thread_events {
result
.events
.iter()
.filter(|event| extract_thread_root(event.raw()).is_none())
.cloned()
.collect()
} else {
result.events.clone()
};
self.add_initial_events_with_gaps(events, backward_token, forward_token);
}
self.propagate_changes();
let _ = self.chunk.updates_as_vector_diffs();
Ok(result)
}
fn add_initial_events_with_gaps(
&mut self,
events: Vec<TimelineEvent>,
prev_gap_token: Option<String>,
next_gap_token: Option<String>,
) {
self.chunk
.push_live_events(prev_gap_token.map(|prev_token| Gap { token: prev_token }), &events);
if let Some(next_token) = next_gap_token {
trace!("inserting forward pagination gap at back");
self.chunk.push_gap(Gap { token: next_token });
}
}
fn propagate_changes(&mut self) {
let updates = self.chunk.store_updates().take();
if !updates.is_empty() {
let _ = self.linked_chunk_update_sender.send(RoomEventCacheLinkedChunkUpdate {
updates,
linked_chunk_id: OwnedLinkedChunkId::EventFocused(
self.room.room_id().to_owned(),
self.focused_event_id.clone(),
),
});
}
}
fn notify_subscribers(&mut self, origin: EventsOrigin) {
let diffs = self.chunk.updates_as_vector_diffs();
if !diffs.is_empty() {
let _ = self.sender.send(TimelineVectorDiffs { diffs, origin });
}
}
fn first_chunk_as_gap(&self) -> Option<(ChunkIdentifier, Gap)> {
self.chunk.first_chunk_as_gap()
}
fn last_chunk_as_gap(&self) -> Option<(ChunkIdentifier, Gap)> {
self.chunk.last_chunk_as_gap()
}
#[instrument(skip(self), fields(room_id = %self.room.room_id()))]
async fn paginate_backwards(&mut self, num_events: u16) -> Result<PaginationResult> {
let room = self.room.get().ok_or(EventCacheError::ClientDropped)?;
let Some((gap_id, gap)) = self.first_chunk_as_gap() else {
trace!("no front gap found, already at timeline start");
return Ok(PaginationResult { events: Vec::new(), hit_end_of_timeline: true });
};
let token = gap.token;
trace!(?token, "paginating backwards with token from front gap");
let (mut events, new_token) = match &self.pagination_mode {
EventFocusedPaginationMode::Room { .. } => {
Self::fetch_room_backwards(&room, num_events, &token).await?
}
EventFocusedPaginationMode::Thread { thread_root } => {
Self::fetch_thread_backwards(&room, num_events, &token, thread_root.clone()).await?
}
};
events.reverse();
let hit_end = new_token.is_none();
let new_gap = new_token.map(|t| Gap { token: t });
let hide_thread_events = match &self.pagination_mode {
EventFocusedPaginationMode::Room { hide_thread_events } => *hide_thread_events,
EventFocusedPaginationMode::Thread { .. } => false,
};
let events = if hide_thread_events {
events.into_iter().filter(|event| extract_thread_root(event.raw()).is_none()).collect()
} else {
events
};
self.chunk.push_backwards_pagination_events(Some(gap_id), new_gap, &events);
self.propagate_changes();
self.notify_subscribers(EventsOrigin::Pagination);
Ok(PaginationResult { events, hit_end_of_timeline: hit_end })
}
async fn fetch_room_backwards(
room: &Room,
num_events: u16,
token: &str,
) -> Result<(Vec<Event>, Option<String>)> {
let mut options = MessagesOptions::backward().from(token);
options.limit = UInt::from(num_events);
let messages = room
.messages(options)
.await
.map_err(|err| EventCacheError::PaginationError(Arc::new(err)))?;
Ok((messages.chunk, messages.end))
}
async fn fetch_thread_backwards(
room: &Room,
num_events: u16,
token: &str,
thread_root: OwnedEventId,
) -> Result<(Vec<Event>, Option<String>)> {
let options = RelationsOptions {
from: Some(token.to_owned()),
dir: Direction::Backward,
limit: Some(UInt::from(num_events)),
include_relations: IncludeRelations::AllRelations,
recurse: true,
};
let mut result = room
.relations(thread_root.clone(), options)
.await
.map_err(|err| EventCacheError::PaginationError(Arc::new(err)))?;
if result.next_batch_token.is_none() {
let root_event = room
.load_event(&thread_root)
.await
.map_err(|err| EventCacheError::PaginationError(Arc::new(err)))?;
result.chunk.push(root_event);
}
Ok((result.chunk, result.next_batch_token))
}
#[instrument(skip(self), fields(room_id = %self.room.room_id()))]
async fn paginate_forwards(&mut self, num_events: u16) -> Result<PaginationResult> {
let room = self.room.get().ok_or(EventCacheError::ClientDropped)?;
let Some((gap_id, gap)) = self.last_chunk_as_gap() else {
trace!("no back gap found, already at timeline end");
return Ok(PaginationResult { events: Vec::new(), hit_end_of_timeline: true });
};
let token = gap.token;
trace!(?token, "paginating forwards with token from back gap");
let (events, new_token) = match &self.pagination_mode {
EventFocusedPaginationMode::Room { .. } => {
Self::fetch_room_forwards(&room, num_events, &token).await?
}
EventFocusedPaginationMode::Thread { thread_root } => {
Self::fetch_thread_forwards(&room, num_events, &token, thread_root.clone()).await?
}
};
let hit_end = new_token.is_none();
let new_gap = new_token.map(|t| Gap { token: t });
let hide_thread_events = match &self.pagination_mode {
EventFocusedPaginationMode::Room { hide_thread_events } => *hide_thread_events,
EventFocusedPaginationMode::Thread { .. } => false,
};
let events = if hide_thread_events {
events.into_iter().filter(|event| extract_thread_root(event.raw()).is_none()).collect()
} else {
events
};
self.chunk.push_forwards_pagination_events(Some(gap_id), new_gap, &events);
self.propagate_changes();
self.notify_subscribers(EventsOrigin::Pagination);
Ok(PaginationResult { events, hit_end_of_timeline: hit_end })
}
async fn fetch_room_forwards(
room: &Room,
num_events: u16,
token: &str,
) -> Result<(Vec<Event>, Option<String>)> {
let mut options = MessagesOptions::new(Direction::Forward);
options = options.from(Some(token));
options.limit = UInt::from(num_events);
let messages = room
.messages(options)
.await
.map_err(|err| EventCacheError::PaginationError(Arc::new(err)))?;
Ok((messages.chunk, messages.end))
}
async fn fetch_thread_forwards(
room: &Room,
num_events: u16,
token: &str,
thread_root: OwnedEventId,
) -> Result<(Vec<Event>, Option<String>)> {
let options = RelationsOptions {
from: Some(token.to_owned()),
dir: Direction::Forward,
limit: Some(UInt::from(num_events)),
include_relations: IncludeRelations::AllRelations,
recurse: true,
};
let result = room
.relations(thread_root, options)
.await
.map_err(|err| EventCacheError::PaginationError(Arc::new(err)))?;
Ok((result.chunk, result.next_batch_token))
}
}
#[derive(Clone)]
pub struct EventFocusedCache {
inner: Arc<RwLock<EventFocusedCacheInner>>,
}
impl EventFocusedCache {
pub(super) fn new(
room: WeakRoom,
focused_event_id: OwnedEventId,
linked_chunk_update_sender: Sender<RoomEventCacheLinkedChunkUpdate>,
) -> Self {
Self {
inner: Arc::new(RwLock::new(EventFocusedCacheInner {
room,
focused_event_id,
pagination_mode: EventFocusedPaginationMode::Room { hide_thread_events: false },
chunk: EventLinkedChunk::new(),
sender: Sender::new(32),
linked_chunk_update_sender,
})),
}
}
pub async fn subscribe(&self) -> (Vec<Event>, Receiver<TimelineVectorDiffs>) {
let inner = self.inner.read().await;
let events = inner.chunk.events().map(|(_position, item)| item.clone()).collect();
let recv = inner.sender.subscribe();
(events, recv)
}
pub async fn hit_timeline_start(&self) -> bool {
self.inner.read().await.first_chunk_as_gap().is_none()
}
pub async fn hit_timeline_end(&self) -> bool {
self.inner.read().await.last_chunk_as_gap().is_none()
}
pub(super) async fn start_from(
&self,
room: Room,
num_context_events: u16,
thread_mode: EventFocusThreadMode,
) -> Result<StartFromResult> {
self.inner.write().await.start_from(room, num_context_events, thread_mode).await
}
pub async fn paginate_backwards(&self, num_events: u16) -> Result<PaginationResult> {
self.inner.write().await.paginate_backwards(num_events).await
}
pub async fn paginate_forwards(&self, num_events: u16) -> Result<PaginationResult> {
self.inner.write().await.paginate_forwards(num_events).await
}
pub async fn thread_root(&self) -> Option<OwnedEventId> {
match &self.inner.read().await.pagination_mode {
EventFocusedPaginationMode::Thread { thread_root } => Some(thread_root.clone()),
_ => None,
}
}
#[cfg(feature = "e2e-encryption")]
pub async fn replace_utds(&self, events: &[ResolvedUtd]) {
let mut guard = self.inner.write().await;
if guard.chunk.replace_utds(events) {
guard.propagate_changes();
guard.notify_subscribers(EventsOrigin::Cache);
}
}
}
#[cfg(not(tarpaulin_include))]
impl std::fmt::Debug for EventFocusedCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EventFocusedCache").finish_non_exhaustive()
}
}