1use std::{
16 collections::{BTreeMap, BTreeSet, HashMap},
17 sync::Arc,
18};
19
20use ruma::{
21 OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, UserId,
22 events::room::member::{MembershipState, SyncRoomMemberEvent},
23};
24use tracing::{instrument, trace};
25
26use super::{DynStateStore, Result, StateChanges};
27use crate::{
28 deserialized_responses::{AmbiguityChange, DisplayName, SyncOrStrippedState},
29 store::StateStoreExt,
30};
31
32#[derive(Debug, Clone)]
34struct DisplayNameUsers {
35 display_name: DisplayName,
36 users: BTreeSet<OwnedUserId>,
37}
38
39impl DisplayNameUsers {
40 fn remove(&mut self, user_id: &UserId) -> Option<OwnedUserId> {
43 self.users.remove(user_id);
44
45 if self.user_count() == 1 { self.users.iter().next().cloned() } else { None }
46 }
47
48 fn add(&mut self, user_id: OwnedUserId) -> Option<OwnedUserId> {
51 let ambiguous_user =
52 if self.user_count() == 1 { self.users.iter().next().cloned() } else { None };
53
54 self.users.insert(user_id);
55
56 ambiguous_user
57 }
58
59 fn user_count(&self) -> usize {
61 self.users.len()
62 }
63
64 fn is_ambiguous(&self) -> bool {
66 is_display_name_ambiguous(&self.display_name, &self.users)
67 }
68}
69
70fn is_member_active(membership: &MembershipState) -> bool {
71 use MembershipState::*;
72 matches!(membership, Join | Invite | Knock)
73}
74
75#[derive(Debug)]
76pub(crate) struct AmbiguityCache {
77 pub store: Arc<DynStateStore>,
78 pub cache: BTreeMap<OwnedRoomId, HashMap<DisplayName, BTreeSet<OwnedUserId>>>,
79 pub changes: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, AmbiguityChange>>,
80}
81
82#[instrument(ret(level = "trace"))]
83pub(crate) fn is_display_name_ambiguous(
84 display_name: &DisplayName,
85 users_with_display_name: &BTreeSet<OwnedUserId>,
86) -> bool {
87 trace!("Checking if a display name is ambiguous");
88 display_name.is_inherently_ambiguous() || users_with_display_name.len() > 1
89}
90
91impl AmbiguityCache {
92 pub fn new(store: Arc<DynStateStore>) -> Self {
94 Self { store, cache: BTreeMap::new(), changes: BTreeMap::new() }
95 }
96
97 pub async fn handle_event(
99 &mut self,
100 changes: &StateChanges,
101 room_id: &RoomId,
102 member_event: &SyncRoomMemberEvent,
103 ) -> Result<()> {
104 if self.changes.get(room_id).is_some_and(|c| c.contains_key(member_event.event_id())) {
113 return Ok(());
114 }
115
116 let (mut old_map, mut new_map) =
117 self.calculate_changes(changes, room_id, member_event).await?;
118
119 let display_names_same = match (&old_map, &new_map) {
120 (Some(a), Some(b)) => a.display_name == b.display_name,
121 _ => false,
122 };
123
124 if display_names_same {
127 return Ok(());
128 }
129
130 let disambiguated_member =
131 old_map.as_mut().and_then(|o| o.remove(member_event.state_key()));
132 let ambiguated_member =
133 new_map.as_mut().and_then(|n| n.add(member_event.state_key().clone()));
134 let ambiguous = new_map.as_ref().is_some_and(|n| n.is_ambiguous());
135
136 self.update(room_id, old_map, new_map);
137
138 let change = AmbiguityChange {
139 member_id: member_event.state_key().clone(),
140 disambiguated_member,
141 ambiguated_member,
142 member_ambiguous: ambiguous,
143 };
144
145 trace!(user_id = ?member_event.state_key(), "Handling display name ambiguity: {change:#?}");
146
147 self.changes
148 .entry(room_id.to_owned())
149 .or_default()
150 .insert(member_event.event_id().to_owned(), change);
151
152 Ok(())
153 }
154
155 fn update(
158 &mut self,
159 room_id: &RoomId,
160 old_map: Option<DisplayNameUsers>,
161 new_map: Option<DisplayNameUsers>,
162 ) {
163 let entry = self.cache.entry(room_id.to_owned()).or_default();
164
165 if let Some(old) = old_map {
166 entry.insert(old.display_name, old.users);
167 }
168
169 if let Some(new) = new_map {
170 entry.insert(new.display_name, new.users);
171 }
172 }
173
174 async fn get_old_display_name(
177 &self,
178 changes: &StateChanges,
179 room_id: &RoomId,
180 new_event: &SyncRoomMemberEvent,
181 ) -> Result<Option<String>> {
182 let user_id = new_event.state_key();
183
184 let old_event = if let Some(member) = changes.member(room_id, user_id) {
185 Some(SyncOrStrippedState::Stripped(member))
186 } else {
187 self.store.get_member_event(room_id, user_id).await?.and_then(|r| r.deserialize().ok())
188 };
189
190 let Some(old_event) = old_event else { return Ok(None) };
191
192 if is_member_active(old_event.membership()) {
193 let display_name = if let Some(d) = changes
194 .profiles
195 .get(room_id)
196 .and_then(|p| p.get(user_id)?.as_original()?.content.displayname.as_deref())
197 {
198 Some(d.to_owned())
199 } else if let Some(d) = self
200 .store
201 .get_profile(room_id, user_id)
202 .await?
203 .and_then(|p| p.into_original()?.content.displayname)
204 {
205 Some(d)
206 } else {
207 old_event.original_content().and_then(|c| c.displayname.clone())
208 };
209
210 Ok(Some(display_name.unwrap_or_else(|| user_id.localpart().to_owned())))
211 } else {
212 Ok(None)
213 }
214 }
215
216 async fn get_users_with_display_name(
223 &mut self,
224 room_id: &RoomId,
225 display_name: &DisplayName,
226 ) -> Result<DisplayNameUsers> {
227 Ok(if let Some(u) = self.cache.entry(room_id.to_owned()).or_default().get(display_name) {
228 DisplayNameUsers { display_name: display_name.clone(), users: u.clone() }
229 } else {
230 let users_with_display_name =
231 self.store.get_users_with_display_name(room_id, display_name).await?;
232
233 DisplayNameUsers { display_name: display_name.clone(), users: users_with_display_name }
234 })
235 }
236
237 async fn calculate_changes(
244 &mut self,
245 changes: &StateChanges,
246 room_id: &RoomId,
247 member_event: &SyncRoomMemberEvent,
248 ) -> Result<(Option<DisplayNameUsers>, Option<DisplayNameUsers>)> {
249 let old_display_name = self.get_old_display_name(changes, room_id, member_event).await?;
250
251 let old_map = if let Some(old_name) = old_display_name.as_deref() {
252 let old_display_name = DisplayName::new(old_name);
253 Some(self.get_users_with_display_name(room_id, &old_display_name).await?)
254 } else {
255 None
256 };
257
258 let new_map = if is_member_active(member_event.membership()) {
259 let new = member_event
260 .as_original()
261 .and_then(|ev| ev.content.displayname.as_deref())
262 .unwrap_or_else(|| member_event.state_key().localpart());
263
264 let new_display_name = if member_event.sender().as_str() == member_event.state_key() {
267 new
268 } else if let Some(old) = old_display_name.as_deref() {
269 old
270 } else {
271 new
272 };
273
274 let new_display_name = DisplayName::new(new_display_name);
275
276 Some(self.get_users_with_display_name(room_id, &new_display_name).await?)
277 } else {
278 None
279 };
280
281 Ok((old_map, new_map))
282 }
283
284 #[cfg(test)]
285 fn check(&self, room_id: &RoomId, display_name: &DisplayName) -> bool {
286 self.cache
287 .get(room_id)
288 .and_then(|display_names| {
289 display_names
290 .get(display_name)
291 .map(|user_ids| is_display_name_ambiguous(display_name, user_ids))
292 })
293 .unwrap_or_else(|| {
294 panic!(
295 "The display name {:?} should be part of the cache {:?}",
296 display_name, self.cache
297 )
298 })
299 }
300}
301
302#[cfg(test)]
303mod test {
304 use matrix_sdk_test::async_test;
305 use ruma::{EventId, room_id, server_name, user_id};
306 use serde_json::json;
307
308 use super::*;
309 use crate::store::{IntoStateStore, MemoryStore};
310
311 fn generate_event(user_id: &UserId, display_name: &str) -> SyncRoomMemberEvent {
312 let server_name = server_name!("localhost");
313 serde_json::from_value(json!({
314 "content": {
315 "displayname": display_name,
316 "membership": "join"
317 },
318 "event_id": EventId::new(server_name),
319 "origin_server_ts": 152037280,
320 "sender": user_id,
321 "state_key": user_id,
322 "type": "m.room.member",
323
324 }))
325 .expect("We should be able to deserialize the static member event")
326 }
327
328 macro_rules! assert_ambiguity {
329 (
330 [ $( ($user:literal, $display_name:literal) ),* ],
331 [ $( ($check_display_name:literal, $ambiguous:expr) ),* ] $(,)?
332 ) => {
333 assert_ambiguity!(
334 [ $( ($user, $display_name) ),* ],
335 [ $( ($check_display_name, $ambiguous) ),* ],
336 "The test failed the ambiguity assertions"
337 )
338 };
339
340 (
341 [ $( ($user:literal, $display_name:literal) ),* ],
342 [ $( ($check_display_name:literal, $ambiguous:expr) ),* ],
343 $description:literal $(,)?
344 ) => {
345 let store = MemoryStore::new();
346 let mut ambiguity_cache = AmbiguityCache::new(store.into_state_store());
347
348 let changes = Default::default();
349 let room_id = room_id!("!foo:bar");
350
351 macro_rules! add_display_name {
352 ($u:literal, $n:literal) => {
353 let event = generate_event(user_id!($u), $n);
354
355 ambiguity_cache
356 .handle_event(&changes, room_id, &event)
357 .await
358 .expect("We should be able to handle a member event to calculate the ambiguity.");
359 };
360 }
361
362 macro_rules! assert_display_name_ambiguity {
363 ($n:literal, $a:expr) => {
364 let display_name = DisplayName::new($n);
365
366 if ambiguity_cache.check(room_id, &display_name) != $a {
367 let foo = if $a { "be" } else { "not be" };
368 panic!("{}: the display name {} should {} ambiguous", $description, $n, foo);
369 }
370 };
371 }
372
373 $(
374 add_display_name!($user, $display_name);
375 )*
376
377 $(
378 assert_display_name_ambiguity!($check_display_name, $ambiguous);
379 )*
380 };
381 }
382
383 #[async_test]
384 async fn test_disambiguation() {
385 assert_ambiguity!(
386 [("@alice:localhost", "alice")],
387 [("alice", false)],
388 "Alice is alone in the room"
389 );
390
391 assert_ambiguity!(
392 [("@alice:localhost", "alice")],
393 [("Alice", false)],
394 "Alice is alone in the room and has a capitalized display name"
395 );
396
397 assert_ambiguity!(
398 [("@alice:localhost", "alice"), ("@bob:localhost", "alice")],
399 [("alice", true)],
400 "Alice and bob share a display name"
401 );
402
403 assert_ambiguity!(
404 [
405 ("@alice:localhost", "alice"),
406 ("@bob:localhost", "alice"),
407 ("@carol:localhost", "carol")
408 ],
409 [("alice", true), ("carol", false)],
410 "Alice and Bob share a display name, while Carol is unique"
411 );
412
413 assert_ambiguity!(
414 [("@alice:localhost", "alice"), ("@bob:localhost", "ALICE")],
415 [("alice", true)],
416 "Alice and Bob share a display name that is differently capitalized"
417 );
418
419 assert_ambiguity!(
420 [("@alice:localhost", "alice"), ("@bob:localhost", "ะฐlice")],
421 [("alice", true)],
422 "Bob tries to impersonate Alice using a cyrillic ะฐ"
423 );
424
425 assert_ambiguity!(
426 [("@alice:localhost", "@bob:localhost"), ("@bob:localhost", "ะฐlice")],
427 [("@bob:localhost", true)],
428 "Alice tries to impersonate bob using an mxid"
429 );
430
431 assert_ambiguity!(
432 [("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "๐ฎ๐ถ๐ฝ๐ถ๐๐๐ถ๐ฝ๐๐ถ")],
433 [("Sahasrahla", true)],
434 "Bob tries to impersonate Alice using scripture symbols"
435 );
436
437 assert_ambiguity!(
438 [("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "๐๐๐ฅ๐๐ฐ๐ฏ๐๐ฅ๐ฉ๐")],
439 [("Sahasrahla", true)],
440 "Bob tries to impersonate Alice using fraktur symbols"
441 );
442
443 assert_ambiguity!(
444 [("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "โโโโโขโกโโโโ")],
445 [("Sahasrahla", true)],
446 "Bob tries to impersonate Alice using circled symbols"
447 );
448
449 assert_ambiguity!(
450 [("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "๐
๐ฐ๐ท๐ฐ๐
๐
๐ฐ๐ท๐ป๐ฐ")],
451 [("Sahasrahla", true)],
452 "Bob tries to impersonate Alice using squared symbols"
453 );
454
455 assert_ambiguity!(
456 [("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "๏ผณ๏ฝ๏ฝ๏ฝ๏ฝ๏ฝ๏ฝ๏ฝ๏ฝ๏ฝ")],
457 [("Sahasrahla", true)],
458 "Bob tries to impersonate Alice using big unicode letters"
459 );
460
461 assert_ambiguity!(
462 [("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "\u{202e}alharsahas")],
463 [("Sahasrahla", true)],
464 "Bob tries to impersonate Alice using left to right shenanigans"
465 );
466
467 assert_ambiguity!(
468 [("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "Saฬดhasrahla")],
469 [("Sahasrahla", true)],
470 "Bob tries to impersonate Alice using a diacritical mark"
471 );
472
473 assert_ambiguity!(
474 [("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "Sahas\u{200B}rahla")],
475 [("Sahasrahla", true)],
476 "Bob tries to impersonate Alice using a zero-width space"
477 );
478
479 assert_ambiguity!(
480 [("@alice:localhost", "Sahasrahla"), ("@bob:localhost", "Sahas\u{200D}rahla")],
481 [("Sahasrahla", true)],
482 "Bob tries to impersonate Alice using a zero-width space"
483 );
484
485 assert_ambiguity!(
486 [("@alice:localhost", "ff"), ("@bob:localhost", "\u{FB00}")],
487 [("ff", true)],
488 "Bob tries to impersonate Alice using a ligature"
489 );
490 }
491}