loro_protocol/
encoding.rs

1//! Binary encoder/decoder for the Loro protocol wire format.
2//!
3//! Follows the field order and validation rules from `protocol.md` and matches
4//! the JS implementation in `packages/loro-protocol/src/encoding.ts`.
5use crate::bytes::{BytesReader, BytesWriter};
6use crate::protocol::*;
7
8const MAX_ROOM_ID_LENGTH: usize = 128;
9
10fn encode_crdt(w: &mut BytesWriter, crdt: CrdtType) {
11    w.push_bytes(&crdt.magic_bytes());
12}
13
14fn decode_crdt(r: &mut BytesReader) -> Result<CrdtType, String> {
15    if r.remaining() < 4 { return Err("Invalid message: too short for CRDT type".into()); }
16    let bytes = r.read_bytes(4)?;
17    let mut arr = [0u8; 4];
18    arr.copy_from_slice(bytes);
19    CrdtType::from_magic_bytes(arr).ok_or_else(|| format!("Invalid CRDT type: {}{}{}{}",
20        arr[0] as char, arr[1] as char, arr[2] as char, arr[3] as char))
21}
22
23fn encode_permission(w: &mut BytesWriter, p: Permission) {
24    match p {
25        Permission::Read => w.push_var_string("read"),
26        Permission::Write => w.push_var_string("write"),
27    }
28}
29
30fn decode_permission(r: &mut BytesReader) -> Result<Permission, String> {
31    match r.read_var_string()?.as_str() {
32        "read" => Ok(Permission::Read),
33        "write" => Ok(Permission::Write),
34        other => Err(format!("Invalid permission: {}", other)),
35    }
36}
37/// Encode a `ProtocolMessage` into a compact binary form according to protocol.md.
38///
39/// The encoder writes, in order: 4-byte CRDT magic, `varBytes` room id, 1-byte
40/// message type tag, and the message-specific payload. It validates room id
41/// length and ensures the final buffer does not exceed 256KB.
42///
43/// Returns an owned `Vec<u8>` containing the encoded message.
44pub fn encode(message: &ProtocolMessage) -> Result<Vec<u8>, String> {
45    let mut w = BytesWriter::new();
46
47    // Common: crdt, room_id, type
48    match message {
49        ProtocolMessage::JoinRequest { crdt, room_id, auth: _, version: _ }
50        | ProtocolMessage::JoinResponseOk { crdt, room_id, permission: _, version: _, extra: _ }
51        | ProtocolMessage::JoinError { crdt, room_id, code: _, message: _, receiver_version: _, app_code: _ }
52        | ProtocolMessage::DocUpdate { crdt, room_id, updates: _ }
53        | ProtocolMessage::DocUpdateFragmentHeader { crdt, room_id, batch_id: _, fragment_count: _, total_size_bytes: _ }
54        | ProtocolMessage::DocUpdateFragment { crdt, room_id, batch_id: _, index: _, fragment: _ }
55        | ProtocolMessage::UpdateError { crdt, room_id, code: _, message: _, batch_id: _, app_code: _ }
56        | ProtocolMessage::Leave { crdt, room_id } => {
57            let room_id_bytes = room_id.as_bytes();
58            if room_id_bytes.len() > MAX_ROOM_ID_LENGTH { return Err("Room ID too long".into()); }
59            encode_crdt(&mut w, *crdt);
60            w.push_var_bytes(room_id_bytes);
61        }
62    }
63
64    // Type byte
65    let ty: u8 = match message {
66        ProtocolMessage::JoinRequest { .. } => MessageType::JoinRequest as u8,
67        ProtocolMessage::JoinResponseOk { .. } => MessageType::JoinResponseOk as u8,
68        ProtocolMessage::JoinError { .. } => MessageType::JoinError as u8,
69        ProtocolMessage::DocUpdate { .. } => MessageType::DocUpdate as u8,
70        ProtocolMessage::DocUpdateFragmentHeader { .. } => MessageType::DocUpdateFragmentHeader as u8,
71        ProtocolMessage::DocUpdateFragment { .. } => MessageType::DocUpdateFragment as u8,
72        ProtocolMessage::UpdateError { .. } => MessageType::UpdateError as u8,
73        ProtocolMessage::Leave { .. } => MessageType::Leave as u8,
74    };
75    w.push_byte(ty);
76
77    // Payload
78    match message {
79        ProtocolMessage::JoinRequest { auth, version, .. } => {
80            w.push_var_bytes(auth);
81            w.push_var_bytes(version);
82        }
83        ProtocolMessage::JoinResponseOk { permission, version, extra, .. } => {
84            encode_permission(&mut w, *permission);
85            w.push_var_bytes(version);
86            if let Some(e) = extra { w.push_var_bytes(e); } else { w.push_var_bytes(&[]); }
87        }
88        ProtocolMessage::JoinError { code, message, receiver_version, app_code, .. } => {
89            w.push_byte(*code as u8);
90            w.push_var_string(message);
91            if matches!(code, JoinErrorCode::VersionUnknown) {
92                if let Some(v) = receiver_version { w.push_var_bytes(v); }
93            }
94            if matches!(code, JoinErrorCode::AppError) {
95                if let Some(app) = app_code { w.push_var_string(app); }
96            }
97        }
98        ProtocolMessage::DocUpdate { updates, .. } => {
99            w.push_uleb128(updates.len() as u64);
100            for u in updates { w.push_var_bytes(u); }
101        }
102        ProtocolMessage::DocUpdateFragmentHeader { batch_id, fragment_count, total_size_bytes, .. } => {
103            w.push_bytes(&batch_id.0);
104            w.push_uleb128(*fragment_count);
105            w.push_uleb128(*total_size_bytes);
106        }
107        ProtocolMessage::DocUpdateFragment { batch_id, index, fragment, .. } => {
108            w.push_bytes(&batch_id.0);
109            w.push_uleb128(*index);
110            w.push_var_bytes(fragment);
111        }
112        ProtocolMessage::UpdateError { code, message, batch_id, app_code, .. } => {
113            w.push_byte(*code as u8);
114            w.push_var_string(message);
115            if matches!(code, UpdateErrorCode::FragmentTimeout) {
116                if let Some(id) = batch_id { w.push_bytes(&id.0); }
117            }
118            if matches!(code, UpdateErrorCode::AppError) {
119                if let Some(app) = app_code { w.push_var_string(app); }
120            }
121        }
122        ProtocolMessage::Leave { .. } => {}
123    }
124
125    let out = w.finalize();
126    if out.len() > MAX_MESSAGE_SIZE {
127        return Err(format!(
128            "Message size {} exceeds maximum {}",
129            out.len(), MAX_MESSAGE_SIZE
130        ));
131    }
132    Ok(out)
133}
134
135pub fn decode(buf: &[u8]) -> Result<ProtocolMessage, String> {
136    let mut r = BytesReader::new(buf);
137
138    // CRDT
139    let crdt = decode_crdt(&mut r)?;
140    // room id (varString)
141    let room_id = r.read_var_string()?;
142    if room_id.len() > MAX_ROOM_ID_LENGTH {
143        return Err("Room ID exceeds maximum length of 128 bytes".into());
144    }
145
146    // type
147    let t = r.read_byte()?;
148    let ty = MessageType::from_u8(t).ok_or_else(|| "Invalid message type".to_string())?;
149
150    use ProtocolMessage as PM;
151    let msg = match ty {
152        MessageType::JoinRequest => {
153            let auth = r.read_var_bytes()?.to_vec();
154            let version = r.read_var_bytes()?.to_vec();
155            PM::JoinRequest { crdt, room_id, auth, version }
156        }
157        MessageType::JoinResponseOk => {
158            let permission = decode_permission(&mut r)?;
159            let version = r.read_var_bytes()?.to_vec();
160            let extra = r.read_var_bytes()?.to_vec();
161            PM::JoinResponseOk { crdt, room_id, permission, version, extra: Some(extra) }
162        }
163        MessageType::JoinError => {
164            let code = r.read_byte().ok().and_then(JoinErrorCode::from_u8).unwrap_or(JoinErrorCode::Unknown);
165            let message = r.read_var_string()?;
166            let mut receiver_version = None;
167            let mut app_code = None;
168            if matches!(code, JoinErrorCode::VersionUnknown) && r.remaining() > 0 {
169                if let Ok(bytes) = r.read_var_bytes() { receiver_version = Some(bytes.to_vec()); }
170            }
171            if matches!(code, JoinErrorCode::AppError) && r.remaining() > 0 {
172                if let Ok(app) = r.read_var_string() { app_code = Some(app); }
173            }
174            PM::JoinError { crdt, room_id, code, message, receiver_version, app_code }
175        }
176        MessageType::DocUpdate => {
177            let count = r.read_uleb128()? as usize;
178            let mut updates = Vec::with_capacity(count);
179            for _ in 0..count { updates.push(r.read_var_bytes()?.to_vec()); }
180            PM::DocUpdate { crdt, room_id, updates }
181        }
182        MessageType::DocUpdateFragmentHeader => {
183            if r.remaining() < 8 { return Err("Invalid DocUpdateFragmentHeader: missing batch ID".into()); }
184            let id = r.read_bytes(8)?;
185            let mut arr = [0u8; 8];
186            arr.copy_from_slice(id);
187            let batch_id = BatchId(arr);
188            let fragment_count = r.read_uleb128()?;
189            let total_size_bytes = r.read_uleb128()?;
190            PM::DocUpdateFragmentHeader { crdt, room_id, batch_id, fragment_count, total_size_bytes }
191        }
192        MessageType::DocUpdateFragment => {
193            if r.remaining() < 8 { return Err("Invalid DocUpdateFragment: missing batch ID".into()); }
194            let id = r.read_bytes(8)?;
195            let mut arr = [0u8; 8];
196            arr.copy_from_slice(id);
197            let batch_id = BatchId(arr);
198            let index = r.read_uleb128()?;
199            let fragment = r.read_var_bytes()?.to_vec();
200            PM::DocUpdateFragment { crdt, room_id, batch_id, index, fragment }
201        }
202        MessageType::UpdateError => {
203            let code = r.read_byte().ok().and_then(UpdateErrorCode::from_u8).unwrap_or(UpdateErrorCode::Unknown);
204            let message = r.read_var_string()?;
205            let mut batch_id = None;
206            let mut app_code = None;
207            if matches!(code, UpdateErrorCode::FragmentTimeout) && r.remaining() >= 8 {
208                let id = r.read_bytes(8)?;
209                let mut arr = [0u8; 8];
210                arr.copy_from_slice(id);
211                batch_id = Some(BatchId(arr));
212            }
213            if matches!(code, UpdateErrorCode::AppError) && r.remaining() > 0 {
214                if let Ok(app) = r.read_var_string() { app_code = Some(app); }
215            }
216            PM::UpdateError { crdt, room_id, code, message, batch_id, app_code }
217        }
218        MessageType::Leave => PM::Leave { crdt, room_id },
219    };
220
221    Ok(msg)
222}
223
224/// Attempt to decode a message, returning `None` when parsing fails.
225pub fn try_decode(buf: &[u8]) -> Option<ProtocolMessage> {
226    decode(buf).ok()
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    #[test]
234    fn encode_decode_leave() {
235        let msg = ProtocolMessage::Leave {
236            crdt: CrdtType::Loro,
237            room_id: "room-123".to_string(),
238        };
239        let enc = encode(&msg).unwrap();
240        let dec = decode(&enc).unwrap();
241        assert_eq!(msg, dec);
242    }
243
244    #[test]
245    fn encode_join_response_ok_defaults_extra() {
246        let msg = ProtocolMessage::JoinResponseOk {
247            crdt: CrdtType::Yjs,
248            room_id: "room-123".to_string(),
249            permission: Permission::Read,
250            version: vec![10,20],
251            extra: None,
252        };
253        let enc = encode(&msg).unwrap();
254        let dec = decode(&enc).unwrap();
255        match dec {
256            ProtocolMessage::JoinResponseOk { permission, version, extra, .. } => {
257                assert!(matches!(permission, Permission::Read));
258                assert_eq!(version, vec![10,20]);
259                // decoder reads varBytes even if empty is encoded
260                assert_eq!(extra.unwrap_or_default(), Vec::<u8>::new());
261            }
262            _ => panic!("wrong decoded type"),
263        }
264    }
265
266    #[test]
267    fn encode_rejects_room_id_over_limit() {
268        let long_room = "x".repeat(129);
269        let msg = ProtocolMessage::JoinRequest {
270            crdt: CrdtType::Loro,
271            room_id: long_room,
272            auth: vec![],
273            version: vec![],
274        };
275        let err = encode(&msg).unwrap_err();
276        assert!(err.contains("Room ID too long"));
277    }
278
279    #[test]
280    fn encode_rejects_payload_over_max_size() {
281        // Build an intentionally oversized payload (header + one giant update)
282        let big_update = vec![0u8; MAX_MESSAGE_SIZE + 1024];
283        let msg = ProtocolMessage::DocUpdate {
284            crdt: CrdtType::Loro,
285            room_id: "room-oversized".into(),
286            updates: vec![big_update],
287        };
288        let err = encode(&msg).unwrap_err();
289        assert!(err.contains("exceeds maximum"));
290    }
291}