use std::sync::Arc;
use eyeball::{ObservableWriteGuard, SharedObservable, Subscriber};
use eyeball_im::{ObservableVector, VectorDiff, VectorSubscriberBatchedStream};
use futures_util::future::join_all;
use imbl::Vector;
use matrix_sdk::{
Result, Room,
deserialized_responses::TimelineEvent,
event_cache::{RoomEventCacheSubscriber, RoomEventCacheUpdate},
locks::Mutex,
paginators::PaginationToken,
room::ListThreadsOptions,
task_monitor::BackgroundTaskHandle,
};
use matrix_sdk_common::serde_helpers::extract_thread_root;
use ruma::{MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedUserId};
use tokio::sync::Mutex as AsyncMutex;
use tracing::{error, trace, warn};
use crate::timeline::{Profile, TimelineDetails, TimelineItemContent, traits::RoomDataProvider};
#[derive(Clone, Debug)]
pub struct ThreadListItem {
pub root_event: ThreadListItemEvent,
pub latest_event: Option<ThreadListItemEvent>,
pub num_replies: u32,
}
#[derive(Clone, Debug)]
pub struct ThreadListItemEvent {
pub event_id: OwnedEventId,
pub timestamp: MilliSecondsSinceUnixEpoch,
pub sender: OwnedUserId,
pub is_own: bool,
pub sender_profile: TimelineDetails<Profile>,
pub content: Option<TimelineItemContent>,
}
#[cfg_attr(feature = "uniffi", derive(uniffi::Enum))]
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum ThreadListPaginationState {
Idle {
end_reached: bool,
},
Loading,
}
#[derive(Debug, thiserror::Error)]
pub enum ThreadListServiceError {
#[error(transparent)]
Sdk(#[from] matrix_sdk::Error),
}
pub struct ThreadListService {
room: Room,
token: AsyncMutex<PaginationToken>,
pagination_state: SharedObservable<ThreadListPaginationState>,
items: Arc<Mutex<ObservableVector<ThreadListItem>>>,
_event_cache_task: BackgroundTaskHandle,
}
impl ThreadListService {
pub fn new(room: Room) -> Self {
let items: Arc<Mutex<ObservableVector<ThreadListItem>>> =
Arc::new(Mutex::new(ObservableVector::new()));
if let Err(e) = room.client().event_cache().subscribe() {
warn!("ThreadListService: failed to subscribe event cache to sync: {e}");
}
let event_cache_task = room
.client()
.task_monitor()
.spawn_infinite_task("thread_list_service::event_cache_listener", {
let room = room.clone();
let items = items.clone();
async move {
let (_event_cache_drop, mut subscriber) = match async {
let (room_event_cache, drop_handles) = room.event_cache().await?;
let (_, subscriber) = room_event_cache.subscribe().await?;
matrix_sdk::event_cache::Result::Ok((drop_handles, subscriber))
}
.await
{
Ok(pair) => pair,
Err(e) => {
error!(
"ThreadListService: failed to subscribe to room event cache, \
live updates will not work: {e}"
);
return;
}
};
trace!("ThreadListService: event cache listener started");
Self::event_cache_listener_loop(&room, &mut subscriber, items).await;
}
})
.abort_on_drop();
Self {
room,
token: AsyncMutex::new(PaginationToken::None),
pagination_state: SharedObservable::new(ThreadListPaginationState::Idle {
end_reached: false,
}),
items,
_event_cache_task: event_cache_task,
}
}
pub fn pagination_state(&self) -> ThreadListPaginationState {
self.pagination_state.get()
}
pub fn subscribe_to_pagination_state_updates(&self) -> Subscriber<ThreadListPaginationState> {
self.pagination_state.subscribe()
}
pub fn items(&self) -> Vec<ThreadListItem> {
self.items.lock().iter().cloned().collect()
}
pub fn subscribe_to_items_updates(
&self,
) -> (Vector<ThreadListItem>, VectorSubscriberBatchedStream<ThreadListItem>) {
self.items.lock().subscribe().into_values_and_batched_stream()
}
pub async fn paginate(&self) -> Result<(), ThreadListServiceError> {
{
let mut pagination_state = self.pagination_state.write();
match *pagination_state {
ThreadListPaginationState::Idle { end_reached: true }
| ThreadListPaginationState::Loading => return Ok(()),
_ => {}
}
ObservableWriteGuard::set(&mut pagination_state, ThreadListPaginationState::Loading);
}
let mut pagination_token = self.token.lock().await;
let from = match &*pagination_token {
PaginationToken::HasMore(token) => Some(token.clone()),
_ => None,
};
let opts = ListThreadsOptions { from, ..Default::default() };
match self.load_thread_list(opts).await {
Ok(thread_list) => {
*pagination_token = match &thread_list.prev_batch_token {
Some(token) => PaginationToken::HasMore(token.clone()),
None => PaginationToken::HitEnd,
};
let end_reached = thread_list.prev_batch_token.is_none();
self.items.lock().append(thread_list.items.into());
self.pagination_state.set(ThreadListPaginationState::Idle { end_reached });
Ok(())
}
Err(err) => {
self.pagination_state.set(ThreadListPaginationState::Idle { end_reached: false });
Err(ThreadListServiceError::Sdk(err))
}
}
}
pub async fn reset(&self) {
let mut pagination_token = self.token.lock().await;
*pagination_token = PaginationToken::None;
self.items.lock().clear();
self.pagination_state.set(ThreadListPaginationState::Idle { end_reached: false });
}
async fn load_thread_list(&self, opts: ListThreadsOptions) -> Result<ThreadList> {
let thread_roots = self.room.list_threads(opts).await?;
let list_items = join_all(
thread_roots
.chunk
.into_iter()
.map(|timeline_event| Self::build_thread_list_item(&self.room, timeline_event))
.collect::<Vec<_>>(),
)
.await
.into_iter()
.flatten()
.collect();
Ok(ThreadList { items: list_items, prev_batch_token: thread_roots.prev_batch_token })
}
async fn build_thread_list_item(
room: &Room,
timeline_event: TimelineEvent,
) -> Option<ThreadListItem> {
let thread_summary = timeline_event.thread_summary.summary().cloned();
let bundled_latest_thread_event = timeline_event.bundled_latest_thread_event.clone();
let root_event = Self::build_event(room, timeline_event).await?;
let num_replies = thread_summary.as_ref().map(|s| s.num_replies).unwrap_or(0);
let latest_event = if let Some(ev) = bundled_latest_thread_event.map(|b| *b) {
Self::build_event(room, ev).await
} else {
None
};
Some(ThreadListItem { root_event, latest_event, num_replies })
}
async fn build_event(
room: &Room,
timeline_event: TimelineEvent,
) -> Option<ThreadListItemEvent> {
let event_id = timeline_event.event_id()?;
let timestamp = timeline_event.timestamp()?;
let sender = timeline_event.sender()?;
let is_own = room.own_user_id() == sender;
let sender_profile =
TimelineDetails::from_initial_value(Profile::load(room, &sender).await);
let content = TimelineItemContent::from_event(room, timeline_event).await;
Some(ThreadListItemEvent { event_id, timestamp, sender, is_own, sender_profile, content })
}
async fn event_cache_listener_loop(
room: &Room,
subscriber: &mut RoomEventCacheSubscriber,
items: Arc<Mutex<ObservableVector<ThreadListItem>>>,
) {
use tokio::sync::broadcast::error::RecvError;
loop {
let update = match subscriber.recv().await {
Ok(update) => update,
Err(RecvError::Closed) => {
error!("ThreadListService: event cache channel closed, stopping listener");
break;
}
Err(RecvError::Lagged(n)) => {
warn!("ThreadListService: lagged behind {n} event cache updates");
continue;
}
};
if let RoomEventCacheUpdate::UpdateTimelineEvents(timeline_diffs) = update {
let new_events = Self::collect_events_from_diffs(timeline_diffs.diffs);
for event in new_events {
let Some(thread_root) = extract_thread_root(event.raw()) else { continue };
let position = {
let guard = items.lock();
guard.iter().position(|item| item.root_event.event_id == thread_root)
};
if let Some(index) = position {
if let Some(latest_event) = Self::build_event(room, event).await {
let mut guard = items.lock();
if index < guard.len()
&& guard[index].root_event.event_id == thread_root
{
let mut updated = guard[index].clone();
updated.latest_event = Some(latest_event);
updated.num_replies = updated.num_replies.saturating_add(1);
guard.set(index, updated);
}
}
}
}
}
}
}
fn collect_events_from_diffs(
diffs: Vec<VectorDiff<matrix_sdk_base::event_cache::Event>>,
) -> Vec<matrix_sdk_base::event_cache::Event> {
let mut events = Vec::new();
for diff in diffs {
match diff {
VectorDiff::Append { values } => events.extend(values),
VectorDiff::PushBack { value }
| VectorDiff::PushFront { value }
| VectorDiff::Insert { value, .. }
| VectorDiff::Set { value, .. } => events.push(value),
VectorDiff::Reset { values } => events.extend(values),
VectorDiff::Clear
| VectorDiff::PopBack
| VectorDiff::PopFront
| VectorDiff::Remove { .. }
| VectorDiff::Truncate { .. } => {}
}
}
events
}
}
#[derive(Clone, Debug)]
struct ThreadList {
pub items: Vec<ThreadListItem>,
pub prev_batch_token: Option<String>,
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use futures_util::pin_mut;
use matrix_sdk::test_utils::mocks::MatrixMockServer;
use matrix_sdk_test::{async_test, event_factory::EventFactory};
use ruma::{event_id, events::AnyTimelineEvent, room_id, serde::Raw, user_id};
use serde_json::json;
use stream_assert::{assert_next_matches, assert_pending};
use wiremock::ResponseTemplate;
use super::{ThreadListPaginationState, ThreadListService};
#[async_test]
async fn test_initial_state() {
let server = MatrixMockServer::new().await;
let service = make_service(&server).await;
assert_eq!(
service.pagination_state(),
ThreadListPaginationState::Idle { end_reached: false }
);
assert!(service.items().is_empty());
}
#[async_test]
async fn test_pagination() {
let server = MatrixMockServer::new().await;
let client = server.client_builder().build().await;
let room_id = room_id!("!a:b.c");
let sender_id = user_id!("@alice:b.c");
let f = EventFactory::new().room(room_id).sender(sender_id);
let eid1 = event_id!("$1");
let eid2 = event_id!("$2");
server
.mock_room_threads()
.ok(
vec![f.text_msg("Thread root 1").event_id(eid1).into_raw()],
Some("next_page_token".to_owned()),
)
.mock_once()
.mount()
.await;
server
.mock_room_threads()
.match_from("next_page_token")
.ok(vec![f.text_msg("Thread root 2").event_id(eid2).into_raw()], None)
.mock_once()
.mount()
.await;
let room = server.sync_joined_room(&client, room_id).await;
let service = ThreadListService::new(room);
service.paginate().await.expect("first paginate failed");
assert_eq!(
service.pagination_state(),
ThreadListPaginationState::Idle { end_reached: false }
);
assert_eq!(service.items().len(), 1);
assert_eq!(service.items()[0].root_event.event_id, eid1);
service.paginate().await.expect("second paginate failed");
assert_eq!(
service.pagination_state(),
ThreadListPaginationState::Idle { end_reached: true }
);
assert_eq!(service.items().len(), 2);
assert_eq!(service.items()[1].root_event.event_id, eid2);
}
#[async_test]
async fn test_pagination_end_reached() {
let server = MatrixMockServer::new().await;
let client = server.client_builder().build().await;
let room_id = room_id!("!a:b.c");
let sender_id = user_id!("@alice:b.c");
let f = EventFactory::new().room(room_id).sender(sender_id);
let eid1 = event_id!("$1");
server
.mock_room_threads()
.ok(vec![f.text_msg("Thread root").event_id(eid1).into_raw()], None)
.mock_once()
.mount()
.await;
let room = server.sync_joined_room(&client, room_id).await;
let service = ThreadListService::new(room);
service.paginate().await.expect("paginate failed");
assert_eq!(
service.pagination_state(),
ThreadListPaginationState::Idle { end_reached: true }
);
assert_eq!(service.items().len(), 1);
service.paginate().await.expect("second paginate should be a no-op");
assert_eq!(service.items().len(), 1);
assert_eq!(
service.pagination_state(),
ThreadListPaginationState::Idle { end_reached: true }
);
}
#[async_test]
async fn test_concurrent_pagination_is_not_possible() {
let server = MatrixMockServer::new().await;
let client = server.client_builder().build().await;
let room_id = room_id!("!a:b.c");
let sender_id = user_id!("@alice:b.c");
let f = EventFactory::new().room(room_id).sender(sender_id);
let eid1 = event_id!("$1");
let chunk: Vec<Raw<AnyTimelineEvent>> =
vec![f.text_msg("Thread root").event_id(eid1).into_raw()];
server
.mock_room_threads()
.respond_with(
ResponseTemplate::new(200)
.set_body_json(json!({ "chunk": chunk, "next_batch": null }))
.set_delay(Duration::from_millis(100)),
)
.expect(1)
.mount()
.await;
let room = server.sync_joined_room(&client, room_id).await;
let service = ThreadListService::new(room);
let (first, second) = tokio::join!(service.paginate(), service.paginate());
first.expect("first paginate should succeed");
second.expect("second (concurrent) paginate should succeed as a no-op");
assert_eq!(service.items().len(), 1);
assert_eq!(service.items()[0].root_event.event_id, eid1);
assert_eq!(
service.pagination_state(),
ThreadListPaginationState::Idle { end_reached: true }
);
}
#[async_test]
async fn test_pagination_error() {
let server = MatrixMockServer::new().await;
let client = server.client_builder().build().await;
let room_id = room_id!("!a:b.c");
server.mock_room_threads().error500().mock_once().mount().await;
let room = server.sync_joined_room(&client, room_id).await;
let service = ThreadListService::new(room);
service.paginate().await.expect_err("paginate should fail on a 500 response");
assert_eq!(
service.pagination_state(),
ThreadListPaginationState::Idle { end_reached: false }
);
assert!(service.items().is_empty());
}
#[async_test]
async fn test_reset() {
let server = MatrixMockServer::new().await;
let client = server.client_builder().build().await;
let room_id = room_id!("!a:b.c");
let sender_id = user_id!("@alice:b.c");
let f = EventFactory::new().room(room_id).sender(sender_id);
let eid1 = event_id!("$1");
server
.mock_room_threads()
.ok(vec![f.text_msg("Thread root").event_id(eid1).into_raw()], None)
.expect(2)
.mount()
.await;
let room = server.sync_joined_room(&client, room_id).await;
let service = ThreadListService::new(room);
service.paginate().await.expect("first paginate failed");
assert_eq!(service.items().len(), 1);
assert_eq!(
service.pagination_state(),
ThreadListPaginationState::Idle { end_reached: true }
);
service.reset().await;
assert!(service.items().is_empty());
assert_eq!(
service.pagination_state(),
ThreadListPaginationState::Idle { end_reached: false }
);
service.paginate().await.expect("paginate after reset failed");
assert_eq!(service.items().len(), 1);
}
#[async_test]
async fn test_pagination_state_subscriber() {
let server = MatrixMockServer::new().await;
let client = server.client_builder().build().await;
let room_id = room_id!("!a:b.c");
let sender_id = user_id!("@alice:b.c");
let f = EventFactory::new().room(room_id).sender(sender_id);
let eid1 = event_id!("$1");
server
.mock_room_threads()
.ok(
vec![f.text_msg("Thread root").event_id(eid1).into_raw()],
Some("next_token".to_owned()),
)
.mock_once()
.mount()
.await;
let room = server.sync_joined_room(&client, room_id).await;
let service = ThreadListService::new(room);
let subscriber = service.subscribe_to_pagination_state_updates();
pin_mut!(subscriber);
assert_pending!(subscriber);
service.paginate().await.expect("paginate failed");
assert_next_matches!(subscriber, ThreadListPaginationState::Idle { end_reached: false });
}
#[async_test]
async fn test_paginated_items_have_num_replies_zero_without_summary() {
let server = MatrixMockServer::new().await;
let client = server.client_builder().build().await;
let room_id = room_id!("!a:b.c");
let sender_id = user_id!("@alice:b.c");
let f = EventFactory::new().room(room_id).sender(sender_id);
let eid1 = event_id!("$1");
server
.mock_room_threads()
.ok(vec![f.text_msg("Thread root").event_id(eid1).into_raw()], None)
.mock_once()
.mount()
.await;
let room = server.sync_joined_room(&client, room_id).await;
let service = ThreadListService::new(room);
service.paginate().await.expect("paginate failed");
let items = service.items();
assert_eq!(items.len(), 1);
assert_eq!(items[0].num_replies, 0);
assert!(items[0].latest_event.is_none());
}
#[async_test]
async fn test_paginated_items_have_num_replies_from_bundled_summary() {
let server = MatrixMockServer::new().await;
let client = server.client_builder().build().await;
let room_id = room_id!("!a:b.c");
let sender_id = user_id!("@alice:b.c");
let f = EventFactory::new().room(room_id).sender(sender_id);
let root_id = event_id!("$root");
let reply_id = event_id!("$reply");
let reply_event =
f.text_msg("Reply in thread").event_id(reply_id).into_raw_sync().cast_unchecked();
let thread_root = f
.text_msg("Thread root")
.event_id(root_id)
.with_bundled_thread_summary(reply_event, 3, false)
.into_raw();
server.mock_room_threads().ok(vec![thread_root], None).mock_once().mount().await;
let room = server.sync_joined_room(&client, room_id).await;
let service = ThreadListService::new(room);
service.paginate().await.expect("paginate failed");
let items = service.items();
assert_eq!(items.len(), 1);
assert_eq!(items[0].root_event.event_id, root_id);
assert_eq!(items[0].num_replies, 3);
let latest = items[0].latest_event.as_ref().expect("should have latest_event");
assert_eq!(latest.event_id, reply_id);
assert_eq!(latest.sender.as_str(), sender_id.as_str());
}
async fn make_service(server: &MatrixMockServer) -> ThreadListService {
let client = server.client_builder().build().await;
let room_id = room_id!("!a:b.c");
let room = server.sync_joined_room(&client, room_id).await;
ThreadListService::new(room)
}
}