hocuspocus_rs_ws/sync/
mod.rs

1//! Forked from [y-sync](https://github.com/y-crdt/y-sync/tree/master)
2
3pub mod awareness;
4
5use awareness::{Awareness, AwarenessUpdate};
6use thiserror::Error;
7use tracing::debug;
8use yrs::updates::decoder::{Decode, Decoder};
9use yrs::updates::encoder::{Encode, Encoder};
10use yrs::{ReadTxn, StateVector, Transact, Update};
11
12/*
13 Core Yjs defines two message types:
14 • YjsSyncStep1: Includes the State Set of the sending client. When received, the client should reply with YjsSyncStep2.
15 • YjsSyncStep2: Includes all missing structs and the complete delete set. When received, the client is assured that it
16   received all information from the remote client.
17
18 In a peer-to-peer network, you may want to introduce a SyncDone message type. Both parties should initiate the connection
19 with SyncStep1. When a client received SyncStep2, it should reply with SyncDone. When the local client received both
20 SyncStep2 and SyncDone, it is assured that it is synced to the remote client.
21
22 In a client-server model, you want to handle this differently: The client should initiate the connection with SyncStep1.
23 When the server receives SyncStep1, it should reply with SyncStep2 immediately followed by SyncStep1. The client replies
24 with SyncStep2 when it receives SyncStep1. Optionally the server may send a SyncDone after it received SyncStep2, so the
25 client knows that the sync is finished.  There are two reasons for this more elaborated sync model: 1. This protocol can
26 easily be implemented on top of http and websockets. 2. The server should only reply to requests, and not initiate them.
27 Therefore, it is necessary that the client initiates the sync.
28
29 Construction of a message:
30 [messageType : varUint, message definition..]
31
32 Note: A message does not include information about the room name. This must be handled by the upper layer protocol!
33
34 stringify[messageType] stringifies a message definition (messageType is already read from the buffer)
35*/
36
37/// A default implementation of y-sync [Protocol].
38pub struct DefaultProtocol;
39
40impl Protocol for DefaultProtocol {}
41
42/// Trait implementing a y-sync protocol. The default implementation can be found in
43/// [DefaultProtocol], but its implementation steps can be potentially changed by the user if
44/// necessary.
45pub trait Protocol {
46    /// To be called whenever a new connection has been accepted. Returns an encoded list of
47    /// messages to be sent back to initiator. This binary may contain multiple messages inside,
48    /// stored one after another.
49    fn start<E: Encoder>(&self, awareness: &Awareness, encoder: &mut E) -> Result<(), Error> {
50        let (sv, update) = {
51            let sv = awareness.doc().transact().state_vector();
52            let update = awareness.update()?;
53            (sv, update)
54        };
55        Message::Sync(SyncMessage::SyncStep1(sv)).encode(encoder);
56        Message::Awareness(update).encode(encoder);
57        Ok(())
58    }
59
60    /// Y-sync protocol sync-step-1 - given a [StateVector] of a remote side, calculate missing
61    /// updates. Returns a sync-step-2 message containing a calculated update.
62    fn handle_sync_step1(
63        &self,
64        awareness: &Awareness,
65        sv: StateVector,
66    ) -> Result<Option<Message>, Error> {
67        let update = awareness.doc().transact().encode_state_as_update_v1(&sv);
68        debug!("pray - handle_sync_step1 - update: {:?}", update);
69        Ok(Some(Message::Sync(SyncMessage::SyncStep2(update))))
70    }
71
72    fn sync_step1(&self, awareness: &Awareness) -> Result<Message, Error> {
73        let sv = awareness.doc().transact().state_vector();
74        Ok(Message::Sync(SyncMessage::SyncStep1(sv)))
75    }
76
77    fn init_awareness(&self, awareness: &Awareness) -> Result<Message, Error> {
78        let update = awareness.update()?;
79        Ok(Message::Awareness(update))
80    }
81
82    fn awareness(&self, awareness: &Awareness) -> Result<Message, Error> {
83        let update = awareness.update()?;
84        Ok(Message::Awareness(update))
85    }
86
87    /// Handle reply for a sync-step-1 send from this replica previously. By default, just apply
88    /// an update to current `awareness` document instance.
89    fn handle_sync_step2(
90        &self,
91        awareness: &mut Awareness,
92        update: Update,
93    ) -> Result<Option<Message>, Error> {
94        let mut txn = awareness.doc().transact_mut();
95        txn.apply_update(update);
96        Ok(Some(Message::SyncStatus(true)))
97    }
98
99    /// Handle continuous update send from the client. By default just apply an update to a current
100    /// `awareness` document instance.
101    fn handle_update(
102        &self,
103        awareness: &mut Awareness,
104        update: Update,
105    ) -> Result<Option<Message>, Error> {
106        self.handle_sync_step2(awareness, update)
107    }
108
109    fn handle_auth_success(&self, _awareness: &Awareness, read_write: bool) -> Message {
110        Message::Auth(
111            if read_write {
112                Some("read-write".to_owned())
113            } else {
114                Some("read-only".to_owned())
115            },
116            true,
117        )
118    }
119
120    /// Handle authorization message. By default, if reason for auth denial has been provided,
121    /// send back [Error::PermissionDenied].
122    fn handle_auth_fail(
123        &self,
124        _awareness: &Awareness,
125    ) -> Message {
126        Message::Auth(
127            None,
128            false,
129        )
130    }
131
132    /// Returns an [AwarenessUpdate] which is a serializable representation of a current `awareness`
133    /// instance.
134    fn handle_awareness_query(&self, awareness: &Awareness) -> Result<Option<Message>, Error> {
135        let update = awareness.update()?;
136        Ok(Some(Message::Awareness(update)))
137    }
138
139    /// Reply to awareness query or just incoming [AwarenessUpdate], where current `awareness`
140    /// instance is being updated with incoming data.
141    fn handle_awareness_update(
142        &self,
143        awareness: &mut Awareness,
144        update: AwarenessUpdate,
145    ) -> Result<Option<Message>, Error> {
146        awareness.apply_update(update)?;
147        Ok(None)
148    }
149
150    /// Y-sync protocol enables to extend its own settings with custom handles. These can be
151    /// implemented here. By default, it returns an [Error::Unsupported].
152    fn missing_handle(
153        &self,
154        _awareness: &mut Awareness,
155        tag: u8,
156        _data: Vec<u8>,
157    ) -> Result<Option<Message>, Error> {
158        Err(Error::Unsupported(tag))
159    }
160}
161
162/// Tag id for [Message::Sync].
163pub const MSG_SYNC: u8 = 0;
164/// Tag id for [Message::Awareness].
165pub const MSG_AWARENESS: u8 = 1;
166/// Tag id for [Message::Auth].
167pub const MSG_AUTH: u8 = 2;
168/// Tag id for [Message::AwarenessQuery].
169pub const MSG_QUERY_AWARENESS: u8 = 3;
170/// Tag id for [Message::SyncStatus].
171pub const MSG_SYNC_STATUS: u8 = 8;
172
173/// authentication result codes
174pub const PERMISSION_DENIED: u8 = 0; // this serverside only, client side use this TOKEN
175pub const PERMISSION_GRANTED: u8 = 1;
176pub const AUTHENTICATED: u8 = 2;
177
178#[derive(Debug, Eq, PartialEq)]
179pub enum Message {
180    Sync(SyncMessage),
181    Auth(Option<String>, bool),
182    AwarenessQuery,
183    Awareness(AwarenessUpdate),
184    SyncStatus(bool),
185    Custom(u8, Vec<u8>),
186}
187
188impl Encode for Message {
189    fn encode<E: Encoder>(&self, encoder: &mut E) {
190        match self {
191            Message::Sync(msg) => {
192                encoder.write_var(MSG_SYNC);
193                msg.encode(encoder);
194            }
195            Message::Auth(reason, authenticated) => {
196                encoder.write_var(MSG_AUTH);
197                if *authenticated {
198                    encoder.write_var(AUTHENTICATED);
199                } else {
200                    encoder.write_var(PERMISSION_DENIED);
201                }
202                if let Some(reason) = reason {
203                    encoder.write_string(reason);
204                }
205            }
206            Message::AwarenessQuery => {
207                encoder.write_var(MSG_QUERY_AWARENESS);
208            }
209            Message::Awareness(update) => {
210                encoder.write_var(MSG_AWARENESS);
211                encoder.write_buf(update.encode_v1())
212            }
213            Message::SyncStatus(connected) => {
214                encoder.write_var(MSG_SYNC_STATUS);
215                encoder.write_var(*connected as u8);
216            }
217            Message::Custom(tag, data) => {
218                encoder.write_u8(*tag);
219                encoder.write_buf(data);
220            }
221        }
222    }
223}
224
225impl Decode for Message {
226    fn decode<D: Decoder>(decoder: &mut D) -> Result<Self, yrs::encoding::read::Error> {
227        let tag: u8 = decoder.read_var()?;
228        match tag {
229            MSG_SYNC => {
230                let msg = SyncMessage::decode(decoder)?;
231                Ok(Message::Sync(msg))
232            }
233            MSG_AWARENESS => {
234                let data = decoder.read_buf()?;
235                let update = AwarenessUpdate::decode_v1(data)?;
236                Ok(Message::Awareness(update))
237            }
238            MSG_AUTH => {
239                let token = if decoder.read_var::<u8>()? == PERMISSION_DENIED {
240                    Some(decoder.read_string()?.to_string())
241                } else {
242                    None
243                };
244                Ok(Message::Auth(token, false))
245            }
246            MSG_QUERY_AWARENESS => Ok(Message::AwarenessQuery),
247            tag => {
248                let data = decoder.read_buf()?;
249                Ok(Message::Custom(tag, data.to_vec()))
250            }
251        }
252    }
253}
254
255/// Tag id for [SyncMessage::SyncStep1].
256pub const MSG_SYNC_STEP_1: u8 = 0;
257/// Tag id for [SyncMessage::SyncStep2].
258pub const MSG_SYNC_STEP_2: u8 = 1;
259/// Tag id for [SyncMessage::Update].
260pub const MSG_SYNC_UPDATE: u8 = 2;
261
262#[derive(Debug, PartialEq, Eq)]
263pub enum SyncMessage {
264    SyncStep1(StateVector),
265    SyncStep2(Vec<u8>),
266    Update(Vec<u8>),
267}
268
269impl Encode for SyncMessage {
270    fn encode<E: Encoder>(&self, encoder: &mut E) {
271        match self {
272            SyncMessage::SyncStep1(sv) => {
273                encoder.write_var(MSG_SYNC_STEP_1);
274                encoder.write_buf(sv.encode_v1());
275            }
276            SyncMessage::SyncStep2(u) => {
277                encoder.write_var(MSG_SYNC_STEP_2);
278                encoder.write_buf(u);
279            }
280            SyncMessage::Update(u) => {
281                encoder.write_var(MSG_SYNC_UPDATE);
282                encoder.write_buf(u);
283            }
284        }
285    }
286}
287
288impl Decode for SyncMessage {
289    fn decode<D: Decoder>(decoder: &mut D) -> Result<Self, yrs::encoding::read::Error> {
290        let tag: u8 = decoder.read_var()?;
291        match tag {
292            MSG_SYNC_STEP_1 => {
293                let buf = decoder.read_buf()?;
294                let sv = StateVector::decode_v1(buf)?;
295                Ok(SyncMessage::SyncStep1(sv))
296            }
297            MSG_SYNC_STEP_2 => {
298                let buf = decoder.read_buf()?;
299                Ok(SyncMessage::SyncStep2(buf.into()))
300            }
301            MSG_SYNC_UPDATE => {
302                let buf = decoder.read_buf()?;
303                Ok(SyncMessage::Update(buf.into()))
304            }
305            _ => Err(yrs::encoding::read::Error::UnexpectedValue),
306        }
307    }
308}
309
310/// An error type returned in response from y-sync [Protocol].
311#[derive(Debug, Error)]
312pub enum Error {
313    /// Incoming Y-protocol message couldn't be deserialized.
314    #[error("failed to deserialize message: {0}")]
315    EncodingError(#[from] yrs::encoding::read::Error),
316
317    /// Applying incoming Y-protocol awareness update has failed.
318    #[error("failed to process awareness update: {0}")]
319    AwarenessEncoding(#[from] awareness::Error),
320
321    /// An incoming Y-protocol authorization request has been denied.
322    #[error("permission denied to access: {reason}")]
323    PermissionDenied { reason: String },
324
325    /// Thrown whenever an unknown message tag has been sent.
326    #[error("unsupported message tag identifier: {0}")]
327    Unsupported(u8),
328
329    /// Custom dynamic kind of error, usually related to a warp internal error messages.
330    #[error("internal failure: {0}")]
331    Other(#[from] Box<dyn std::error::Error + Send + Sync>),
332
333    /// Generic error, usually used for custom protocol handling.
334    #[error("{0}")]
335    Anyhow(#[from] anyhow::Error),
336}
337
338/// Since y-sync protocol enables for a multiple messages to be packed into a singe byte payload,
339/// [MessageReader] can be used over the decoder to read these messages one by one in iterable
340/// fashion.
341pub struct MessageReader<'a, D: Decoder>(&'a mut D);
342
343impl<'a, D: Decoder> MessageReader<'a, D> {
344    pub fn new(decoder: &'a mut D) -> Self {
345        MessageReader(decoder)
346    }
347}
348
349impl<'a, D: Decoder> Iterator for MessageReader<'a, D> {
350    type Item = Result<Message, yrs::encoding::read::Error>;
351
352    fn next(&mut self) -> Option<Self::Item> {
353        match Message::decode(self.0) {
354            Ok(msg) => Some(Ok(msg)),
355            Err(yrs::encoding::read::Error::EndOfBuffer(_)) => None,
356            Err(error) => Some(Err(error)),
357        }
358    }
359}
360
361#[cfg(test)]
362mod test {
363    use super::{Message, SyncMessage};
364    use crate::sync::awareness::Awareness;
365    use crate::sync::{DefaultProtocol, MessageReader, Protocol};
366    use std::collections::HashMap;
367    use yrs::encoding::read::Cursor;
368    use yrs::updates::decoder::{Decode, DecoderV1};
369    use yrs::updates::encoder::{Encode, Encoder, EncoderV1};
370    use yrs::{Doc, GetString, ReadTxn, StateVector, Text, Transact, Update};
371
372    #[test]
373    fn message_encoding() {
374        let doc = Doc::new();
375        let txt = doc.get_or_insert_text("text");
376        txt.push(&mut doc.transact_mut(), "hello world");
377        let mut awareness = Awareness::new(doc);
378        awareness.set_local_state("{\"user\":{\"name\":\"Anonymous 50\",\"color\":\"#30bced\",\"colorLight\":\"#30bced33\"}}");
379
380        let messages = [
381            Message::Sync(SyncMessage::SyncStep1(
382                awareness.doc().transact().state_vector(),
383            )),
384            Message::Sync(SyncMessage::SyncStep2(
385                awareness
386                    .doc()
387                    .transact()
388                    .encode_state_as_update_v1(&StateVector::default()),
389            )),
390            Message::Awareness(awareness.update().unwrap()),
391            Message::Auth(Some("reason".to_string()), false),
392            Message::AwarenessQuery,
393        ];
394
395        for msg in messages {
396            let encoded = msg.encode_v1();
397            let decoded = Message::decode_v1(&encoded)
398                .unwrap_or_else(|_| panic!("failed to decode {:?}", msg));
399            assert_eq!(decoded, msg);
400        }
401    }
402
403    #[test]
404    fn protocol_init() {
405        let awareness = Awareness::default();
406        let protocol = DefaultProtocol;
407        let mut encoder = EncoderV1::new();
408        protocol.start(&awareness, &mut encoder).unwrap();
409        let data = encoder.to_vec();
410        let mut decoder = DecoderV1::new(Cursor::new(&data));
411        let mut reader = MessageReader::new(&mut decoder);
412
413        assert_eq!(
414            reader.next().unwrap().unwrap(),
415            Message::Sync(SyncMessage::SyncStep1(StateVector::default()))
416        );
417
418        assert_eq!(
419            reader.next().unwrap().unwrap(),
420            Message::Awareness(awareness.update().unwrap())
421        );
422
423        assert!(reader.next().is_none());
424    }
425
426    #[test]
427    fn protocol_sync_steps() {
428        let protocol = DefaultProtocol;
429
430        let mut a1 = Awareness::new(Doc::with_client_id(1));
431        let mut a2 = Awareness::new(Doc::with_client_id(2));
432
433        let expected = {
434            let txt = a1.doc_mut().get_or_insert_text("test");
435            let mut txn = a1.doc_mut().transact_mut();
436            txt.push(&mut txn, "hello");
437            txn.encode_state_as_update_v1(&StateVector::default())
438        };
439
440        let result = protocol
441            .handle_sync_step1(&a1, a2.doc().transact().state_vector())
442            .unwrap();
443
444        assert_eq!(
445            result,
446            Some(Message::Sync(SyncMessage::SyncStep2(expected)))
447        );
448
449        if let Some(Message::Sync(SyncMessage::SyncStep2(u))) = result {
450            let result2 = protocol
451                .handle_sync_step2(&mut a2, Update::decode_v1(&u).unwrap())
452                .unwrap();
453
454            assert!(result2.is_none());
455        }
456
457        let txt = a2.doc().transact().get_text("test").unwrap();
458        assert_eq!(txt.get_string(&a2.doc().transact()), "hello".to_owned());
459    }
460
461    #[test]
462    fn protocol_sync_step_update() {
463        let protocol = DefaultProtocol;
464
465        let mut a1 = Awareness::new(Doc::with_client_id(1));
466        let mut a2 = Awareness::new(Doc::with_client_id(2));
467
468        let data = {
469            let txt = a1.doc_mut().get_or_insert_text("test");
470            let mut txn = a1.doc_mut().transact_mut();
471            txt.push(&mut txn, "hello");
472            txn.encode_update_v1()
473        };
474
475        let result = protocol
476            .handle_update(&mut a2, Update::decode_v1(&data).unwrap())
477            .unwrap();
478
479        assert!(result.is_none());
480
481        let txt = a2.doc().transact().get_text("test").unwrap();
482        assert_eq!(txt.get_string(&a2.doc().transact()), "hello".to_owned());
483    }
484
485    #[test]
486    fn protocol_awareness_sync() {
487        let protocol = DefaultProtocol;
488
489        let mut a1 = Awareness::new(Doc::with_client_id(1));
490        let mut a2 = Awareness::new(Doc::with_client_id(2));
491
492        a1.set_local_state("{x:3}");
493        let result = protocol.handle_awareness_query(&a1).unwrap();
494
495        assert_eq!(result, Some(Message::Awareness(a1.update().unwrap())));
496
497        if let Some(Message::Awareness(u)) = result {
498            let result = protocol.handle_awareness_update(&mut a2, u).unwrap();
499            assert!(result.is_none());
500        }
501
502        assert_eq!(a2.clients(), &HashMap::from([(1, "{x:3}".to_owned())]));
503    }
504}