1pub mod auth;
4pub mod batch;
5pub mod frames;
6pub mod handshake;
7pub mod opcodes;
8pub mod request_fields;
9pub mod text_fields;
10
11pub use auth::{AuthMethod, AuthResponse};
12pub use batch::{BatchDocument, BatchVector};
13pub use frames::{ErrorPayload, NativeRequest, NativeResponse};
14pub use handshake::{
15 CAP_COLUMNAR, CAP_CRDT, CAP_FTS, CAP_GRAPHRAG, CAP_MSGPACK, CAP_SPATIAL, CAP_STREAMING,
16 CAP_TIMESERIES, DEFAULT_NATIVE_PORT, FRAME_HEADER_LEN, HELLO_ACK_MAGIC, HELLO_ERROR_MAGIC,
17 HELLO_ERROR_MAGIC_U32, HELLO_MAGIC, HelloAckFrame, HelloErrorCode, HelloErrorFrame, HelloFrame,
18 Limits, MAX_FRAME_SIZE, PROTO_VERSION, PROTO_VERSION_MAX, PROTO_VERSION_MIN,
19};
20pub use opcodes::{OpCode, ResponseStatus, UnknownOpCode};
21pub use request_fields::RequestFields;
22pub use text_fields::TextFields;
23
24#[cfg(test)]
27mod tests {
28 use super::*;
29 use crate::value::Value;
30
31 #[test]
32 fn opcode_repr() {
33 assert_eq!(OpCode::Auth as u8, 0x01);
34 assert_eq!(OpCode::Sql as u8, 0x20);
35 assert_eq!(OpCode::Begin as u8, 0x40);
36 assert_eq!(OpCode::GraphHop as u8, 0x50);
37 assert_eq!(OpCode::TextSearch as u8, 0x60);
38 assert_eq!(OpCode::VectorBatchInsert as u8, 0x70);
39 }
40
41 #[test]
42 fn opcode_is_write() {
43 assert!(OpCode::PointPut.is_write());
44 assert!(OpCode::PointDelete.is_write());
45 assert!(OpCode::CrdtApply.is_write());
46 assert!(OpCode::EdgePut.is_write());
47 assert!(!OpCode::PointGet.is_write());
48 assert!(!OpCode::Sql.is_write());
49 assert!(!OpCode::VectorSearch.is_write());
50 assert!(!OpCode::Ping.is_write());
51 }
52
53 #[test]
54 fn response_status_repr() {
55 assert_eq!(ResponseStatus::Ok as u8, 0);
56 assert_eq!(ResponseStatus::Partial as u8, 1);
57 assert_eq!(ResponseStatus::Error as u8, 2);
58 }
59
60 #[test]
61 fn native_response_ok() {
62 let r = NativeResponse::ok(42);
63 assert_eq!(r.seq, 42);
64 assert_eq!(r.status, ResponseStatus::Ok);
65 assert!(r.error.is_none());
66 }
67
68 #[test]
69 fn native_response_error() {
70 let r = NativeResponse::error(1, "42P01", "collection not found");
71 assert_eq!(r.status, ResponseStatus::Error);
72 let e = r.error.unwrap();
73 assert_eq!(e.code, "42P01");
74 assert_eq!(e.message, "collection not found");
75 }
76
77 #[test]
78 fn native_response_from_query_result() {
79 let qr = crate::result::QueryResult {
80 columns: vec!["id".into(), "name".into()],
81 rows: vec![vec![
82 Value::String("u1".into()),
83 Value::String("Alice".into()),
84 ]],
85 rows_affected: 0,
86 };
87 let r = NativeResponse::from_query_result(5, qr, 100);
88 assert_eq!(r.seq, 5);
89 assert_eq!(r.watermark_lsn, 100);
90 assert_eq!(r.columns.as_ref().unwrap().len(), 2);
91 assert_eq!(r.rows.as_ref().unwrap().len(), 1);
92 }
93
94 #[test]
95 fn native_response_status_row() {
96 let r = NativeResponse::status_row(3, "OK");
97 assert_eq!(r.columns.as_ref().unwrap(), &["status"]);
98 assert_eq!(r.rows.as_ref().unwrap()[0][0].as_str(), Some("OK"));
99 }
100
101 #[test]
102 fn msgpack_roundtrip_request() {
103 let req = NativeRequest {
104 op: OpCode::Sql,
105 seq: 1,
106 fields: RequestFields::Text(TextFields {
107 sql: Some("SELECT 1".into()),
108 ..Default::default()
109 }),
110 };
111 let bytes = zerompk::to_msgpack_vec(&req).unwrap();
112 let decoded: NativeRequest = zerompk::from_msgpack(&bytes).unwrap();
113 assert_eq!(decoded.op, OpCode::Sql);
114 assert_eq!(decoded.seq, 1);
115 }
116
117 #[test]
118 fn msgpack_roundtrip_response() {
119 let resp = NativeResponse::from_query_result(
120 7,
121 crate::result::QueryResult {
122 columns: vec!["x".into()],
123 rows: vec![vec![Value::Integer(42)]],
124 rows_affected: 0,
125 },
126 99,
127 );
128 let bytes = zerompk::to_msgpack_vec(&resp).unwrap();
129 let decoded: NativeResponse = zerompk::from_msgpack(&bytes).unwrap();
130 assert_eq!(decoded.seq, 7);
131 assert_eq!(decoded.watermark_lsn, 99);
132 assert_eq!(decoded.rows.unwrap()[0][0].as_i64(), Some(42));
133 }
134
135 #[test]
136 fn auth_method_variants() {
137 let trust = AuthMethod::Trust {
138 username: "admin".into(),
139 };
140 let bytes = zerompk::to_msgpack_vec(&trust).unwrap();
141 let decoded: AuthMethod = zerompk::from_msgpack(&bytes).unwrap();
142 match decoded {
143 AuthMethod::Trust { username } => assert_eq!(username, "admin"),
144 _ => panic!("expected Trust variant"),
145 }
146
147 let pw = AuthMethod::Password {
148 username: "user".into(),
149 password: "secret".into(),
150 };
151 let bytes = zerompk::to_msgpack_vec(&pw).unwrap();
152 let decoded: AuthMethod = zerompk::from_msgpack(&bytes).unwrap();
153 match decoded {
154 AuthMethod::Password { username, password } => {
155 assert_eq!(username, "user");
156 assert_eq!(password, "secret");
157 }
158 _ => panic!("expected Password variant"),
159 }
160 }
161
162 #[test]
163 fn hello_frame_roundtrip() {
164 let frame = HelloFrame {
165 proto_min: 1,
166 proto_max: 3,
167 capabilities: CAP_STREAMING | CAP_GRAPHRAG | CAP_FTS,
168 };
169 let buf = frame.encode();
170 let decoded = HelloFrame::decode(&buf).expect("decode failed");
171 assert_eq!(decoded, frame);
172 }
173
174 #[test]
175 fn hello_frame_bad_magic() {
176 let mut buf = HelloFrame {
177 proto_min: 1,
178 proto_max: 1,
179 capabilities: 0,
180 }
181 .encode();
182 buf[0] = 0xFF;
183 assert!(HelloFrame::decode(&buf).is_none());
184 }
185
186 #[test]
187 fn hello_ack_frame_roundtrip_all_limits() {
188 let frame = HelloAckFrame {
189 proto_version: 1,
190 capabilities: CAP_STREAMING | CAP_CRDT,
191 server_version: "0.1.0-dev".into(),
192 limits: Limits {
193 max_vector_dim: Some(1536),
194 max_top_k: Some(1000),
195 max_scan_limit: Some(10_000),
196 max_batch_size: Some(512),
197 max_crdt_delta_bytes: Some(1 << 20),
198 max_query_text_bytes: Some(4096),
199 max_graph_depth: Some(16),
200 },
201 };
202 let enc = frame.encode();
203 let decoded = HelloAckFrame::decode(&enc).expect("decode failed");
204 assert_eq!(decoded, frame);
205 }
206
207 #[test]
208 fn hello_ack_frame_roundtrip_some_limits() {
209 let frame = HelloAckFrame {
210 proto_version: 1,
211 capabilities: 0,
212 server_version: "1.0.0".into(),
213 limits: Limits {
214 max_vector_dim: Some(768),
215 max_top_k: None,
216 max_scan_limit: None,
217 max_batch_size: None,
218 max_crdt_delta_bytes: None,
219 max_query_text_bytes: None,
220 max_graph_depth: None,
221 },
222 };
223 let enc = frame.encode();
224 let decoded = HelloAckFrame::decode(&enc).expect("decode failed");
225 assert_eq!(decoded, frame);
226 }
227
228 #[test]
229 fn hello_ack_frame_roundtrip_no_limits() {
230 let frame = HelloAckFrame {
231 proto_version: 1,
232 capabilities: CAP_STREAMING,
233 server_version: "0.2.0".into(),
234 limits: Limits::default(),
235 };
236 let enc = frame.encode();
237 let decoded = HelloAckFrame::decode(&enc).expect("decode failed");
238 assert_eq!(decoded, frame);
239 }
240
241 #[test]
242 fn hello_ack_bad_magic() {
243 let frame = HelloAckFrame {
244 proto_version: 1,
245 capabilities: 0,
246 server_version: "x".into(),
247 limits: Limits::default(),
248 };
249 let mut enc = frame.encode();
250 enc[0] = 0xFF;
251 assert!(HelloAckFrame::decode(&enc).is_none());
252 }
253
254 #[test]
255 fn cap_bits_non_overlapping() {
256 let all = CAP_STREAMING
257 | CAP_GRAPHRAG
258 | CAP_FTS
259 | CAP_CRDT
260 | CAP_SPATIAL
261 | CAP_TIMESERIES
262 | CAP_COLUMNAR;
263 assert_eq!(all.count_ones(), 7);
264 }
265
266 #[test]
268 fn vector_id_overflow_rejected_in_plan_builder() {
269 let wire_id: u64 = u32::MAX as u64 + 1;
271 let result: Result<u32, _> = wire_id.try_into();
272 assert!(result.is_err(), "u64 > u32::MAX must not fit in u32");
273
274 let wire_ok: u64 = u32::MAX as u64;
276 let result2: Result<u32, _> = wire_ok.try_into();
277 assert_eq!(result2.unwrap(), u32::MAX);
278 }
279}