1use std::collections::HashMap;
19
20use serde::{Deserialize, Serialize};
21
22use super::compensation::CompensationHint;
23use super::shape::ShapeDefinition;
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27#[repr(u8)]
28pub enum SyncMessageType {
29 Handshake = 0x01,
30 HandshakeAck = 0x02,
31 DeltaPush = 0x10,
32 DeltaAck = 0x11,
33 DeltaReject = 0x12,
34 ShapeSubscribe = 0x20,
35 ShapeSnapshot = 0x21,
36 ShapeDelta = 0x22,
37 ShapeUnsubscribe = 0x23,
38 VectorClockSync = 0x30,
39 TimeseriesPush = 0x40,
41 TimeseriesAck = 0x41,
43 ResyncRequest = 0x50,
46 Throttle = 0x52,
49 TokenRefresh = 0x60,
51 TokenRefreshAck = 0x61,
53 PingPong = 0xFF,
54}
55
56impl SyncMessageType {
57 pub fn from_u8(v: u8) -> Option<Self> {
58 match v {
59 0x01 => Some(Self::Handshake),
60 0x02 => Some(Self::HandshakeAck),
61 0x10 => Some(Self::DeltaPush),
62 0x11 => Some(Self::DeltaAck),
63 0x12 => Some(Self::DeltaReject),
64 0x20 => Some(Self::ShapeSubscribe),
65 0x21 => Some(Self::ShapeSnapshot),
66 0x22 => Some(Self::ShapeDelta),
67 0x23 => Some(Self::ShapeUnsubscribe),
68 0x30 => Some(Self::VectorClockSync),
69 0x40 => Some(Self::TimeseriesPush),
70 0x41 => Some(Self::TimeseriesAck),
71 0x50 => Some(Self::ResyncRequest),
72 0x52 => Some(Self::Throttle),
73 0x60 => Some(Self::TokenRefresh),
74 0x61 => Some(Self::TokenRefreshAck),
75 0xFF => Some(Self::PingPong),
76 _ => None,
77 }
78 }
79}
80
81pub struct SyncFrame {
86 pub msg_type: SyncMessageType,
87 pub body: Vec<u8>,
88}
89
90impl SyncFrame {
91 pub const HEADER_SIZE: usize = 5;
92
93 pub fn to_bytes(&self) -> Vec<u8> {
95 let len = self.body.len() as u32;
96 let mut buf = Vec::with_capacity(Self::HEADER_SIZE + self.body.len());
97 buf.push(self.msg_type as u8);
98 buf.extend_from_slice(&len.to_le_bytes());
99 buf.extend_from_slice(&self.body);
100 buf
101 }
102
103 pub fn from_bytes(data: &[u8]) -> Option<Self> {
107 if data.len() < Self::HEADER_SIZE {
108 return None;
109 }
110 let msg_type = SyncMessageType::from_u8(data[0])?;
111 let len = u32::from_le_bytes(data[1..5].try_into().ok()?) as usize;
112 if data.len() < Self::HEADER_SIZE + len {
113 return None;
114 }
115 let body = data[Self::HEADER_SIZE..Self::HEADER_SIZE + len].to_vec();
116 Some(Self { msg_type, body })
117 }
118
119 pub fn new_msgpack<T: Serialize>(msg_type: SyncMessageType, value: &T) -> Option<Self> {
121 let body = rmp_serde::to_vec_named(value).ok()?;
122 Some(Self { msg_type, body })
123 }
124
125 pub fn encode_or_empty<T: Serialize>(msg_type: SyncMessageType, value: &T) -> Self {
128 Self::new_msgpack(msg_type, value).unwrap_or(Self {
129 msg_type,
130 body: Vec::new(),
131 })
132 }
133
134 pub fn decode_body<'a, T: Deserialize<'a>>(&'a self) -> Option<T> {
136 rmp_serde::from_slice(&self.body).ok()
137 }
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct HandshakeMsg {
145 pub jwt_token: String,
147 pub vector_clock: HashMap<String, HashMap<String, u64>>,
149 pub subscribed_shapes: Vec<String>,
151 pub client_version: String,
153 #[serde(default)]
155 pub lite_id: String,
156 #[serde(default)]
158 pub epoch: u64,
159 #[serde(default)]
162 pub wire_version: u16,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct HandshakeAckMsg {
168 pub success: bool,
170 pub session_id: String,
172 pub server_clock: HashMap<String, u64>,
174 pub error: Option<String>,
176 #[serde(default)]
178 pub fork_detected: bool,
179 #[serde(default)]
181 pub server_wire_version: u16,
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct DeltaPushMsg {
187 pub collection: String,
189 pub document_id: String,
191 pub delta: Vec<u8>,
193 pub peer_id: u64,
195 pub mutation_id: u64,
197 #[serde(default)]
200 pub checksum: u32,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct DeltaAckMsg {
206 pub mutation_id: u64,
208 pub lsn: u64,
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct DeltaRejectMsg {
215 pub mutation_id: u64,
217 pub reason: String,
219 pub compensation: Option<CompensationHint>,
221}
222
223#[derive(Debug, Clone, Serialize, Deserialize)]
225pub struct ShapeSubscribeMsg {
226 pub shape: ShapeDefinition,
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
232pub struct ShapeSnapshotMsg {
233 pub shape_id: String,
235 pub data: Vec<u8>,
237 pub snapshot_lsn: u64,
239 pub doc_count: usize,
241}
242
243#[derive(Debug, Clone, Serialize, Deserialize)]
245pub struct ShapeDeltaMsg {
246 pub shape_id: String,
248 pub collection: String,
250 pub document_id: String,
252 pub operation: String,
254 pub delta: Vec<u8>,
256 pub lsn: u64,
258}
259
260#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct ShapeUnsubscribeMsg {
263 pub shape_id: String,
264}
265
266#[derive(Debug, Clone, Serialize, Deserialize)]
268pub struct VectorClockSyncMsg {
269 pub clocks: HashMap<String, u64>,
271 pub sender_id: u64,
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
286pub struct ResyncRequestMsg {
287 pub reason: ResyncReason,
289 pub from_mutation_id: u64,
291 pub collection: String,
293}
294
295#[derive(Debug, Clone, Serialize, Deserialize)]
297pub enum ResyncReason {
298 SequenceGap {
300 expected: u64,
302 received: u64,
304 },
305 ChecksumMismatch {
307 mutation_id: u64,
309 },
310 CorruptedState,
312}
313
314#[derive(Debug, Clone, Serialize, Deserialize)]
320pub struct ThrottleMsg {
321 pub throttle: bool,
323 pub queue_depth: u64,
325 pub suggested_rate: u64,
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize)]
336pub struct TokenRefreshMsg {
337 pub new_token: String,
339}
340
341#[derive(Debug, Clone, Serialize, Deserialize)]
343pub struct TokenRefreshAckMsg {
344 pub success: bool,
346 pub error: Option<String>,
348 #[serde(default)]
350 pub expires_in_secs: u64,
351}
352
353#[derive(Debug, Clone, Serialize, Deserialize)]
355pub struct PingPongMsg {
356 pub timestamp_ms: u64,
358 pub is_pong: bool,
360}
361
362#[derive(Debug, Clone, Serialize, Deserialize)]
364pub struct TimeseriesPushMsg {
365 pub lite_id: String,
367 pub collection: String,
369 pub ts_block: Vec<u8>,
371 pub val_block: Vec<u8>,
373 pub series_block: Vec<u8>,
375 pub sample_count: u64,
377 pub min_ts: i64,
379 pub max_ts: i64,
381 pub watermarks: HashMap<u64, u64>,
384}
385
386#[derive(Debug, Clone, Serialize, Deserialize)]
388pub struct TimeseriesAckMsg {
389 pub collection: String,
391 pub accepted: u64,
393 pub rejected: u64,
395 pub lsn: u64,
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402
403 #[test]
404 fn frame_roundtrip() {
405 let ping = PingPongMsg {
406 timestamp_ms: 12345,
407 is_pong: false,
408 };
409 let frame = SyncFrame::new_msgpack(SyncMessageType::PingPong, &ping).unwrap();
410 let bytes = frame.to_bytes();
411 let decoded = SyncFrame::from_bytes(&bytes).unwrap();
412 assert_eq!(decoded.msg_type, SyncMessageType::PingPong);
413 let decoded_ping: PingPongMsg = decoded.decode_body().unwrap();
414 assert_eq!(decoded_ping.timestamp_ms, 12345);
415 assert!(!decoded_ping.is_pong);
416 }
417
418 #[test]
419 fn handshake_serialization() {
420 let msg = HandshakeMsg {
421 jwt_token: "test.jwt.token".into(),
422 vector_clock: HashMap::new(),
423 subscribed_shapes: vec!["shape1".into()],
424 client_version: "0.1.0".into(),
425 lite_id: String::new(),
426 epoch: 0,
427 wire_version: 1,
428 };
429 let frame = SyncFrame::new_msgpack(SyncMessageType::Handshake, &msg).unwrap();
430 let bytes = frame.to_bytes();
431 assert!(bytes.len() > SyncFrame::HEADER_SIZE);
432 assert_eq!(bytes[0], 0x01);
433 }
434
435 #[test]
436 fn delta_reject_with_compensation() {
437 let reject = DeltaRejectMsg {
438 mutation_id: 42,
439 reason: "unique violation".into(),
440 compensation: Some(CompensationHint::UniqueViolation {
441 field: "email".into(),
442 conflicting_value: "alice@example.com".into(),
443 }),
444 };
445 let frame = SyncFrame::new_msgpack(SyncMessageType::DeltaReject, &reject).unwrap();
446 let decoded: DeltaRejectMsg = SyncFrame::from_bytes(&frame.to_bytes())
447 .unwrap()
448 .decode_body()
449 .unwrap();
450 assert_eq!(decoded.mutation_id, 42);
451 assert!(matches!(
452 decoded.compensation,
453 Some(CompensationHint::UniqueViolation { .. })
454 ));
455 }
456
457 #[test]
458 fn message_type_roundtrip() {
459 for v in [
460 0x01, 0x02, 0x10, 0x11, 0x12, 0x20, 0x21, 0x22, 0x23, 0x30, 0x40, 0x41, 0x50, 0x52,
461 0x60, 0x61, 0xFF,
462 ] {
463 let mt = SyncMessageType::from_u8(v).unwrap();
464 assert_eq!(mt as u8, v);
465 }
466 assert!(SyncMessageType::from_u8(0x99).is_none());
467 }
468
469 #[test]
470 fn shape_subscribe_roundtrip() {
471 let msg = ShapeSubscribeMsg {
472 shape: ShapeDefinition {
473 shape_id: "s1".into(),
474 tenant_id: 1,
475 shape_type: super::super::shape::ShapeType::Vector {
476 collection: "embeddings".into(),
477 field_name: None,
478 },
479 description: "all embeddings".into(),
480 field_filter: vec![],
481 },
482 };
483 let frame = SyncFrame::new_msgpack(SyncMessageType::ShapeSubscribe, &msg).unwrap();
484 let decoded: ShapeSubscribeMsg = SyncFrame::from_bytes(&frame.to_bytes())
485 .unwrap()
486 .decode_body()
487 .unwrap();
488 assert_eq!(decoded.shape.shape_id, "s1");
489 }
490}