1use crate::bytes::{BytesReader, BytesWriter};
6use crate::protocol::*;
7
8const MAX_ROOM_ID_LENGTH: usize = 128;
9
10fn encode_crdt(w: &mut BytesWriter, crdt: CrdtType) {
11 w.push_bytes(&crdt.magic_bytes());
12}
13
14fn decode_crdt(r: &mut BytesReader) -> Result<CrdtType, String> {
15 if r.remaining() < 4 { return Err("Invalid message: too short for CRDT type".into()); }
16 let bytes = r.read_bytes(4)?;
17 let mut arr = [0u8; 4];
18 arr.copy_from_slice(bytes);
19 CrdtType::from_magic_bytes(arr).ok_or_else(|| format!("Invalid CRDT type: {}{}{}{}",
20 arr[0] as char, arr[1] as char, arr[2] as char, arr[3] as char))
21}
22
23fn encode_permission(w: &mut BytesWriter, p: Permission) {
24 match p {
25 Permission::Read => w.push_var_string("read"),
26 Permission::Write => w.push_var_string("write"),
27 }
28}
29
30fn decode_permission(r: &mut BytesReader) -> Result<Permission, String> {
31 match r.read_var_string()?.as_str() {
32 "read" => Ok(Permission::Read),
33 "write" => Ok(Permission::Write),
34 other => Err(format!("Invalid permission: {}", other)),
35 }
36}
37pub fn encode(message: &ProtocolMessage) -> Result<Vec<u8>, String> {
45 let mut w = BytesWriter::new();
46
47 match message {
49 ProtocolMessage::JoinRequest { crdt, room_id, auth: _, version: _ }
50 | ProtocolMessage::JoinResponseOk { crdt, room_id, permission: _, version: _, extra: _ }
51 | ProtocolMessage::JoinError { crdt, room_id, code: _, message: _, receiver_version: _, app_code: _ }
52 | ProtocolMessage::DocUpdate { crdt, room_id, updates: _ }
53 | ProtocolMessage::DocUpdateFragmentHeader { crdt, room_id, batch_id: _, fragment_count: _, total_size_bytes: _ }
54 | ProtocolMessage::DocUpdateFragment { crdt, room_id, batch_id: _, index: _, fragment: _ }
55 | ProtocolMessage::UpdateError { crdt, room_id, code: _, message: _, batch_id: _, app_code: _ }
56 | ProtocolMessage::Leave { crdt, room_id } => {
57 let room_id_bytes = room_id.as_bytes();
58 if room_id_bytes.len() > MAX_ROOM_ID_LENGTH { return Err("Room ID too long".into()); }
59 encode_crdt(&mut w, *crdt);
60 w.push_var_bytes(room_id_bytes);
61 }
62 }
63
64 let ty: u8 = match message {
66 ProtocolMessage::JoinRequest { .. } => MessageType::JoinRequest as u8,
67 ProtocolMessage::JoinResponseOk { .. } => MessageType::JoinResponseOk as u8,
68 ProtocolMessage::JoinError { .. } => MessageType::JoinError as u8,
69 ProtocolMessage::DocUpdate { .. } => MessageType::DocUpdate as u8,
70 ProtocolMessage::DocUpdateFragmentHeader { .. } => MessageType::DocUpdateFragmentHeader as u8,
71 ProtocolMessage::DocUpdateFragment { .. } => MessageType::DocUpdateFragment as u8,
72 ProtocolMessage::UpdateError { .. } => MessageType::UpdateError as u8,
73 ProtocolMessage::Leave { .. } => MessageType::Leave as u8,
74 };
75 w.push_byte(ty);
76
77 match message {
79 ProtocolMessage::JoinRequest { auth, version, .. } => {
80 w.push_var_bytes(auth);
81 w.push_var_bytes(version);
82 }
83 ProtocolMessage::JoinResponseOk { permission, version, extra, .. } => {
84 encode_permission(&mut w, *permission);
85 w.push_var_bytes(version);
86 if let Some(e) = extra { w.push_var_bytes(e); } else { w.push_var_bytes(&[]); }
87 }
88 ProtocolMessage::JoinError { code, message, receiver_version, app_code, .. } => {
89 w.push_byte(*code as u8);
90 w.push_var_string(message);
91 if matches!(code, JoinErrorCode::VersionUnknown) {
92 if let Some(v) = receiver_version { w.push_var_bytes(v); }
93 }
94 if matches!(code, JoinErrorCode::AppError) {
95 if let Some(app) = app_code { w.push_var_string(app); }
96 }
97 }
98 ProtocolMessage::DocUpdate { updates, .. } => {
99 w.push_uleb128(updates.len() as u64);
100 for u in updates { w.push_var_bytes(u); }
101 }
102 ProtocolMessage::DocUpdateFragmentHeader { batch_id, fragment_count, total_size_bytes, .. } => {
103 w.push_bytes(&batch_id.0);
104 w.push_uleb128(*fragment_count);
105 w.push_uleb128(*total_size_bytes);
106 }
107 ProtocolMessage::DocUpdateFragment { batch_id, index, fragment, .. } => {
108 w.push_bytes(&batch_id.0);
109 w.push_uleb128(*index);
110 w.push_var_bytes(fragment);
111 }
112 ProtocolMessage::UpdateError { code, message, batch_id, app_code, .. } => {
113 w.push_byte(*code as u8);
114 w.push_var_string(message);
115 if matches!(code, UpdateErrorCode::FragmentTimeout) {
116 if let Some(id) = batch_id { w.push_bytes(&id.0); }
117 }
118 if matches!(code, UpdateErrorCode::AppError) {
119 if let Some(app) = app_code { w.push_var_string(app); }
120 }
121 }
122 ProtocolMessage::Leave { .. } => {}
123 }
124
125 let out = w.finalize();
126 if out.len() > MAX_MESSAGE_SIZE {
127 return Err(format!(
128 "Message size {} exceeds maximum {}",
129 out.len(), MAX_MESSAGE_SIZE
130 ));
131 }
132 Ok(out)
133}
134
135pub fn decode(buf: &[u8]) -> Result<ProtocolMessage, String> {
136 let mut r = BytesReader::new(buf);
137
138 let crdt = decode_crdt(&mut r)?;
140 let room_id = r.read_var_string()?;
142 if room_id.len() > MAX_ROOM_ID_LENGTH {
143 return Err("Room ID exceeds maximum length of 128 bytes".into());
144 }
145
146 let t = r.read_byte()?;
148 let ty = MessageType::from_u8(t).ok_or_else(|| "Invalid message type".to_string())?;
149
150 use ProtocolMessage as PM;
151 let msg = match ty {
152 MessageType::JoinRequest => {
153 let auth = r.read_var_bytes()?.to_vec();
154 let version = r.read_var_bytes()?.to_vec();
155 PM::JoinRequest { crdt, room_id, auth, version }
156 }
157 MessageType::JoinResponseOk => {
158 let permission = decode_permission(&mut r)?;
159 let version = r.read_var_bytes()?.to_vec();
160 let extra = r.read_var_bytes()?.to_vec();
161 PM::JoinResponseOk { crdt, room_id, permission, version, extra: Some(extra) }
162 }
163 MessageType::JoinError => {
164 let code = r.read_byte().ok().and_then(JoinErrorCode::from_u8).unwrap_or(JoinErrorCode::Unknown);
165 let message = r.read_var_string()?;
166 let mut receiver_version = None;
167 let mut app_code = None;
168 if matches!(code, JoinErrorCode::VersionUnknown) && r.remaining() > 0 {
169 if let Ok(bytes) = r.read_var_bytes() { receiver_version = Some(bytes.to_vec()); }
170 }
171 if matches!(code, JoinErrorCode::AppError) && r.remaining() > 0 {
172 if let Ok(app) = r.read_var_string() { app_code = Some(app); }
173 }
174 PM::JoinError { crdt, room_id, code, message, receiver_version, app_code }
175 }
176 MessageType::DocUpdate => {
177 let count = r.read_uleb128()? as usize;
178 let mut updates = Vec::with_capacity(count);
179 for _ in 0..count { updates.push(r.read_var_bytes()?.to_vec()); }
180 PM::DocUpdate { crdt, room_id, updates }
181 }
182 MessageType::DocUpdateFragmentHeader => {
183 if r.remaining() < 8 { return Err("Invalid DocUpdateFragmentHeader: missing batch ID".into()); }
184 let id = r.read_bytes(8)?;
185 let mut arr = [0u8; 8];
186 arr.copy_from_slice(id);
187 let batch_id = BatchId(arr);
188 let fragment_count = r.read_uleb128()?;
189 let total_size_bytes = r.read_uleb128()?;
190 PM::DocUpdateFragmentHeader { crdt, room_id, batch_id, fragment_count, total_size_bytes }
191 }
192 MessageType::DocUpdateFragment => {
193 if r.remaining() < 8 { return Err("Invalid DocUpdateFragment: missing batch ID".into()); }
194 let id = r.read_bytes(8)?;
195 let mut arr = [0u8; 8];
196 arr.copy_from_slice(id);
197 let batch_id = BatchId(arr);
198 let index = r.read_uleb128()?;
199 let fragment = r.read_var_bytes()?.to_vec();
200 PM::DocUpdateFragment { crdt, room_id, batch_id, index, fragment }
201 }
202 MessageType::UpdateError => {
203 let code = r.read_byte().ok().and_then(UpdateErrorCode::from_u8).unwrap_or(UpdateErrorCode::Unknown);
204 let message = r.read_var_string()?;
205 let mut batch_id = None;
206 let mut app_code = None;
207 if matches!(code, UpdateErrorCode::FragmentTimeout) && r.remaining() >= 8 {
208 let id = r.read_bytes(8)?;
209 let mut arr = [0u8; 8];
210 arr.copy_from_slice(id);
211 batch_id = Some(BatchId(arr));
212 }
213 if matches!(code, UpdateErrorCode::AppError) && r.remaining() > 0 {
214 if let Ok(app) = r.read_var_string() { app_code = Some(app); }
215 }
216 PM::UpdateError { crdt, room_id, code, message, batch_id, app_code }
217 }
218 MessageType::Leave => PM::Leave { crdt, room_id },
219 };
220
221 Ok(msg)
222}
223
224pub fn try_decode(buf: &[u8]) -> Option<ProtocolMessage> {
226 decode(buf).ok()
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 #[test]
234 fn encode_decode_leave() {
235 let msg = ProtocolMessage::Leave {
236 crdt: CrdtType::Loro,
237 room_id: "room-123".to_string(),
238 };
239 let enc = encode(&msg).unwrap();
240 let dec = decode(&enc).unwrap();
241 assert_eq!(msg, dec);
242 }
243
244 #[test]
245 fn encode_join_response_ok_defaults_extra() {
246 let msg = ProtocolMessage::JoinResponseOk {
247 crdt: CrdtType::Yjs,
248 room_id: "room-123".to_string(),
249 permission: Permission::Read,
250 version: vec![10,20],
251 extra: None,
252 };
253 let enc = encode(&msg).unwrap();
254 let dec = decode(&enc).unwrap();
255 match dec {
256 ProtocolMessage::JoinResponseOk { permission, version, extra, .. } => {
257 assert!(matches!(permission, Permission::Read));
258 assert_eq!(version, vec![10,20]);
259 assert_eq!(extra.unwrap_or_default(), Vec::<u8>::new());
261 }
262 _ => panic!("wrong decoded type"),
263 }
264 }
265
266 #[test]
267 fn encode_rejects_room_id_over_limit() {
268 let long_room = "x".repeat(129);
269 let msg = ProtocolMessage::JoinRequest {
270 crdt: CrdtType::Loro,
271 room_id: long_room,
272 auth: vec![],
273 version: vec![],
274 };
275 let err = encode(&msg).unwrap_err();
276 assert!(err.contains("Room ID too long"));
277 }
278
279 #[test]
280 fn encode_rejects_payload_over_max_size() {
281 let big_update = vec![0u8; MAX_MESSAGE_SIZE + 1024];
283 let msg = ProtocolMessage::DocUpdate {
284 crdt: CrdtType::Loro,
285 room_id: "room-oversized".into(),
286 updates: vec![big_update],
287 };
288 let err = encode(&msg).unwrap_err();
289 assert!(err.contains("exceeds maximum"));
290 }
291}