yantrikdb_protocol/
lib.rs1pub mod codec;
2pub mod error;
3pub mod frame;
4pub mod messages;
5pub mod opcodes;
6
7pub use codec::YantrikCodec;
8pub use error::ProtocolError;
9pub use frame::Frame;
10pub use opcodes::OpCode;
11
12pub fn pack<T: serde::Serialize>(msg: &T) -> Result<bytes::Bytes, ProtocolError> {
14 let data = rmp_serde::to_vec_named(msg)?;
15 Ok(bytes::Bytes::from(data))
16}
17
18pub fn unpack<'de, T: serde::Deserialize<'de>>(data: &'de [u8]) -> Result<T, ProtocolError> {
20 Ok(rmp_serde::from_slice(data)?)
21}
22
23pub fn unpack_frame<T: serde::de::DeserializeOwned>(frame: &Frame) -> Result<T, ProtocolError> {
26 if frame.is_compressed() {
27 let decompressed = zstd::decode_all(&frame.payload[..])
28 .map_err(|e| ProtocolError::Io(e))?;
29 Ok(rmp_serde::from_slice(&decompressed)?)
30 } else {
31 Ok(rmp_serde::from_slice(&frame.payload)?)
32 }
33}
34
35pub fn pack_compressed<T: serde::Serialize>(msg: &T) -> Result<bytes::Bytes, ProtocolError> {
37 let data = rmp_serde::to_vec_named(msg)?;
38 let compressed = zstd::encode_all(data.as_slice(), 3)
39 .map_err(|e| ProtocolError::Io(e))?;
40 Ok(bytes::Bytes::from(compressed))
41}
42
43pub fn make_frame_auto_compress<T: serde::Serialize>(
45 opcode: OpCode,
46 stream_id: u32,
47 msg: &T,
48 min_size_bytes: usize,
49) -> Result<Frame, ProtocolError> {
50 let raw = pack(msg)?;
51 if raw.len() < min_size_bytes {
52 return Ok(Frame::new(opcode, stream_id, raw));
53 }
54 let compressed = zstd::encode_all(&raw[..], 3).map_err(|e| ProtocolError::Io(e))?;
55 if compressed.len() < raw.len() {
57 Ok(Frame::new(opcode, stream_id, bytes::Bytes::from(compressed)).with_compression())
58 } else {
59 Ok(Frame::new(opcode, stream_id, raw))
60 }
61}
62
63pub fn make_frame<T: serde::Serialize>(
65 opcode: OpCode,
66 stream_id: u32,
67 msg: &T,
68) -> Result<Frame, ProtocolError> {
69 let payload = pack(msg)?;
70 Ok(Frame::new(opcode, stream_id, payload))
71}
72
73pub fn make_error(
75 stream_id: u32,
76 code: u16,
77 message: impl Into<String>,
78) -> Result<Frame, ProtocolError> {
79 make_frame(
80 OpCode::Error,
81 stream_id,
82 &messages::ErrorResponse {
83 code,
84 message: message.into(),
85 details: None,
86 },
87 )
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93 use messages::RememberRequest;
94
95 #[test]
96 fn pack_unpack_roundtrip() {
97 let req = RememberRequest {
98 text: "Alice leads engineering".into(),
99 memory_type: "semantic".into(),
100 importance: 0.9,
101 valence: 0.0,
102 half_life: 168.0,
103 metadata: serde_json::json!({}),
104 namespace: "default".into(),
105 certainty: 1.0,
106 domain: "work".into(),
107 source: "user".into(),
108 emotional_state: None,
109 embedding: None,
110 };
111
112 let packed = pack(&req).unwrap();
113 let unpacked: RememberRequest = unpack(&packed).unwrap();
114
115 assert_eq!(unpacked.text, "Alice leads engineering");
116 assert_eq!(unpacked.importance, 0.9);
117 assert_eq!(unpacked.domain, "work");
118 }
119
120 #[test]
121 fn make_frame_roundtrip() {
122 let req = messages::RecallRequest {
123 query: "who leads engineering?".into(),
124 top_k: 5,
125 memory_type: None,
126 include_consolidated: false,
127 expand_entities: true,
128 namespace: None,
129 domain: None,
130 source: None,
131 query_embedding: None,
132 };
133
134 let frame = make_frame(OpCode::Recall, 7, &req).unwrap();
135 assert_eq!(frame.opcode, OpCode::Recall);
136 assert_eq!(frame.stream_id, 7);
137
138 let decoded: messages::RecallRequest = unpack(&frame.payload).unwrap();
139 assert_eq!(decoded.query, "who leads engineering?");
140 assert_eq!(decoded.top_k, 5);
141 }
142
143 #[test]
144 fn make_error_frame() {
145 let frame =
146 make_error(0, messages::error_codes::AUTH_REQUIRED, "not authenticated").unwrap();
147 assert_eq!(frame.opcode, OpCode::Error);
148
149 let err: messages::ErrorResponse = unpack(&frame.payload).unwrap();
150 assert_eq!(err.code, 1000);
151 assert_eq!(err.message, "not authenticated");
152 }
153}