1pub mod awareness;
4
5use awareness::{Awareness, AwarenessUpdate};
6use thiserror::Error;
7use tracing::debug;
8use yrs::updates::decoder::{Decode, Decoder};
9use yrs::updates::encoder::{Encode, Encoder};
10use yrs::{ReadTxn, StateVector, Transact, Update};
11
12pub struct DefaultProtocol;
39
40impl Protocol for DefaultProtocol {}
41
42pub trait Protocol {
46 fn start<E: Encoder>(&self, awareness: &Awareness, encoder: &mut E) -> Result<(), Error> {
50 let (sv, update) = {
51 let sv = awareness.doc().transact().state_vector();
52 let update = awareness.update()?;
53 (sv, update)
54 };
55 Message::Sync(SyncMessage::SyncStep1(sv)).encode(encoder);
56 Message::Awareness(update).encode(encoder);
57 Ok(())
58 }
59
60 fn handle_sync_step1(
63 &self,
64 awareness: &Awareness,
65 sv: StateVector,
66 ) -> Result<Option<Message>, Error> {
67 let update = awareness.doc().transact().encode_state_as_update_v1(&sv);
68 debug!("pray - handle_sync_step1 - update: {:?}", update);
69 Ok(Some(Message::Sync(SyncMessage::SyncStep2(update))))
70 }
71
72 fn sync_step1(&self, awareness: &Awareness) -> Result<Message, Error> {
73 let sv = awareness.doc().transact().state_vector();
74 Ok(Message::Sync(SyncMessage::SyncStep1(sv)))
75 }
76
77 fn init_awareness(&self, awareness: &Awareness) -> Result<Message, Error> {
78 let update = awareness.update()?;
79 Ok(Message::Awareness(update))
80 }
81
82 fn awareness(&self, awareness: &Awareness) -> Result<Message, Error> {
83 let update = awareness.update()?;
84 Ok(Message::Awareness(update))
85 }
86
87 fn handle_sync_step2(
90 &self,
91 awareness: &mut Awareness,
92 update: Update,
93 ) -> Result<Option<Message>, Error> {
94 let mut txn = awareness.doc().transact_mut();
95 txn.apply_update(update);
96 Ok(Some(Message::SyncStatus(true)))
97 }
98
99 fn handle_update(
102 &self,
103 awareness: &mut Awareness,
104 update: Update,
105 ) -> Result<Option<Message>, Error> {
106 self.handle_sync_step2(awareness, update)
107 }
108
109 fn handle_auth_success(&self, _awareness: &Awareness, read_write: bool) -> Message {
110 Message::Auth(
111 if read_write {
112 Some("read-write".to_owned())
113 } else {
114 Some("read-only".to_owned())
115 },
116 true,
117 )
118 }
119
120 fn handle_auth_fail(
123 &self,
124 _awareness: &Awareness,
125 ) -> Message {
126 Message::Auth(
127 None,
128 false,
129 )
130 }
131
132 fn handle_awareness_query(&self, awareness: &Awareness) -> Result<Option<Message>, Error> {
135 let update = awareness.update()?;
136 Ok(Some(Message::Awareness(update)))
137 }
138
139 fn handle_awareness_update(
142 &self,
143 awareness: &mut Awareness,
144 update: AwarenessUpdate,
145 ) -> Result<Option<Message>, Error> {
146 awareness.apply_update(update)?;
147 Ok(None)
148 }
149
150 fn missing_handle(
153 &self,
154 _awareness: &mut Awareness,
155 tag: u8,
156 _data: Vec<u8>,
157 ) -> Result<Option<Message>, Error> {
158 Err(Error::Unsupported(tag))
159 }
160}
161
162pub const MSG_SYNC: u8 = 0;
164pub const MSG_AWARENESS: u8 = 1;
166pub const MSG_AUTH: u8 = 2;
168pub const MSG_QUERY_AWARENESS: u8 = 3;
170pub const MSG_SYNC_STATUS: u8 = 8;
172
173pub const PERMISSION_DENIED: u8 = 0; pub const PERMISSION_GRANTED: u8 = 1;
176pub const AUTHENTICATED: u8 = 2;
177
178#[derive(Debug, Eq, PartialEq)]
179pub enum Message {
180 Sync(SyncMessage),
181 Auth(Option<String>, bool),
182 AwarenessQuery,
183 Awareness(AwarenessUpdate),
184 SyncStatus(bool),
185 Custom(u8, Vec<u8>),
186}
187
188impl Encode for Message {
189 fn encode<E: Encoder>(&self, encoder: &mut E) {
190 match self {
191 Message::Sync(msg) => {
192 encoder.write_var(MSG_SYNC);
193 msg.encode(encoder);
194 }
195 Message::Auth(reason, authenticated) => {
196 encoder.write_var(MSG_AUTH);
197 if *authenticated {
198 encoder.write_var(AUTHENTICATED);
199 } else {
200 encoder.write_var(PERMISSION_DENIED);
201 }
202 if let Some(reason) = reason {
203 encoder.write_string(reason);
204 }
205 }
206 Message::AwarenessQuery => {
207 encoder.write_var(MSG_QUERY_AWARENESS);
208 }
209 Message::Awareness(update) => {
210 encoder.write_var(MSG_AWARENESS);
211 encoder.write_buf(update.encode_v1())
212 }
213 Message::SyncStatus(connected) => {
214 encoder.write_var(MSG_SYNC_STATUS);
215 encoder.write_var(*connected as u8);
216 }
217 Message::Custom(tag, data) => {
218 encoder.write_u8(*tag);
219 encoder.write_buf(data);
220 }
221 }
222 }
223}
224
225impl Decode for Message {
226 fn decode<D: Decoder>(decoder: &mut D) -> Result<Self, yrs::encoding::read::Error> {
227 let tag: u8 = decoder.read_var()?;
228 match tag {
229 MSG_SYNC => {
230 let msg = SyncMessage::decode(decoder)?;
231 Ok(Message::Sync(msg))
232 }
233 MSG_AWARENESS => {
234 let data = decoder.read_buf()?;
235 let update = AwarenessUpdate::decode_v1(data)?;
236 Ok(Message::Awareness(update))
237 }
238 MSG_AUTH => {
239 let token = if decoder.read_var::<u8>()? == PERMISSION_DENIED {
240 Some(decoder.read_string()?.to_string())
241 } else {
242 None
243 };
244 Ok(Message::Auth(token, false))
245 }
246 MSG_QUERY_AWARENESS => Ok(Message::AwarenessQuery),
247 tag => {
248 let data = decoder.read_buf()?;
249 Ok(Message::Custom(tag, data.to_vec()))
250 }
251 }
252 }
253}
254
255pub const MSG_SYNC_STEP_1: u8 = 0;
257pub const MSG_SYNC_STEP_2: u8 = 1;
259pub const MSG_SYNC_UPDATE: u8 = 2;
261
262#[derive(Debug, PartialEq, Eq)]
263pub enum SyncMessage {
264 SyncStep1(StateVector),
265 SyncStep2(Vec<u8>),
266 Update(Vec<u8>),
267}
268
269impl Encode for SyncMessage {
270 fn encode<E: Encoder>(&self, encoder: &mut E) {
271 match self {
272 SyncMessage::SyncStep1(sv) => {
273 encoder.write_var(MSG_SYNC_STEP_1);
274 encoder.write_buf(sv.encode_v1());
275 }
276 SyncMessage::SyncStep2(u) => {
277 encoder.write_var(MSG_SYNC_STEP_2);
278 encoder.write_buf(u);
279 }
280 SyncMessage::Update(u) => {
281 encoder.write_var(MSG_SYNC_UPDATE);
282 encoder.write_buf(u);
283 }
284 }
285 }
286}
287
288impl Decode for SyncMessage {
289 fn decode<D: Decoder>(decoder: &mut D) -> Result<Self, yrs::encoding::read::Error> {
290 let tag: u8 = decoder.read_var()?;
291 match tag {
292 MSG_SYNC_STEP_1 => {
293 let buf = decoder.read_buf()?;
294 let sv = StateVector::decode_v1(buf)?;
295 Ok(SyncMessage::SyncStep1(sv))
296 }
297 MSG_SYNC_STEP_2 => {
298 let buf = decoder.read_buf()?;
299 Ok(SyncMessage::SyncStep2(buf.into()))
300 }
301 MSG_SYNC_UPDATE => {
302 let buf = decoder.read_buf()?;
303 Ok(SyncMessage::Update(buf.into()))
304 }
305 _ => Err(yrs::encoding::read::Error::UnexpectedValue),
306 }
307 }
308}
309
310#[derive(Debug, Error)]
312pub enum Error {
313 #[error("failed to deserialize message: {0}")]
315 EncodingError(#[from] yrs::encoding::read::Error),
316
317 #[error("failed to process awareness update: {0}")]
319 AwarenessEncoding(#[from] awareness::Error),
320
321 #[error("permission denied to access: {reason}")]
323 PermissionDenied { reason: String },
324
325 #[error("unsupported message tag identifier: {0}")]
327 Unsupported(u8),
328
329 #[error("internal failure: {0}")]
331 Other(#[from] Box<dyn std::error::Error + Send + Sync>),
332
333 #[error("{0}")]
335 Anyhow(#[from] anyhow::Error),
336}
337
338pub struct MessageReader<'a, D: Decoder>(&'a mut D);
342
343impl<'a, D: Decoder> MessageReader<'a, D> {
344 pub fn new(decoder: &'a mut D) -> Self {
345 MessageReader(decoder)
346 }
347}
348
349impl<'a, D: Decoder> Iterator for MessageReader<'a, D> {
350 type Item = Result<Message, yrs::encoding::read::Error>;
351
352 fn next(&mut self) -> Option<Self::Item> {
353 match Message::decode(self.0) {
354 Ok(msg) => Some(Ok(msg)),
355 Err(yrs::encoding::read::Error::EndOfBuffer(_)) => None,
356 Err(error) => Some(Err(error)),
357 }
358 }
359}
360
361#[cfg(test)]
362mod test {
363 use super::{Message, SyncMessage};
364 use crate::sync::awareness::Awareness;
365 use crate::sync::{DefaultProtocol, MessageReader, Protocol};
366 use std::collections::HashMap;
367 use yrs::encoding::read::Cursor;
368 use yrs::updates::decoder::{Decode, DecoderV1};
369 use yrs::updates::encoder::{Encode, Encoder, EncoderV1};
370 use yrs::{Doc, GetString, ReadTxn, StateVector, Text, Transact, Update};
371
372 #[test]
373 fn message_encoding() {
374 let doc = Doc::new();
375 let txt = doc.get_or_insert_text("text");
376 txt.push(&mut doc.transact_mut(), "hello world");
377 let mut awareness = Awareness::new(doc);
378 awareness.set_local_state("{\"user\":{\"name\":\"Anonymous 50\",\"color\":\"#30bced\",\"colorLight\":\"#30bced33\"}}");
379
380 let messages = [
381 Message::Sync(SyncMessage::SyncStep1(
382 awareness.doc().transact().state_vector(),
383 )),
384 Message::Sync(SyncMessage::SyncStep2(
385 awareness
386 .doc()
387 .transact()
388 .encode_state_as_update_v1(&StateVector::default()),
389 )),
390 Message::Awareness(awareness.update().unwrap()),
391 Message::Auth(Some("reason".to_string()), false),
392 Message::AwarenessQuery,
393 ];
394
395 for msg in messages {
396 let encoded = msg.encode_v1();
397 let decoded = Message::decode_v1(&encoded)
398 .unwrap_or_else(|_| panic!("failed to decode {:?}", msg));
399 assert_eq!(decoded, msg);
400 }
401 }
402
403 #[test]
404 fn protocol_init() {
405 let awareness = Awareness::default();
406 let protocol = DefaultProtocol;
407 let mut encoder = EncoderV1::new();
408 protocol.start(&awareness, &mut encoder).unwrap();
409 let data = encoder.to_vec();
410 let mut decoder = DecoderV1::new(Cursor::new(&data));
411 let mut reader = MessageReader::new(&mut decoder);
412
413 assert_eq!(
414 reader.next().unwrap().unwrap(),
415 Message::Sync(SyncMessage::SyncStep1(StateVector::default()))
416 );
417
418 assert_eq!(
419 reader.next().unwrap().unwrap(),
420 Message::Awareness(awareness.update().unwrap())
421 );
422
423 assert!(reader.next().is_none());
424 }
425
426 #[test]
427 fn protocol_sync_steps() {
428 let protocol = DefaultProtocol;
429
430 let mut a1 = Awareness::new(Doc::with_client_id(1));
431 let mut a2 = Awareness::new(Doc::with_client_id(2));
432
433 let expected = {
434 let txt = a1.doc_mut().get_or_insert_text("test");
435 let mut txn = a1.doc_mut().transact_mut();
436 txt.push(&mut txn, "hello");
437 txn.encode_state_as_update_v1(&StateVector::default())
438 };
439
440 let result = protocol
441 .handle_sync_step1(&a1, a2.doc().transact().state_vector())
442 .unwrap();
443
444 assert_eq!(
445 result,
446 Some(Message::Sync(SyncMessage::SyncStep2(expected)))
447 );
448
449 if let Some(Message::Sync(SyncMessage::SyncStep2(u))) = result {
450 let result2 = protocol
451 .handle_sync_step2(&mut a2, Update::decode_v1(&u).unwrap())
452 .unwrap();
453
454 assert!(result2.is_none());
455 }
456
457 let txt = a2.doc().transact().get_text("test").unwrap();
458 assert_eq!(txt.get_string(&a2.doc().transact()), "hello".to_owned());
459 }
460
461 #[test]
462 fn protocol_sync_step_update() {
463 let protocol = DefaultProtocol;
464
465 let mut a1 = Awareness::new(Doc::with_client_id(1));
466 let mut a2 = Awareness::new(Doc::with_client_id(2));
467
468 let data = {
469 let txt = a1.doc_mut().get_or_insert_text("test");
470 let mut txn = a1.doc_mut().transact_mut();
471 txt.push(&mut txn, "hello");
472 txn.encode_update_v1()
473 };
474
475 let result = protocol
476 .handle_update(&mut a2, Update::decode_v1(&data).unwrap())
477 .unwrap();
478
479 assert!(result.is_none());
480
481 let txt = a2.doc().transact().get_text("test").unwrap();
482 assert_eq!(txt.get_string(&a2.doc().transact()), "hello".to_owned());
483 }
484
485 #[test]
486 fn protocol_awareness_sync() {
487 let protocol = DefaultProtocol;
488
489 let mut a1 = Awareness::new(Doc::with_client_id(1));
490 let mut a2 = Awareness::new(Doc::with_client_id(2));
491
492 a1.set_local_state("{x:3}");
493 let result = protocol.handle_awareness_query(&a1).unwrap();
494
495 assert_eq!(result, Some(Message::Awareness(a1.update().unwrap())));
496
497 if let Some(Message::Awareness(u)) = result {
498 let result = protocol.handle_awareness_update(&mut a2, u).unwrap();
499 assert!(result.is_none());
500 }
501
502 assert_eq!(a2.clients(), &HashMap::from([(1, "{x:3}".to_owned())]));
503 }
504}