use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use super::compensation::CompensationHint;
use super::shape::ShapeDefinition;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum SyncMessageType {
Handshake = 0x01,
HandshakeAck = 0x02,
DeltaPush = 0x10,
DeltaAck = 0x11,
DeltaReject = 0x12,
ShapeSubscribe = 0x20,
ShapeSnapshot = 0x21,
ShapeDelta = 0x22,
ShapeUnsubscribe = 0x23,
VectorClockSync = 0x30,
TimeseriesPush = 0x40,
TimeseriesAck = 0x41,
ResyncRequest = 0x50,
Throttle = 0x52,
TokenRefresh = 0x60,
TokenRefreshAck = 0x61,
PingPong = 0xFF,
}
impl SyncMessageType {
pub fn from_u8(v: u8) -> Option<Self> {
match v {
0x01 => Some(Self::Handshake),
0x02 => Some(Self::HandshakeAck),
0x10 => Some(Self::DeltaPush),
0x11 => Some(Self::DeltaAck),
0x12 => Some(Self::DeltaReject),
0x20 => Some(Self::ShapeSubscribe),
0x21 => Some(Self::ShapeSnapshot),
0x22 => Some(Self::ShapeDelta),
0x23 => Some(Self::ShapeUnsubscribe),
0x30 => Some(Self::VectorClockSync),
0x40 => Some(Self::TimeseriesPush),
0x41 => Some(Self::TimeseriesAck),
0x50 => Some(Self::ResyncRequest),
0x52 => Some(Self::Throttle),
0x60 => Some(Self::TokenRefresh),
0x61 => Some(Self::TokenRefreshAck),
0xFF => Some(Self::PingPong),
_ => None,
}
}
}
pub struct SyncFrame {
pub msg_type: SyncMessageType,
pub body: Vec<u8>,
}
impl SyncFrame {
pub const HEADER_SIZE: usize = 5;
pub fn to_bytes(&self) -> Vec<u8> {
let len = self.body.len() as u32;
let mut buf = Vec::with_capacity(Self::HEADER_SIZE + self.body.len());
buf.push(self.msg_type as u8);
buf.extend_from_slice(&len.to_le_bytes());
buf.extend_from_slice(&self.body);
buf
}
pub fn from_bytes(data: &[u8]) -> Option<Self> {
if data.len() < Self::HEADER_SIZE {
return None;
}
let msg_type = SyncMessageType::from_u8(data[0])?;
let len = u32::from_le_bytes(data[1..5].try_into().ok()?) as usize;
if data.len() < Self::HEADER_SIZE + len {
return None;
}
let body = data[Self::HEADER_SIZE..Self::HEADER_SIZE + len].to_vec();
Some(Self { msg_type, body })
}
pub fn new_msgpack<T: Serialize>(msg_type: SyncMessageType, value: &T) -> Option<Self> {
let body = rmp_serde::to_vec_named(value).ok()?;
Some(Self { msg_type, body })
}
pub fn encode_or_empty<T: Serialize>(msg_type: SyncMessageType, value: &T) -> Self {
Self::new_msgpack(msg_type, value).unwrap_or(Self {
msg_type,
body: Vec::new(),
})
}
pub fn decode_body<'a, T: Deserialize<'a>>(&'a self) -> Option<T> {
rmp_serde::from_slice(&self.body).ok()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HandshakeMsg {
pub jwt_token: String,
pub vector_clock: HashMap<String, HashMap<String, u64>>,
pub subscribed_shapes: Vec<String>,
pub client_version: String,
#[serde(default)]
pub lite_id: String,
#[serde(default)]
pub epoch: u64,
#[serde(default)]
pub wire_version: u16,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HandshakeAckMsg {
pub success: bool,
pub session_id: String,
pub server_clock: HashMap<String, u64>,
pub error: Option<String>,
#[serde(default)]
pub fork_detected: bool,
#[serde(default)]
pub server_wire_version: u16,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeltaPushMsg {
pub collection: String,
pub document_id: String,
pub delta: Vec<u8>,
pub peer_id: u64,
pub mutation_id: u64,
#[serde(default)]
pub checksum: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeltaAckMsg {
pub mutation_id: u64,
pub lsn: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeltaRejectMsg {
pub mutation_id: u64,
pub reason: String,
pub compensation: Option<CompensationHint>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShapeSubscribeMsg {
pub shape: ShapeDefinition,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShapeSnapshotMsg {
pub shape_id: String,
pub data: Vec<u8>,
pub snapshot_lsn: u64,
pub doc_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShapeDeltaMsg {
pub shape_id: String,
pub collection: String,
pub document_id: String,
pub operation: String,
pub delta: Vec<u8>,
pub lsn: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShapeUnsubscribeMsg {
pub shape_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorClockSyncMsg {
pub clocks: HashMap<String, u64>,
pub sender_id: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResyncRequestMsg {
pub reason: ResyncReason,
pub from_mutation_id: u64,
pub collection: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ResyncReason {
SequenceGap {
expected: u64,
received: u64,
},
ChecksumMismatch {
mutation_id: u64,
},
CorruptedState,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ThrottleMsg {
pub throttle: bool,
pub queue_depth: u64,
pub suggested_rate: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenRefreshMsg {
pub new_token: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenRefreshAckMsg {
pub success: bool,
pub error: Option<String>,
#[serde(default)]
pub expires_in_secs: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PingPongMsg {
pub timestamp_ms: u64,
pub is_pong: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TimeseriesPushMsg {
pub lite_id: String,
pub collection: String,
pub ts_block: Vec<u8>,
pub val_block: Vec<u8>,
pub series_block: Vec<u8>,
pub sample_count: u64,
pub min_ts: i64,
pub max_ts: i64,
pub watermarks: HashMap<u64, u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TimeseriesAckMsg {
pub collection: String,
pub accepted: u64,
pub rejected: u64,
pub lsn: u64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn frame_roundtrip() {
let ping = PingPongMsg {
timestamp_ms: 12345,
is_pong: false,
};
let frame = SyncFrame::new_msgpack(SyncMessageType::PingPong, &ping).unwrap();
let bytes = frame.to_bytes();
let decoded = SyncFrame::from_bytes(&bytes).unwrap();
assert_eq!(decoded.msg_type, SyncMessageType::PingPong);
let decoded_ping: PingPongMsg = decoded.decode_body().unwrap();
assert_eq!(decoded_ping.timestamp_ms, 12345);
assert!(!decoded_ping.is_pong);
}
#[test]
fn handshake_serialization() {
let msg = HandshakeMsg {
jwt_token: "test.jwt.token".into(),
vector_clock: HashMap::new(),
subscribed_shapes: vec!["shape1".into()],
client_version: "0.1.0".into(),
lite_id: String::new(),
epoch: 0,
wire_version: 1,
};
let frame = SyncFrame::new_msgpack(SyncMessageType::Handshake, &msg).unwrap();
let bytes = frame.to_bytes();
assert!(bytes.len() > SyncFrame::HEADER_SIZE);
assert_eq!(bytes[0], 0x01);
}
#[test]
fn delta_reject_with_compensation() {
let reject = DeltaRejectMsg {
mutation_id: 42,
reason: "unique violation".into(),
compensation: Some(CompensationHint::UniqueViolation {
field: "email".into(),
conflicting_value: "alice@example.com".into(),
}),
};
let frame = SyncFrame::new_msgpack(SyncMessageType::DeltaReject, &reject).unwrap();
let decoded: DeltaRejectMsg = SyncFrame::from_bytes(&frame.to_bytes())
.unwrap()
.decode_body()
.unwrap();
assert_eq!(decoded.mutation_id, 42);
assert!(matches!(
decoded.compensation,
Some(CompensationHint::UniqueViolation { .. })
));
}
#[test]
fn message_type_roundtrip() {
for v in [
0x01, 0x02, 0x10, 0x11, 0x12, 0x20, 0x21, 0x22, 0x23, 0x30, 0x40, 0x41, 0x50, 0x52,
0x60, 0x61, 0xFF,
] {
let mt = SyncMessageType::from_u8(v).unwrap();
assert_eq!(mt as u8, v);
}
assert!(SyncMessageType::from_u8(0x99).is_none());
}
#[test]
fn shape_subscribe_roundtrip() {
let msg = ShapeSubscribeMsg {
shape: ShapeDefinition {
shape_id: "s1".into(),
tenant_id: 1,
shape_type: super::super::shape::ShapeType::Vector {
collection: "embeddings".into(),
field_name: None,
},
description: "all embeddings".into(),
field_filter: vec![],
},
};
let frame = SyncFrame::new_msgpack(SyncMessageType::ShapeSubscribe, &msg).unwrap();
let decoded: ShapeSubscribeMsg = SyncFrame::from_bytes(&frame.to_bytes())
.unwrap()
.decode_body()
.unwrap();
assert_eq!(decoded.shape.shape_id, "s1");
}
}