1use std::collections::{BTreeMap, BTreeSet, HashMap};
16
17use ruma::{
18 OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, UserId,
19 events::room::member::{MembershipState, SyncRoomMemberEvent},
20};
21use tracing::{instrument, trace};
22
23use super::{Result, StateChanges};
24use crate::{
25 StateStore,
26 deserialized_responses::{AmbiguityChange, DisplayName, SyncOrStrippedState},
27 store::{SaveLockedStateStore, StateStoreExt},
28};
29
30#[derive(Debug, Clone)]
32struct DisplayNameUsers {
33 display_name: DisplayName,
34 users: BTreeSet<OwnedUserId>,
35}
36
37impl DisplayNameUsers {
38 fn remove(&mut self, user_id: &UserId) -> Option<OwnedUserId> {
41 self.users.remove(user_id);
42
43 if self.user_count() == 1 { self.users.iter().next().cloned() } else { None }
44 }
45
46 fn add(&mut self, user_id: OwnedUserId) -> Option<OwnedUserId> {
49 let ambiguous_user =
50 if self.user_count() == 1 { self.users.iter().next().cloned() } else { None };
51
52 self.users.insert(user_id);
53
54 ambiguous_user
55 }
56
57 fn user_count(&self) -> usize {
59 self.users.len()
60 }
61
62 fn is_ambiguous(&self) -> bool {
64 is_display_name_ambiguous(&self.display_name, &self.users)
65 }
66}
67
68fn is_member_active(membership: &MembershipState) -> bool {
69 use MembershipState::*;
70 matches!(membership, Join | Invite | Knock)
71}
72
73#[derive(Debug)]
74pub(crate) struct AmbiguityCache {
75 pub store: SaveLockedStateStore,
76 pub cache: BTreeMap<OwnedRoomId, HashMap<DisplayName, BTreeSet<OwnedUserId>>>,
77 pub changes: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, AmbiguityChange>>,
78}
79
80#[instrument(ret(level = "trace"))]
81pub(crate) fn is_display_name_ambiguous(
82 display_name: &DisplayName,
83 users_with_display_name: &BTreeSet<OwnedUserId>,
84) -> bool {
85 trace!("Checking if a display name is ambiguous");
86 display_name.is_inherently_ambiguous() || users_with_display_name.len() > 1
87}
88
89impl AmbiguityCache {
90 pub fn new(store: SaveLockedStateStore) -> Self {
92 Self { store, cache: BTreeMap::new(), changes: BTreeMap::new() }
93 }
94
95 pub async fn handle_event(
97 &mut self,
98 changes: &StateChanges,
99 room_id: &RoomId,
100 member_event: &SyncRoomMemberEvent,
101 ) -> Result<()> {
102 if self.changes.get(room_id).is_some_and(|c| c.contains_key(member_event.event_id())) {
111 return Ok(());
112 }
113
114 let (mut old_map, mut new_map) =
115 self.calculate_changes(changes, room_id, member_event).await?;
116
117 let display_names_same = match (&old_map, &new_map) {
118 (Some(a), Some(b)) => a.display_name == b.display_name,
119 _ => false,
120 };
121
122 if display_names_same {
125 return Ok(());
126 }
127
128 let disambiguated_member =
129 old_map.as_mut().and_then(|o| o.remove(member_event.state_key()));
130 let ambiguated_member =
131 new_map.as_mut().and_then(|n| n.add(member_event.state_key().clone()));
132 let ambiguous = new_map.as_ref().is_some_and(|n| n.is_ambiguous());
133
134 self.update(room_id, old_map, new_map);
135
136 let change = AmbiguityChange {
137 member_id: member_event.state_key().clone(),
138 disambiguated_member,
139 ambiguated_member,
140 member_ambiguous: ambiguous,
141 };
142
143 trace!(user_id = ?member_event.state_key(), "Handling display name ambiguity: {change:#?}");
144
145 self.changes
146 .entry(room_id.to_owned())
147 .or_default()
148 .insert(member_event.event_id().to_owned(), change);
149
150 Ok(())
151 }
152
153 fn update(
156 &mut self,
157 room_id: &RoomId,
158 old_map: Option<DisplayNameUsers>,
159 new_map: Option<DisplayNameUsers>,
160 ) {
161 let entry = self.cache.entry(room_id.to_owned()).or_default();
162
163 if let Some(old) = old_map {
164 entry.insert(old.display_name, old.users);
165 }
166
167 if let Some(new) = new_map {
168 entry.insert(new.display_name, new.users);
169 }
170 }
171
172 async fn get_old_display_name(
175 &self,
176 changes: &StateChanges,
177 room_id: &RoomId,
178 new_event: &SyncRoomMemberEvent,
179 ) -> Result<Option<String>> {
180 let user_id = new_event.state_key();
181
182 let old_event = if let Some(member) = changes.member(room_id, user_id) {
183 Some(SyncOrStrippedState::Stripped(member))
184 } else {
185 self.store.get_member_event(room_id, user_id).await?.and_then(|r| r.deserialize().ok())
186 };
187
188 let Some(old_event) = old_event else { return Ok(None) };
189
190 if is_member_active(old_event.membership()) {
191 let display_name = if let Some(d) = changes
192 .profiles
193 .get(room_id)
194 .and_then(|p| p.get(user_id)?.content.displayname.as_deref())
195 {
196 Some(d.to_owned())
197 } else if let Some(d) =
198 self.store.get_profile(room_id, user_id).await?.and_then(|p| p.content.displayname)
199 {
200 Some(d)
201 } else {
202 old_event.displayname_value().map(ToOwned::to_owned)
203 };
204
205 Ok(Some(display_name.unwrap_or_else(|| user_id.localpart().to_owned())))
206 } else {
207 Ok(None)
208 }
209 }
210
211 async fn get_users_with_display_name(
218 &mut self,
219 room_id: &RoomId,
220 display_name: &DisplayName,
221 ) -> Result<DisplayNameUsers> {
222 Ok(if let Some(u) = self.cache.entry(room_id.to_owned()).or_default().get(display_name) {
223 DisplayNameUsers { display_name: display_name.clone(), users: u.clone() }
224 } else {
225 let users_with_display_name =
226 self.store.get_users_with_display_name(room_id, display_name).await?;
227
228 DisplayNameUsers { display_name: display_name.clone(), users: users_with_display_name }
229 })
230 }
231
232 async fn calculate_changes(
239 &mut self,
240 changes: &StateChanges,
241 room_id: &RoomId,
242 member_event: &SyncRoomMemberEvent,
243 ) -> Result<(Option<DisplayNameUsers>, Option<DisplayNameUsers>)> {
244 let old_display_name = self.get_old_display_name(changes, room_id, member_event).await?;
245
246 let old_map = if let Some(old_name) = old_display_name.as_deref() {
247 let old_display_name = DisplayName::new(old_name);
248 Some(self.get_users_with_display_name(room_id, &old_display_name).await?)
249 } else {
250 None
251 };
252
253 let new_map = if is_member_active(member_event.membership()) {
254 let new = member_event
255 .as_original()
256 .and_then(|ev| ev.content.displayname.as_deref())
257 .unwrap_or_else(|| member_event.state_key().localpart());
258
259 let new_display_name = if member_event.sender().as_str() == member_event.state_key() {
262 new
263 } else if let Some(old) = old_display_name.as_deref() {
264 old
265 } else {
266 new
267 };
268
269 let new_display_name = DisplayName::new(new_display_name);
270
271 Some(self.get_users_with_display_name(room_id, &new_display_name).await?)
272 } else {
273 None
274 };
275
276 Ok((old_map, new_map))
277 }
278
279 #[cfg(test)]
280 fn check(&self, room_id: &RoomId, display_name: &DisplayName) -> bool {
281 self.cache
282 .get(room_id)
283 .and_then(|display_names| {
284 display_names
285 .get(display_name)
286 .map(|user_ids| is_display_name_ambiguous(display_name, user_ids))
287 })
288 .unwrap_or_else(|| {
289 panic!(
290 "The display name {:?} should be part of the cache {:?}",
291 display_name, self.cache
292 )
293 })
294 }
295}
296
297#[cfg(test)]
298mod test {
299 use matrix_sdk_test::async_test;
300 use ruma::{EventId, room_id, server_name, user_id};
301 use serde_json::json;
302
303 use super::*;
304 use crate::store::{IntoStateStore, MemoryStore};
305
306 fn generate_event(user_id: &UserId, display_name: &str) -> SyncRoomMemberEvent {
307 let server_name = server_name!("localhost");
308 serde_json::from_value(json!({
309 "content": {
310 "displayname": display_name,
311 "membership": "join"
312 },
313 "event_id": EventId::new_v1(server_name),
314 "origin_server_ts": 152037280,
315 "sender": user_id,
316 "state_key": user_id,
317 "type": "m.room.member",
318
319 }))
320 .expect("We should be able to deserialize the static member event")
321 }
322
323 macro_rules! assert_ambiguity {
324 (
325 [ $( ($user:literal, $display_name:literal) ),* ],
326 [ $( ($check_display_name:literal, $ambiguous:expr) ),* ] $(,)?
327 ) => {
328 assert_ambiguity!(
329 [ $( ($user, $display_name) ),* ],
330 [ $( ($check_display_name, $ambiguous) ),* ],
331 "The test failed the ambiguity assertions"
332 )
333 };
334
335 (
336 [ $( ($user:literal, $display_name:literal) ),* ],
337 [ $( ($check_display_name:literal, $ambiguous:expr) ),* ],
338 $description:literal $(,)?
339 ) => {
340 let store = MemoryStore::new();
341 let mut ambiguity_cache = AmbiguityCache::new(SaveLockedStateStore::new(store.into_state_store()));
342
343 let changes = Default::default();
344 let room_id = room_id!("!foo:bar");
345
346 macro_rules! add_display_name {
347 ($u:literal, $n:literal) => {
348 let event = generate_event(user_id!($u), $n);
349
350 ambiguity_cache
351 .handle_event(&changes, room_id, &event)
352 .await
353 .expect("We should be able to handle a member event to calculate the ambiguity.");
354 };
355 }
356
357 macro_rules! assert_display_name_ambiguity {
358 ($n:literal, $a:expr) => {
359 let display_name = DisplayName::new($n);
360
361 if ambiguity_cache.check(room_id, &display_name) != $a {
362 let foo = if $a { "be" } else { "not be" };
363 panic!("{}: the display name {} should {} ambiguous", $description, $n, foo);
364 }
365 };
366 }
367
368 $(
369 add_display_name!($user, $display_name);
370 )*
371
372 $(
373 assert_display_name_ambiguity!($check_display_name, $ambiguous);
374 )*
375 };
376 }
377
378 #[async_test]
379 async fn test_disambiguation() {
380 assert_ambiguity!(
381 [("@alice:localhost", "alice")],
382 [("alice", false)],
383 "Alice is alone in the room"
384 );
385
386 assert_ambiguity!(
387 [("@alice:localhost", "alice")],
388 [("Alice", false)],
389 "Alice is alone in the room and has a capitalized display name"
390 );
391
392 assert_ambiguity!(
393 [("@alice:localhost", "alice"), ("@bob:localhost", "alice")],
394 [("alice", true)],
395 "Alice and bob share a display name"
396 );
397
398 assert_ambiguity!(
399 [
400 ("@alice:localhost", "alice"),
401 ("@bob:localhost", "alice"),
402 ("@carol:localhost", "carol")
403 ],
404 [("alice", true), ("carol", false)],
405 "Alice and Bob share a display name, while Carol is unique"
406 );
407
408 assert_ambiguity!(
409 [("@alice:localhost", "alice"), ("@bob:localhost", "ALICE")],
410 [("alice", true)],
411 "Alice and Bob share a display name that is differently capitalized"
412 );
413
414 assert_ambiguity!(
415 [("@alice:localhost", "alice"), ("@bob:localhost", "ะฐlice")],
416 [("alice", true)],
417 "Bob tries to impersonate Alice using a cyrillic ะฐ"
418 );
419
420 assert_ambiguity!(
421 [("@alice:localhost", "@bob:localhost"), ("@bob:localhost", "ะฐlice")],
422 [("@bob:localhost", true)],
423 "Alice tries to impersonate bob using an mxid"
424 );
425
426 assert_ambiguity!(
427 [("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "๐ฎ๐ถ๐ฝ๐ถ๐๐๐ถ๐ฝ๐๐ถ")],
428 [("Sahasrahla", true)],
429 "Bob tries to impersonate Alice using scripture symbols"
430 );
431
432 assert_ambiguity!(
433 [("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "๐๐๐ฅ๐๐ฐ๐ฏ๐๐ฅ๐ฉ๐")],
434 [("Sahasrahla", true)],
435 "Bob tries to impersonate Alice using fraktur symbols"
436 );
437
438 assert_ambiguity!(
439 [("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "โโโโโขโกโโโโ")],
440 [("Sahasrahla", true)],
441 "Bob tries to impersonate Alice using circled symbols"
442 );
443
444 assert_ambiguity!(
445 [("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "๐
๐ฐ๐ท๐ฐ๐
๐
๐ฐ๐ท๐ป๐ฐ")],
446 [("Sahasrahla", true)],
447 "Bob tries to impersonate Alice using squared symbols"
448 );
449
450 assert_ambiguity!(
451 [("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "๏ผณ๏ฝ๏ฝ๏ฝ๏ฝ๏ฝ๏ฝ๏ฝ๏ฝ๏ฝ")],
452 [("Sahasrahla", true)],
453 "Bob tries to impersonate Alice using big unicode letters"
454 );
455
456 assert_ambiguity!(
457 [("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "\u{202e}alharsahas")],
458 [("Sahasrahla", true)],
459 "Bob tries to impersonate Alice using left to right shenanigans"
460 );
461
462 assert_ambiguity!(
463 [("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "Saฬดhasrahla")],
464 [("Sahasrahla", true)],
465 "Bob tries to impersonate Alice using a diacritical mark"
466 );
467
468 assert_ambiguity!(
469 [("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "Sahas\u{200B}rahla")],
470 [("Sahasrahla", true)],
471 "Bob tries to impersonate Alice using a zero-width space"
472 );
473
474 assert_ambiguity!(
475 [("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "Sahas\u{200D}rahla")],
476 [("Sahasrahla", true)],
477 "Bob tries to impersonate Alice using a zero-width space"
478 );
479
480 assert_ambiguity!(
481 [("@alice:localhost", "ff"), ("@bob:localhost", "\u{FB00}")],
482 [("ff", true)],
483 "Bob tries to impersonate Alice using a ligature"
484 );
485 }
486}