1use std::collections::hash_map::Entry;
2use std::collections::HashMap;
3use std::fmt::Formatter;
4use std::sync::Arc;
5use thiserror::Error;
6use yrs::block::ClientID;
7use yrs::updates::decoder::{Decode, Decoder};
8use yrs::updates::encoder::{Encode, Encoder};
9use yrs::{Doc, Observer, Subscription};
10
11const NULL_STR: &str = "null";
12
13type AwarenessObserver = Observer<Arc<dyn Fn(&Awareness, &Event) + Send + Sync + 'static>>;
14
15pub struct Awareness {
28 pub doc: Doc,
29 states: HashMap<ClientID, String>,
30 meta: HashMap<ClientID, MetaClientState>,
31 on_update: Option<AwarenessObserver>,
32}
33
34impl Awareness {
35 pub fn new(doc: Doc) -> Self {
39 Awareness {
40 doc,
41 on_update: None,
42 states: HashMap::new(),
43 meta: HashMap::new(),
44 }
45 }
46
47 pub fn on_update<F>(&mut self, f: F) -> Subscription
49 where
50 F: Fn(&Awareness, &Event) + Send + Sync + 'static,
51 {
52 let eh = self.on_update.get_or_insert_with(Observer::default);
53 eh.subscribe(Arc::new(f))
54 }
55
56 pub fn doc(&self) -> &Doc {
58 &self.doc
59 }
60
61 pub fn doc_mut(&mut self) -> &mut Doc {
63 &mut self.doc
64 }
65
66 pub fn client_id(&self) -> ClientID {
68 self.doc.client_id()
69 }
70
71 pub fn clients(&self) -> &HashMap<ClientID, String> {
75 &self.states
76 }
77
78 pub fn local_state(&self) -> Option<&str> {
80 Some(self.states.get(&self.doc.client_id())?.as_str())
81 }
82
83 pub fn set_local_state<S: Into<String>>(&mut self, json: S) {
88 let client_id = self.doc.client_id();
89 self.update_meta(client_id);
90 let new: String = json.into();
91 match self.states.entry(client_id) {
92 Entry::Occupied(mut e) => {
93 e.insert(new);
94 if let Some(eh) = self.on_update.as_ref() {
95 let e = Event::new(vec![], vec![client_id], vec![]);
96 eh.trigger(|cb| {
97 cb(self, &e);
98 });
99 }
100 }
101 Entry::Vacant(e) => {
102 e.insert(new);
103 if let Some(eh) = self.on_update.as_ref() {
104 let e = Event::new(vec![client_id], vec![], vec![]);
105 eh.trigger(|cb| {
106 cb(self, &e);
107 });
108 }
109 }
110 }
111 }
112
113 pub fn remove_state(&mut self, client_id: ClientID) {
115 let prev_state = self.states.remove(&client_id);
116 self.update_meta(client_id);
117 if let Some(eh) = self.on_update.as_ref() {
118 if prev_state.is_some() {
119 let e = Event::new(Vec::default(), Vec::default(), vec![client_id]);
120 eh.trigger(|cb| {
121 cb(self, &e);
122 });
123 }
124 }
125 }
126
127 pub fn clean_local_state(&mut self) {
130 let client_id = self.doc.client_id();
131 self.remove_state(client_id);
132 }
133
134 fn update_meta(&mut self, client_id: ClientID) {
135 match self.meta.entry(client_id) {
136 Entry::Occupied(mut e) => {
137 let clock = e.get().clock + 1;
138 let meta = MetaClientState::new(clock);
139 e.insert(meta);
140 }
141 Entry::Vacant(e) => {
142 e.insert(MetaClientState::new(1));
143 }
144 }
145 }
146
147 pub fn update(&self) -> Result<AwarenessUpdate, Error> {
149 let clients = self.states.keys().cloned();
150 self.update_with_clients(clients)
151 }
152
153 pub fn update_with_clients<I: IntoIterator<Item = ClientID>>(
158 &self,
159 clients: I,
160 ) -> Result<AwarenessUpdate, Error> {
161 let mut res = HashMap::new();
162 for client_id in clients {
163 let clock = if let Some(meta) = self.meta.get(&client_id) {
164 meta.clock
165 } else {
166 return Err(Error::ClientNotFound(client_id));
167 };
168 let json = if let Some(json) = self.states.get(&client_id) {
169 json.clone()
170 } else {
171 String::from(NULL_STR)
172 };
173 res.insert(client_id, AwarenessUpdateEntry { clock, json });
174 }
175 Ok(AwarenessUpdate { clients: res })
176 }
177
178 pub fn apply_update(&mut self, update: AwarenessUpdate) -> Result<(), Error> {
181 let mut added = Vec::new();
182 let mut updated = Vec::new();
183 let mut removed = Vec::new();
184
185 for (client_id, entry) in update.clients {
186 let mut clock = entry.clock;
187 let is_null = entry.json.as_str() == NULL_STR;
188 match self.meta.entry(client_id) {
189 Entry::Occupied(mut e) => {
190 let prev = e.get();
191 let is_removed =
192 prev.clock == clock && is_null && self.states.contains_key(&client_id);
193 let is_new = prev.clock < clock;
194 if is_new || is_removed {
195 if is_null {
196 if client_id == self.doc.client_id()
198 && self.states.contains_key(&client_id)
199 {
200 clock += 1;
203 } else {
204 self.states.remove(&client_id);
205 if self.on_update.is_some() {
206 removed.push(client_id);
207 }
208 }
209 } else {
210 match self.states.entry(client_id) {
211 Entry::Occupied(mut e) => {
212 if self.on_update.is_some() {
213 updated.push(client_id);
214 }
215 e.insert(entry.json);
216 }
217 Entry::Vacant(e) => {
218 e.insert(entry.json);
219 if self.on_update.is_some() {
220 updated.push(client_id);
221 }
222 }
223 }
224 }
225 e.insert(MetaClientState::new(clock));
226 true
227 } else {
228 false
229 }
230 }
231 Entry::Vacant(e) => {
232 e.insert(MetaClientState::new(clock));
233 self.states.insert(client_id, entry.json);
234 if self.on_update.is_some() {
235 added.push(client_id);
236 }
237 true
238 }
239 };
240 }
241
242 if let Some(eh) = self.on_update.as_ref() {
243 if !added.is_empty() || !updated.is_empty() || !removed.is_empty() {
244 let e = Event::new(added, updated, removed);
245 eh.trigger(|cb| {
246 cb(self, &e);
247 });
248 }
249 }
250
251 Ok(())
252 }
253}
254
255impl Default for Awareness {
256 fn default() -> Self {
257 Awareness::new(Doc::new())
258 }
259}
260
261impl std::fmt::Debug for Awareness {
262 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
263 f.debug_struct("Awareness")
264 .field("state", &self.states)
265 .field("meta", &self.meta)
266 .field("doc", &self.doc)
267 .finish()
268 }
269}
270
271#[derive(Debug, Eq, PartialEq)]
273pub struct AwarenessUpdate {
274 pub(crate) clients: HashMap<ClientID, AwarenessUpdateEntry>,
275}
276
277impl Encode for AwarenessUpdate {
278 fn encode<E: Encoder>(&self, encoder: &mut E) {
279 encoder.write_var(self.clients.len());
280 for (&client_id, e) in self.clients.iter() {
281 encoder.write_var(client_id);
282 encoder.write_var(e.clock);
283 encoder.write_string(&e.json);
284 }
285 }
286}
287
288impl Decode for AwarenessUpdate {
289 fn decode<D: Decoder>(decoder: &mut D) -> Result<Self, yrs::encoding::read::Error> {
290 let len: usize = decoder.read_var()?;
291 let mut clients = HashMap::with_capacity(len);
292 for _ in 0..len {
293 let client_id: ClientID = decoder.read_var()?;
294 let clock: u32 = decoder.read_var()?;
295 let json = decoder.read_string()?.to_string();
296 clients.insert(client_id, AwarenessUpdateEntry { clock, json });
297 }
298
299 Ok(AwarenessUpdate { clients })
300 }
301}
302
303#[derive(Debug, Eq, PartialEq)]
306pub struct AwarenessUpdateEntry {
307 pub(crate) clock: u32,
308 pub(crate) json: String,
309}
310
311#[derive(Error, Debug)]
313pub enum Error {
314 #[error("client ID `{0}` not found")]
316 ClientNotFound(ClientID),
317}
318
319#[derive(Debug, Clone)]
320struct MetaClientState {
321 clock: u32,
322}
323
324impl MetaClientState {
325 fn new(clock: u32) -> Self {
326 MetaClientState { clock }
327 }
328}
329
330#[derive(Debug, Default, Clone, Eq, PartialEq)]
332pub struct Event {
333 added: Vec<ClientID>,
334 updated: Vec<ClientID>,
335 removed: Vec<ClientID>,
336}
337
338impl Event {
339 pub fn new(added: Vec<ClientID>, updated: Vec<ClientID>, removed: Vec<ClientID>) -> Self {
340 Event {
341 added,
342 updated,
343 removed,
344 }
345 }
346
347 pub fn added(&self) -> &[ClientID] {
350 &self.added
351 }
352
353 pub fn updated(&self) -> &[ClientID] {
356 &self.updated
357 }
358
359 pub fn removed(&self) -> &[ClientID] {
362 &self.removed
363 }
364}
365
366#[cfg(test)]
367mod test {
368 use super::*;
369 use std::sync::mpsc::{channel, Receiver};
370 use yrs::Doc;
371
372 fn update(
373 recv: &mut Receiver<Event>,
374 from: &Awareness,
375 to: &mut Awareness,
376 ) -> Result<Event, Box<dyn std::error::Error>> {
377 let e = recv.try_recv()?;
378 let u = from.update_with_clients([e.added(), e.updated(), e.removed()].concat())?;
379 to.apply_update(u)?;
380 Ok(e)
381 }
382
383 #[test]
384 fn awareness() -> Result<(), Box<dyn std::error::Error>> {
385 let (s1, mut o_local) = channel();
386 let mut local = Awareness::new(Doc::with_client_id(1));
387 let _sub_local = local.on_update(move |_, e| {
388 s1.send(e.clone()).unwrap();
389 });
390
391 let (s2, o_remote) = channel();
392 let mut remote = Awareness::new(Doc::with_client_id(2));
393 let _sub_remote = local.on_update(move |_, e| {
394 s2.send(e.clone()).unwrap();
395 });
396
397 local.set_local_state("{x:3}");
398 let _e_local = update(&mut o_local, &local, &mut remote)?;
399 assert_eq!(remote.clients()[&1], "{x:3}");
400 assert_eq!(remote.meta[&1].clock, 1);
401 assert_eq!(o_remote.try_recv()?.added, &[1]);
402
403 local.set_local_state("{x:4}");
404 let e_local = update(&mut o_local, &local, &mut remote)?;
405 let e_remote = o_remote.try_recv()?;
406 assert_eq!(remote.clients()[&1], "{x:4}");
407 assert_eq!(e_remote, Event::new(vec![], vec![1], vec![]));
408 assert_eq!(e_remote, e_local);
409
410 local.clean_local_state();
411 let e_local = update(&mut o_local, &local, &mut remote)?;
412 let e_remote = o_remote.try_recv()?;
413 assert_eq!(e_remote.removed.len(), 1);
414 assert_eq!(local.clients().get(&1), None);
415 assert_eq!(e_remote, e_local);
416 Ok(())
417 }
418}