use std::collections::HashSet;
use matrix_sdk_base::{RoomStateFilter, deserialized_responses::TimelineEvent};
use matrix_sdk_search::error::IndexError;
#[cfg(doc)]
use matrix_sdk_search::index::RoomIndex;
use ruma::{OwnedEventId, OwnedRoomId};
use crate::{Client, Room};
impl Room {
pub async fn search(
&self,
query: &str,
max_number_of_results: usize,
pagination_offset: Option<usize>,
) -> Result<Vec<OwnedEventId>, IndexError> {
let mut search_index_guard = self.client.search_index().lock().await;
search_index_guard.search(query, max_number_of_results, pagination_offset, self.room_id())
}
}
#[derive(thiserror::Error, Debug)]
pub enum SearchError {
#[error(transparent)]
IndexError(#[from] IndexError),
#[error(transparent)]
EventLoadError(#[from] crate::Error),
}
impl Room {
pub fn search_messages(
&self,
query: String,
num_results_per_batch: usize,
) -> RoomSearchIterator {
RoomSearchIterator {
room: self.clone(),
query,
offset: None,
is_done: false,
num_results_per_batch,
}
}
}
#[derive(Debug)]
pub struct RoomSearchIterator {
room: Room,
query: String,
offset: Option<usize>,
is_done: bool,
num_results_per_batch: usize,
}
impl RoomSearchIterator {
pub async fn next(&mut self) -> Result<Option<Vec<OwnedEventId>>, IndexError> {
if self.is_done {
return Ok(None);
}
let result = self.room.search(&self.query, self.num_results_per_batch, self.offset).await?;
if result.is_empty() {
self.is_done = true;
Ok(None)
} else {
self.offset = Some(self.offset.unwrap_or(0) + result.len());
Ok(Some(result))
}
}
pub async fn next_events(&mut self) -> Result<Option<Vec<TimelineEvent>>, SearchError> {
let Some(event_ids) = self.next().await? else {
return Ok(None);
};
let mut results = Vec::new();
for event_id in event_ids {
results.push(self.room.load_or_fetch_event(&event_id, None).await?);
}
Ok(Some(results))
}
}
#[derive(Debug)]
struct GlobalSearchRoomState {
room: Room,
offset: Option<usize>,
}
impl GlobalSearchRoomState {
fn new(room: Room) -> Self {
Self { room, offset: None }
}
}
#[derive(Debug)]
pub struct GlobalSearchBuilder {
client: Client,
query: String,
num_results_per_batch: usize,
room_set: Vec<Room>,
}
impl GlobalSearchBuilder {
fn new(client: Client, query: String, num_results_per_batch: usize) -> Self {
let room_set = client.rooms_filtered(RoomStateFilter::JOINED);
Self { client, query, room_set, num_results_per_batch }
}
pub async fn only_dm_rooms(mut self) -> Result<Self, crate::Error> {
let mut to_remove = HashSet::new();
for room in &self.room_set {
if !room.compute_is_dm().await? {
to_remove.insert(room.room_id().to_owned());
}
}
self.room_set.retain(|room| !to_remove.contains(room.room_id()));
Ok(self)
}
pub async fn no_dms(mut self) -> Result<Self, crate::Error> {
let mut to_remove = HashSet::new();
for room in &self.room_set {
if room.compute_is_dm().await? {
to_remove.insert(room.room_id().to_owned());
}
}
self.room_set.retain(|room| !to_remove.contains(room.room_id()));
Ok(self)
}
pub fn build(self) -> GlobalSearchIterator {
GlobalSearchIterator {
client: self.client,
query: self.query,
room_state: Vec::from_iter(self.room_set.into_iter().map(GlobalSearchRoomState::new)),
current_batch: Vec::new(),
num_results_per_batch: self.num_results_per_batch,
}
}
}
impl Client {
pub fn search_messages(
&self,
query: String,
num_results_per_batch: usize,
) -> GlobalSearchBuilder {
GlobalSearchBuilder::new(self.clone(), query, num_results_per_batch)
}
}
#[derive(Debug)]
pub struct GlobalSearchIterator {
client: Client,
query: String,
room_state: Vec<GlobalSearchRoomState>,
current_batch: Vec<(OwnedRoomId, OwnedEventId)>,
num_results_per_batch: usize,
}
impl GlobalSearchIterator {
pub async fn next(&mut self) -> Result<Option<Vec<(OwnedRoomId, OwnedEventId)>>, SearchError> {
if self.room_state.is_empty() {
return Ok(None);
}
if self.current_batch.len() >= self.num_results_per_batch {
return Ok(Some(self.current_batch.drain(0..self.num_results_per_batch).collect()));
}
let mut to_remove = HashSet::new();
for room_state in &mut self.room_state {
let room_results = room_state
.room
.search(&self.query, self.num_results_per_batch, room_state.offset)
.await?;
if room_results.is_empty() {
to_remove.insert(room_state.room.room_id().to_owned());
} else {
room_state.offset = Some(room_state.offset.unwrap_or(0) + room_results.len());
self.current_batch.extend(
room_results
.into_iter()
.map(|event_id| (room_state.room.room_id().to_owned(), event_id)),
);
if self.current_batch.len() >= self.num_results_per_batch {
break;
}
}
}
for room_id in to_remove {
self.room_state.retain(|room_state| room_state.room.room_id() != room_id);
}
if !self.current_batch.is_empty() {
let high = self.num_results_per_batch.min(self.current_batch.len());
Ok(Some(self.current_batch.drain(0..high).collect()))
} else {
debug_assert!(self.room_state.is_empty());
Ok(None)
}
}
pub async fn next_events(
&mut self,
) -> Result<Option<Vec<(OwnedRoomId, TimelineEvent)>>, SearchError> {
let Some(event_ids) = self.next().await? else {
return Ok(None);
};
let mut results = Vec::with_capacity(event_ids.len());
for (room_id, event_id) in event_ids {
let Some(room) = self.client.get_room(&room_id) else {
continue;
};
results.push((room_id, room.load_or_fetch_event(&event_id, None).await?));
}
Ok(Some(results))
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use matrix_sdk_test::{BOB, JoinedRoomBuilder, async_test, event_factory::EventFactory};
use ruma::{event_id, room_id, user_id};
use crate::{sleep::sleep, test_utils::mocks::MatrixMockServer};
#[async_test]
async fn test_room_message_search() {
let server = MatrixMockServer::new().await;
let client = server.client_builder().build().await;
let event_cache = client.event_cache();
event_cache.subscribe().unwrap();
let room_id = room_id!("!room_id:localhost");
let room = server.sync_joined_room(&client, room_id).await;
let f = EventFactory::new().room(room_id).sender(user_id!("@user_id:localhost"));
let event_id = event_id!("$event_id:localhost");
server
.sync_room(
&client,
JoinedRoomBuilder::new(room_id)
.add_timeline_event(f.text_msg("hello world").event_id(event_id)),
)
.await;
sleep(Duration::from_millis(200)).await;
{
let mut room_search = room.search_messages("search query".to_owned(), 5);
let maybe_results = room_search.next().await.unwrap();
assert!(maybe_results.is_none());
let maybe_results = room_search.next().await.unwrap();
assert!(maybe_results.is_none());
}
{
let mut room_search = room.search_messages("world".to_owned(), 5);
let maybe_results = room_search.next().await.unwrap();
let results = maybe_results.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(&results[0], event_id,);
let maybe_results = room_search.next().await.unwrap();
assert!(maybe_results.is_none());
}
{
let mut room_search = room.search_messages("world".to_owned(), 5);
let maybe_results = room_search.next_events().await.unwrap();
let results = maybe_results.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].event_id().as_deref().unwrap(), event_id,);
let maybe_results = room_search.next_events().await.unwrap();
assert!(maybe_results.is_none());
}
}
#[async_test]
async fn test_global_message_search() {
let server = MatrixMockServer::new().await;
let client = server.client_builder().build().await;
let event_cache = client.event_cache();
event_cache.subscribe().unwrap();
let room_id1 = room_id!("!r1:localhost");
let room_id2 = room_id!("!r2:localhost");
let f = EventFactory::new().sender(user_id!("@user_id:localhost"));
let result_event_id1 = event_id!("$result1:localhost");
let result_event_id2 = event_id!("$result2:localhost");
server
.mock_sync()
.ok_and_run(&client, |sync_builder| {
sync_builder
.add_joined_room(
JoinedRoomBuilder::new(room_id1)
.add_timeline_event(
f.text_msg("hello world").room(room_id1).event_id(result_event_id1),
)
.add_timeline_event(f.text_msg("hello back").room(room_id1)),
)
.add_joined_room(JoinedRoomBuilder::new(room_id2).add_timeline_event(
f.text_msg("it's a mad world").room(room_id2).event_id(result_event_id2),
));
})
.await;
sleep(Duration::from_millis(200)).await;
{
let mut search = client.search_messages("search query".to_owned(), 5).build();
let maybe_results = search.next().await.unwrap();
assert!(maybe_results.is_none());
let maybe_results = search.next().await.unwrap();
assert!(maybe_results.is_none());
}
{
let mut search = client.search_messages("world".to_owned(), 5).build();
let maybe_results = search.next().await.unwrap();
let results = maybe_results.unwrap();
assert_eq!(results.len(), 2);
assert!(results.contains(&(room_id1.to_owned(), result_event_id1.to_owned())));
assert!(results.contains(&(room_id2.to_owned(), result_event_id2.to_owned())));
let maybe_results = search.next().await.unwrap();
assert!(maybe_results.is_none());
}
{
let mut search = client.search_messages("world".to_owned(), 5).build();
let maybe_results = search.next_events().await.unwrap();
let results = maybe_results.unwrap();
assert_eq!(results.len(), 2);
assert!(results.iter().any(|(room_id, event)| {
room_id == room_id1 && event.event_id().as_deref() == Some(result_event_id1)
}));
assert!(results.iter().any(|(room_id, event)| {
room_id == room_id2 && event.event_id().as_deref() == Some(result_event_id2)
}));
let maybe_results = search.next_events().await.unwrap();
assert!(maybe_results.is_none());
}
}
#[async_test]
async fn test_global_message_search_dm_or_groups() {
let server = MatrixMockServer::new().await;
let client = server.client_builder().build().await;
let event_cache = client.event_cache();
event_cache.subscribe().unwrap();
let room_id1 = room_id!("!r1:localhost");
let room_id2 = room_id!("!r2:localhost");
let f = EventFactory::new().sender(user_id!("@user_id:localhost"));
let result_event_id1 = event_id!("$result1:localhost");
let result_event_id2 = event_id!("$result2:localhost");
server
.mock_sync()
.ok_and_run(&client, |sync_builder| {
sync_builder
.add_joined_room(
JoinedRoomBuilder::new(room_id1)
.add_timeline_event(
f.text_msg("hello world").room(room_id1).event_id(result_event_id1),
)
.add_timeline_event(f.text_msg("hello back").room(room_id1)),
)
.add_joined_room(JoinedRoomBuilder::new(room_id2).add_timeline_event(
f.text_msg("it's a mad world").room(room_id2).event_id(result_event_id2),
))
.add_global_account_data(
f.direct().add_user((*BOB).to_owned().into(), room_id1),
);
})
.await;
sleep(Duration::from_millis(200)).await;
{
let mut search = client
.search_messages("world".to_owned(), 5)
.only_dm_rooms()
.await
.unwrap()
.build();
let maybe_results = search.next().await.unwrap();
let results = maybe_results.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(&results[0], &(room_id1.to_owned(), result_event_id1.to_owned()));
let maybe_results = search.next().await.unwrap();
assert!(maybe_results.is_none());
}
{
let mut search =
client.search_messages("world".to_owned(), 5).no_dms().await.unwrap().build();
let maybe_results = search.next_events().await.unwrap();
let results = maybe_results.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, room_id2);
assert_eq!(results[0].1.event_id().as_deref().unwrap(), result_event_id2);
let maybe_results = search.next().await.unwrap();
assert!(maybe_results.is_none());
}
}
}