use crate::encoding::read;
use crate::sync::{awareness, Awareness, AwarenessUpdate};
use crate::updates::decoder::{Decode, Decoder};
use crate::updates::encoder::{Encode, Encoder};
use crate::{ReadTxn, StateVector, Transact, Update};
use thiserror::Error;
#[derive(Debug, Copy, Clone, Default)]
pub struct DefaultProtocol;
impl Protocol for DefaultProtocol {}
pub trait Protocol {
fn start<E: Encoder>(&self, awareness: &Awareness, encoder: &mut E) -> Result<(), Error> {
let (sv, update) = {
let sv = awareness.doc().transact().state_vector();
let update = awareness.update()?;
(sv, update)
};
Message::Sync(SyncMessage::SyncStep1(sv)).encode(encoder);
Message::Awareness(update).encode(encoder);
Ok(())
}
fn handle_sync_step1(
&self,
awareness: &Awareness,
sv: StateVector,
) -> Result<Option<Message>, Error> {
let update = awareness.doc().transact().encode_state_as_update_v1(&sv);
Ok(Some(Message::Sync(SyncMessage::SyncStep2(update))))
}
fn handle_sync_step2(
&self,
awareness: &mut Awareness,
update: Update,
) -> Result<Option<Message>, Error> {
let mut txn = awareness.doc().transact_mut();
txn.apply_update(update);
Ok(None)
}
fn handle_update(
&self,
awareness: &mut Awareness,
update: Update,
) -> Result<Option<Message>, Error> {
self.handle_sync_step2(awareness, update)
}
fn handle_auth(
&self,
_awareness: &Awareness,
deny_reason: Option<String>,
) -> Result<Option<Message>, Error> {
if let Some(reason) = deny_reason {
Err(Error::PermissionDenied { reason })
} else {
Ok(None)
}
}
fn handle_awareness_query(&self, awareness: &Awareness) -> Result<Option<Message>, Error> {
let update = awareness.update()?;
Ok(Some(Message::Awareness(update)))
}
fn handle_awareness_update(
&self,
awareness: &mut Awareness,
update: AwarenessUpdate,
) -> Result<Option<Message>, Error> {
awareness.apply_update(update)?;
Ok(None)
}
fn missing_handle(
&self,
_awareness: &mut Awareness,
tag: u8,
_data: Vec<u8>,
) -> Result<Option<Message>, Error> {
Err(Error::Unsupported(tag))
}
}
pub const MSG_SYNC: u8 = 0;
pub const MSG_AWARENESS: u8 = 1;
pub const MSG_AUTH: u8 = 2;
pub const MSG_QUERY_AWARENESS: u8 = 3;
pub const PERMISSION_DENIED: u8 = 0;
pub const PERMISSION_GRANTED: u8 = 1;
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum Message {
Sync(SyncMessage),
Auth(Option<String>),
AwarenessQuery,
Awareness(AwarenessUpdate),
Custom(u8, Vec<u8>),
}
impl Encode for Message {
fn encode<E: Encoder>(&self, encoder: &mut E) {
match self {
Message::Sync(msg) => {
encoder.write_var(MSG_SYNC);
msg.encode(encoder);
}
Message::Auth(reason) => {
encoder.write_var(MSG_AUTH);
if let Some(reason) = reason {
encoder.write_var(PERMISSION_DENIED);
encoder.write_string(&reason);
} else {
encoder.write_var(PERMISSION_GRANTED);
}
}
Message::AwarenessQuery => {
encoder.write_var(MSG_QUERY_AWARENESS);
}
Message::Awareness(update) => {
encoder.write_var(MSG_AWARENESS);
encoder.write_buf(&update.encode_v1())
}
Message::Custom(tag, data) => {
encoder.write_u8(*tag);
encoder.write_buf(&data);
}
}
}
}
impl Decode for Message {
fn decode<D: Decoder>(decoder: &mut D) -> Result<Self, read::Error> {
let tag: u8 = decoder.read_var()?;
match tag {
MSG_SYNC => {
let msg = SyncMessage::decode(decoder)?;
Ok(Message::Sync(msg))
}
MSG_AWARENESS => {
let data = decoder.read_buf()?;
let update = AwarenessUpdate::decode_v1(data)?;
Ok(Message::Awareness(update))
}
MSG_AUTH => {
let reason = if decoder.read_var::<u8>()? == PERMISSION_DENIED {
Some(decoder.read_string()?.to_string())
} else {
None
};
Ok(Message::Auth(reason))
}
MSG_QUERY_AWARENESS => Ok(Message::AwarenessQuery),
tag => {
let data = decoder.read_buf()?;
Ok(Message::Custom(tag, data.to_vec()))
}
}
}
}
pub const MSG_SYNC_STEP_1: u8 = 0;
pub const MSG_SYNC_STEP_2: u8 = 1;
pub const MSG_SYNC_UPDATE: u8 = 2;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SyncMessage {
SyncStep1(StateVector),
SyncStep2(Vec<u8>),
Update(Vec<u8>),
}
impl Encode for SyncMessage {
fn encode<E: Encoder>(&self, encoder: &mut E) {
match self {
SyncMessage::SyncStep1(sv) => {
encoder.write_var(MSG_SYNC_STEP_1);
encoder.write_buf(sv.encode_v1());
}
SyncMessage::SyncStep2(u) => {
encoder.write_var(MSG_SYNC_STEP_2);
encoder.write_buf(u);
}
SyncMessage::Update(u) => {
encoder.write_var(MSG_SYNC_UPDATE);
encoder.write_buf(u);
}
}
}
}
impl Decode for SyncMessage {
fn decode<D: Decoder>(decoder: &mut D) -> Result<Self, read::Error> {
let tag: u8 = decoder.read_var()?;
match tag {
MSG_SYNC_STEP_1 => {
let buf = decoder.read_buf()?;
let sv = StateVector::decode_v1(buf)?;
Ok(SyncMessage::SyncStep1(sv))
}
MSG_SYNC_STEP_2 => {
let buf = decoder.read_buf()?;
Ok(SyncMessage::SyncStep2(buf.into()))
}
MSG_SYNC_UPDATE => {
let buf = decoder.read_buf()?;
Ok(SyncMessage::Update(buf.into()))
}
_ => Err(read::Error::UnexpectedValue),
}
}
}
#[derive(Debug, Error)]
pub enum Error {
#[error("failed to deserialize message: {0}")]
DecodingError(#[from] read::Error),
#[error("failed to process awareness update: {0}")]
AwarenessEncoding(#[from] awareness::Error),
#[error("permission denied to access: {reason}")]
PermissionDenied { reason: String },
#[error("unsupported message tag identifier: {0}")]
Unsupported(u8),
#[error("IO error: {0}")]
IO(#[from] std::io::Error),
#[error("internal failure: {0}")]
Other(#[from] Box<dyn std::error::Error + Send + Sync>),
}
#[cfg(feature = "net")]
impl From<tokio::task::JoinError> for Error {
fn from(value: tokio::task::JoinError) -> Self {
Error::Other(value.into())
}
}
pub struct MessageReader<'a, D: Decoder>(&'a mut D);
impl<'a, D: Decoder> MessageReader<'a, D> {
pub fn new(decoder: &'a mut D) -> Self {
MessageReader(decoder)
}
}
impl<'a, D: Decoder> Iterator for MessageReader<'a, D> {
type Item = Result<Message, read::Error>;
fn next(&mut self) -> Option<Self::Item> {
match Message::decode(self.0) {
Ok(msg) => Some(Ok(msg)),
Err(read::Error::EndOfBuffer(_)) => None,
Err(error) => Some(Err(error)),
}
}
}
#[cfg(test)]
mod test {
use crate::encoding::read::Cursor;
use crate::sync::protocol::MessageReader;
use crate::sync::{Awareness, Protocol};
use crate::updates::decoder::{Decode, DecoderV1};
use crate::updates::encoder::{Encode, Encoder, EncoderV1};
use crate::{Doc, GetString, ReadTxn, StateVector, Text, Transact, Update};
use std::collections::HashMap;
#[test]
fn message_encoding() {
let doc = Doc::new();
let txt = doc.get_or_insert_text("text");
txt.push(&mut doc.transact_mut(), "hello world");
let mut awareness = Awareness::new(doc);
awareness.set_local_state("{\"user\":{\"name\":\"Anonymous 50\",\"color\":\"#30bced\",\"colorLight\":\"#30bced33\"}}");
let messages = [
crate::sync::Message::Sync(crate::sync::SyncMessage::SyncStep1(
awareness.doc().transact().state_vector(),
)),
crate::sync::Message::Sync(crate::sync::SyncMessage::SyncStep2(
awareness
.doc()
.transact()
.encode_state_as_update_v1(&StateVector::default()),
)),
crate::sync::Message::Awareness(awareness.update().unwrap()),
crate::sync::Message::Auth(Some("reason".to_string())),
crate::sync::Message::AwarenessQuery,
];
for msg in messages {
let encoded = msg.encode_v1();
let decoded = crate::sync::Message::decode_v1(&encoded)
.expect(&format!("failed to decode {:?}", msg));
assert_eq!(decoded, msg);
}
}
#[test]
fn protocol_init() {
let awareness = Awareness::default();
let protocol = crate::sync::DefaultProtocol;
let mut encoder = EncoderV1::new();
protocol.start(&awareness, &mut encoder).unwrap();
let data = encoder.to_vec();
let mut decoder = DecoderV1::new(Cursor::new(&data));
let mut reader = MessageReader::new(&mut decoder);
assert_eq!(
reader.next().unwrap().unwrap(),
crate::sync::Message::Sync(crate::sync::SyncMessage::SyncStep1(StateVector::default()))
);
assert_eq!(
reader.next().unwrap().unwrap(),
crate::sync::Message::Awareness(awareness.update().unwrap())
);
assert!(reader.next().is_none());
}
#[test]
fn protocol_sync_steps() {
let protocol = crate::sync::DefaultProtocol;
let mut a1 = Awareness::new(Doc::with_client_id(1));
let mut a2 = Awareness::new(Doc::with_client_id(2));
let expected = {
let txt = a1.doc_mut().get_or_insert_text("test");
let mut txn = a1.doc_mut().transact_mut();
txt.push(&mut txn, "hello");
txn.encode_state_as_update_v1(&StateVector::default())
};
let result = protocol
.handle_sync_step1(&a1, a2.doc().transact().state_vector())
.unwrap();
assert_eq!(
result,
Some(crate::sync::Message::Sync(
crate::sync::SyncMessage::SyncStep2(expected)
))
);
if let Some(crate::sync::Message::Sync(crate::sync::SyncMessage::SyncStep2(u))) = result {
let result2 = protocol
.handle_sync_step2(&mut a2, Update::decode_v1(&u).unwrap())
.unwrap();
assert!(result2.is_none());
}
let txt = a2.doc().transact().get_text("test").unwrap();
assert_eq!(txt.get_string(&a2.doc().transact()), "hello".to_owned());
}
#[test]
fn protocol_sync_step_update() {
let protocol = crate::sync::DefaultProtocol;
let mut a1 = Awareness::new(Doc::with_client_id(1));
let mut a2 = Awareness::new(Doc::with_client_id(2));
let data = {
let txt = a1.doc_mut().get_or_insert_text("test");
let mut txn = a1.doc_mut().transact_mut();
txt.push(&mut txn, "hello");
txn.encode_update_v1()
};
let result = protocol
.handle_update(&mut a2, Update::decode_v1(&data).unwrap())
.unwrap();
assert!(result.is_none());
let txt = a2.doc().transact().get_text("test").unwrap();
assert_eq!(txt.get_string(&a2.doc().transact()), "hello".to_owned());
}
#[test]
fn protocol_awareness_sync() {
let protocol = crate::sync::DefaultProtocol;
let mut a1 = Awareness::new(Doc::with_client_id(1));
let mut a2 = Awareness::new(Doc::with_client_id(2));
a1.set_local_state("{x:3}");
let result = protocol.handle_awareness_query(&a1).unwrap();
assert_eq!(
result,
Some(crate::sync::Message::Awareness(a1.update().unwrap()))
);
if let Some(crate::sync::Message::Awareness(u)) = result {
let result = protocol.handle_awareness_update(&mut a2, u).unwrap();
assert!(result.is_none());
}
assert_eq!(a2.clients(), &HashMap::from([(1, "{x:3}".to_owned())]));
}
}