hocuspocus_rs_ws/sync/
awareness.rs

1use std::collections::hash_map::Entry;
2use std::collections::HashMap;
3use std::fmt::Formatter;
4use std::sync::Arc;
5use thiserror::Error;
6use yrs::block::ClientID;
7use yrs::updates::decoder::{Decode, Decoder};
8use yrs::updates::encoder::{Encode, Encoder};
9use yrs::{Doc, Observer, Subscription};
10
11const NULL_STR: &str = "null";
12
13type AwarenessObserver = Observer<Arc<dyn Fn(&Awareness, &Event) + Send + Sync + 'static>>;
14
15/// The Awareness class implements a simple shared state protocol that can be used for non-persistent
16/// data like awareness information (cursor, username, status, ..). Each client can update its own
17/// local state and listen to state changes of remote clients.
18///
19/// Each client is identified by a unique client id (something we borrow from `doc.clientID`).
20/// A client can override its own state by propagating a message with an increasing timestamp
21/// (`clock`). If such a message is received, it is applied if the known state of that client is
22/// older than the new state (`clock < new_clock`). If a client thinks that a remote client is
23/// offline, it may propagate a message with `{ clock, state: null, client }`. If such a message is
24/// received, and the known clock of that client equals the received clock, it will clean the state.
25///
26/// Before a client disconnects, it should propagate a `null` state with an updated clock.
27pub struct Awareness {
28    pub doc: Doc,
29    states: HashMap<ClientID, String>,
30    meta: HashMap<ClientID, MetaClientState>,
31    on_update: Option<AwarenessObserver>,
32}
33
34impl Awareness {
35    /// Creates a new instance of [Awareness] struct, which operates over a given document.
36    /// Awareness instance has full ownership of that document. If necessary it can be accessed
37    /// using either [Awareness::doc] or [Awareness::doc_mut] methods.
38    pub fn new(doc: Doc) -> Self {
39        Awareness {
40            doc,
41            on_update: None,
42            states: HashMap::new(),
43            meta: HashMap::new(),
44        }
45    }
46
47    /// Returns a channel receiver for an incoming awareness events. This channel can be cloned.
48    pub fn on_update<F>(&mut self, f: F) -> Subscription
49    where
50        F: Fn(&Awareness, &Event) + Send + Sync + 'static,
51    {
52        let eh = self.on_update.get_or_insert_with(Observer::default);
53        eh.subscribe(Arc::new(f))
54    }
55
56    /// Returns a read-only reference to an underlying [Doc].
57    pub fn doc(&self) -> &Doc {
58        &self.doc
59    }
60
61    /// Returns a read-write reference to an underlying [Doc].
62    pub fn doc_mut(&mut self) -> &mut Doc {
63        &mut self.doc
64    }
65
66    /// Returns a globally unique client ID of an underlying [Doc].
67    pub fn client_id(&self) -> ClientID {
68        self.doc.client_id()
69    }
70
71    /// Returns a state map of all of the clients tracked by current [Awareness] instance. Those
72    /// states are identified by their corresponding [ClientID]s. The associated state is
73    /// represented and replicated to other clients as a JSON string.
74    pub fn clients(&self) -> &HashMap<ClientID, String> {
75        &self.states
76    }
77
78    /// Returns a JSON string state representation of a current [Awareness] instance.
79    pub fn local_state(&self) -> Option<&str> {
80        Some(self.states.get(&self.doc.client_id())?.as_str())
81    }
82
83    /// Sets a current [Awareness] instance state to a corresponding JSON string. This state will
84    /// be replicated to other clients as part of the [AwarenessUpdate] and it will trigger an event
85    /// to be emitted if current instance was created using Awareness::with_observer method.
86    ///
87    pub fn set_local_state<S: Into<String>>(&mut self, json: S) {
88        let client_id = self.doc.client_id();
89        self.update_meta(client_id);
90        let new: String = json.into();
91        match self.states.entry(client_id) {
92            Entry::Occupied(mut e) => {
93                e.insert(new);
94                if let Some(eh) = self.on_update.as_ref() {
95                    let e = Event::new(vec![], vec![client_id], vec![]);
96                    eh.trigger(|cb| {
97                        cb(self, &e);
98                    });
99                }
100            }
101            Entry::Vacant(e) => {
102                e.insert(new);
103                if let Some(eh) = self.on_update.as_ref() {
104                    let e = Event::new(vec![client_id], vec![], vec![]);
105                    eh.trigger(|cb| {
106                        cb(self, &e);
107                    });
108                }
109            }
110        }
111    }
112
113    /// Clears out a state of a given client, effectively marking it as disconnected.
114    pub fn remove_state(&mut self, client_id: ClientID) {
115        let prev_state = self.states.remove(&client_id);
116        self.update_meta(client_id);
117        if let Some(eh) = self.on_update.as_ref() {
118            if prev_state.is_some() {
119                let e = Event::new(Vec::default(), Vec::default(), vec![client_id]);
120                eh.trigger(|cb| {
121                    cb(self, &e);
122                });
123            }
124        }
125    }
126
127    /// Clears out a state of a current client (see: [Awareness::client_id]),
128    /// effectively marking it as disconnected.
129    pub fn clean_local_state(&mut self) {
130        let client_id = self.doc.client_id();
131        self.remove_state(client_id);
132    }
133
134    fn update_meta(&mut self, client_id: ClientID) {
135        match self.meta.entry(client_id) {
136            Entry::Occupied(mut e) => {
137                let clock = e.get().clock + 1;
138                let meta = MetaClientState::new(clock);
139                e.insert(meta);
140            }
141            Entry::Vacant(e) => {
142                e.insert(MetaClientState::new(1));
143            }
144        }
145    }
146
147    /// Returns a serializable update object which is representation of a current Awareness state.
148    pub fn update(&self) -> Result<AwarenessUpdate, Error> {
149        let clients = self.states.keys().cloned();
150        self.update_with_clients(clients)
151    }
152
153    /// Returns a serializable update object which is representation of a current Awareness state.
154    /// Unlike [Awareness::update], this method variant allows to prepare update only for a subset
155    /// of known clients. These clients must all be known to a current [Awareness] instance,
156    /// otherwise a [Error::ClientNotFound] error will be returned.
157    pub fn update_with_clients<I: IntoIterator<Item = ClientID>>(
158        &self,
159        clients: I,
160    ) -> Result<AwarenessUpdate, Error> {
161        let mut res = HashMap::new();
162        for client_id in clients {
163            let clock = if let Some(meta) = self.meta.get(&client_id) {
164                meta.clock
165            } else {
166                return Err(Error::ClientNotFound(client_id));
167            };
168            let json = if let Some(json) = self.states.get(&client_id) {
169                json.clone()
170            } else {
171                String::from(NULL_STR)
172            };
173            res.insert(client_id, AwarenessUpdateEntry { clock, json });
174        }
175        Ok(AwarenessUpdate { clients: res })
176    }
177
178    /// Applies an update (incoming from remote channel or generated using [Awareness::update] /
179    /// [Awareness::update_with_clients] methods) and modifies a state of a current instance.
180    pub fn apply_update(&mut self, update: AwarenessUpdate) -> Result<(), Error> {
181        let mut added = Vec::new();
182        let mut updated = Vec::new();
183        let mut removed = Vec::new();
184
185        for (client_id, entry) in update.clients {
186            let mut clock = entry.clock;
187            let is_null = entry.json.as_str() == NULL_STR;
188            match self.meta.entry(client_id) {
189                Entry::Occupied(mut e) => {
190                    let prev = e.get();
191                    let is_removed =
192                        prev.clock == clock && is_null && self.states.contains_key(&client_id);
193                    let is_new = prev.clock < clock;
194                    if is_new || is_removed {
195                        if is_null {
196                            // never let a remote client remove this local state
197                            if client_id == self.doc.client_id()
198                                && self.states.contains_key(&client_id)
199                            {
200                                // remote client removed the local state. Do not remote state. Broadcast a message indicating
201                                // that this client still exists by increasing the clock
202                                clock += 1;
203                            } else {
204                                self.states.remove(&client_id);
205                                if self.on_update.is_some() {
206                                    removed.push(client_id);
207                                }
208                            }
209                        } else {
210                            match self.states.entry(client_id) {
211                                Entry::Occupied(mut e) => {
212                                    if self.on_update.is_some() {
213                                        updated.push(client_id);
214                                    }
215                                    e.insert(entry.json);
216                                }
217                                Entry::Vacant(e) => {
218                                    e.insert(entry.json);
219                                    if self.on_update.is_some() {
220                                        updated.push(client_id);
221                                    }
222                                }
223                            }
224                        }
225                        e.insert(MetaClientState::new(clock));
226                        true
227                    } else {
228                        false
229                    }
230                }
231                Entry::Vacant(e) => {
232                    e.insert(MetaClientState::new(clock));
233                    self.states.insert(client_id, entry.json);
234                    if self.on_update.is_some() {
235                        added.push(client_id);
236                    }
237                    true
238                }
239            };
240        }
241
242        if let Some(eh) = self.on_update.as_ref() {
243            if !added.is_empty() || !updated.is_empty() || !removed.is_empty() {
244                let e = Event::new(added, updated, removed);
245                eh.trigger(|cb| {
246                    cb(self, &e);
247                });
248            }
249        }
250
251        Ok(())
252    }
253}
254
255impl Default for Awareness {
256    fn default() -> Self {
257        Awareness::new(Doc::new())
258    }
259}
260
261impl std::fmt::Debug for Awareness {
262    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
263        f.debug_struct("Awareness")
264            .field("state", &self.states)
265            .field("meta", &self.meta)
266            .field("doc", &self.doc)
267            .finish()
268    }
269}
270
271/// A structure that represents an encodable state of an [Awareness] struct.
272#[derive(Debug, Eq, PartialEq)]
273pub struct AwarenessUpdate {
274    pub(crate) clients: HashMap<ClientID, AwarenessUpdateEntry>,
275}
276
277impl Encode for AwarenessUpdate {
278    fn encode<E: Encoder>(&self, encoder: &mut E) {
279        encoder.write_var(self.clients.len());
280        for (&client_id, e) in self.clients.iter() {
281            encoder.write_var(client_id);
282            encoder.write_var(e.clock);
283            encoder.write_string(&e.json);
284        }
285    }
286}
287
288impl Decode for AwarenessUpdate {
289    fn decode<D: Decoder>(decoder: &mut D) -> Result<Self, yrs::encoding::read::Error> {
290        let len: usize = decoder.read_var()?;
291        let mut clients = HashMap::with_capacity(len);
292        for _ in 0..len {
293            let client_id: ClientID = decoder.read_var()?;
294            let clock: u32 = decoder.read_var()?;
295            let json = decoder.read_string()?.to_string();
296            clients.insert(client_id, AwarenessUpdateEntry { clock, json });
297        }
298
299        Ok(AwarenessUpdate { clients })
300    }
301}
302
303/// A single client entry of an [AwarenessUpdate]. It consists of logical clock and JSON client
304/// state represented as a string.
305#[derive(Debug, Eq, PartialEq)]
306pub struct AwarenessUpdateEntry {
307    pub(crate) clock: u32,
308    pub(crate) json: String,
309}
310
311/// Errors generated by an [Awareness] struct methods.
312#[derive(Error, Debug)]
313pub enum Error {
314    /// Client ID was not found in [Awareness] metadata.
315    #[error("client ID `{0}` not found")]
316    ClientNotFound(ClientID),
317}
318
319#[derive(Debug, Clone)]
320struct MetaClientState {
321    clock: u32,
322}
323
324impl MetaClientState {
325    fn new(clock: u32) -> Self {
326        MetaClientState { clock }
327    }
328}
329
330/// Event type emitted by an [Awareness] struct.
331#[derive(Debug, Default, Clone, Eq, PartialEq)]
332pub struct Event {
333    added: Vec<ClientID>,
334    updated: Vec<ClientID>,
335    removed: Vec<ClientID>,
336}
337
338impl Event {
339    pub fn new(added: Vec<ClientID>, updated: Vec<ClientID>, removed: Vec<ClientID>) -> Self {
340        Event {
341            added,
342            updated,
343            removed,
344        }
345    }
346
347    /// Collection of new clients that have been added to an [Awareness] struct, that was not known
348    /// before. Actual client state can be accessed via `awareness.clients().get(client_id)`.
349    pub fn added(&self) -> &[ClientID] {
350        &self.added
351    }
352
353    /// Collection of new clients that have been updated within an [Awareness] struct since the last
354    /// update. Actual client state can be accessed via `awareness.clients().get(client_id)`.
355    pub fn updated(&self) -> &[ClientID] {
356        &self.updated
357    }
358
359    /// Collection of new clients that have been removed from [Awareness] struct since the last
360    /// update.
361    pub fn removed(&self) -> &[ClientID] {
362        &self.removed
363    }
364}
365
366#[cfg(test)]
367mod test {
368    use super::*;
369    use std::sync::mpsc::{channel, Receiver};
370    use yrs::Doc;
371
372    fn update(
373        recv: &mut Receiver<Event>,
374        from: &Awareness,
375        to: &mut Awareness,
376    ) -> Result<Event, Box<dyn std::error::Error>> {
377        let e = recv.try_recv()?;
378        let u = from.update_with_clients([e.added(), e.updated(), e.removed()].concat())?;
379        to.apply_update(u)?;
380        Ok(e)
381    }
382
383    #[test]
384    fn awareness() -> Result<(), Box<dyn std::error::Error>> {
385        let (s1, mut o_local) = channel();
386        let mut local = Awareness::new(Doc::with_client_id(1));
387        let _sub_local = local.on_update(move |_, e| {
388            s1.send(e.clone()).unwrap();
389        });
390
391        let (s2, o_remote) = channel();
392        let mut remote = Awareness::new(Doc::with_client_id(2));
393        let _sub_remote = local.on_update(move |_, e| {
394            s2.send(e.clone()).unwrap();
395        });
396
397        local.set_local_state("{x:3}");
398        let _e_local = update(&mut o_local, &local, &mut remote)?;
399        assert_eq!(remote.clients()[&1], "{x:3}");
400        assert_eq!(remote.meta[&1].clock, 1);
401        assert_eq!(o_remote.try_recv()?.added, &[1]);
402
403        local.set_local_state("{x:4}");
404        let e_local = update(&mut o_local, &local, &mut remote)?;
405        let e_remote = o_remote.try_recv()?;
406        assert_eq!(remote.clients()[&1], "{x:4}");
407        assert_eq!(e_remote, Event::new(vec![], vec![1], vec![]));
408        assert_eq!(e_remote, e_local);
409
410        local.clean_local_state();
411        let e_local = update(&mut o_local, &local, &mut remote)?;
412        let e_remote = o_remote.try_recv()?;
413        assert_eq!(e_remote.removed.len(), 1);
414        assert_eq!(local.clients().get(&1), None);
415        assert_eq!(e_remote, e_local);
416        Ok(())
417    }
418}