1use reddb_types::Value;
16
17pub const MSG_QUERY: u8 = 0x01;
19pub const MSG_RESULT: u8 = 0x02;
20pub const MSG_ERROR: u8 = 0x03;
21pub const MSG_BULK_INSERT: u8 = 0x04;
22pub const MSG_BULK_OK: u8 = 0x05;
23pub const MSG_BULK_INSERT_BINARY: u8 = 0x06;
24pub const MSG_QUERY_BINARY: u8 = 0x07;
25pub const MSG_BULK_INSERT_PREVALIDATED: u8 = 0x08;
26pub const MSG_BULK_STREAM_START: u8 = 0x09;
27pub const MSG_BULK_STREAM_ROWS: u8 = 0x0A;
28pub const MSG_BULK_STREAM_COMMIT: u8 = 0x0B;
29pub const MSG_BULK_STREAM_ACK: u8 = 0x0C;
30pub const MSG_PREPARE: u8 = 0x0D;
31pub const MSG_PREPARED_OK: u8 = 0x0E;
32pub const MSG_EXECUTE_PREPARED: u8 = 0x0F;
33pub const MSG_DEALLOCATE: u8 = 0x10;
34pub const MSG_DECLARE_CURSOR: u8 = 0x11;
35pub const MSG_CURSOR_OK: u8 = 0x12;
36pub const MSG_FETCH: u8 = 0x13;
37pub const MSG_CURSOR_BATCH: u8 = 0x14;
38pub const MSG_CLOSE_CURSOR: u8 = 0x15;
39
40pub const VAL_NULL: u8 = 0;
42pub const VAL_I64: u8 = 1;
43pub const VAL_F64: u8 = 2;
44pub const VAL_TEXT: u8 = 3;
45pub const VAL_BOOL: u8 = 4;
46pub const VAL_U64: u8 = 5;
47
48#[derive(Debug, Clone, PartialEq)]
49pub enum WireValue {
50 Null,
51 I64(i64),
52 U64(u64),
53 F64(f64),
54 Text(String),
55 Bool(bool),
56 Bytes(Vec<u8>),
57 Timestamp(u64),
58}
59
60impl From<&Value> for WireValue {
67 fn from(value: &Value) -> Self {
68 match value {
69 Value::Null => WireValue::Null,
70 Value::Integer(n) => WireValue::I64(*n),
71 Value::UnsignedInteger(n) => WireValue::U64(*n),
72 Value::Float(f) => WireValue::F64(*f),
73 Value::Text(s) => WireValue::Text(s.to_string()),
74 Value::Blob(bytes) => WireValue::Bytes(bytes.clone()),
75 Value::Boolean(b) => WireValue::Bool(*b),
76 Value::Timestamp(t) => WireValue::Timestamp(*t as u64),
77 _ => WireValue::Null,
78 }
79 }
80}
81
82impl From<Value> for WireValue {
83 fn from(value: Value) -> Self {
84 match value {
85 Value::Null => WireValue::Null,
86 Value::Integer(n) => WireValue::I64(n),
87 Value::UnsignedInteger(n) => WireValue::U64(n),
88 Value::Float(f) => WireValue::F64(f),
89 Value::Text(s) => WireValue::Text(s.to_string()),
90 Value::Blob(bytes) => WireValue::Bytes(bytes),
91 Value::Boolean(b) => WireValue::Bool(b),
92 Value::Timestamp(t) => WireValue::Timestamp(t as u64),
93 _ => WireValue::Null,
94 }
95 }
96}
97
98impl TryFrom<WireValue> for Value {
99 type Error = &'static str;
100
101 fn try_from(value: WireValue) -> Result<Self, Self::Error> {
102 match value {
103 WireValue::Null => Ok(Value::Null),
104 WireValue::I64(n) => Ok(Value::Integer(n)),
105 WireValue::U64(n) => Ok(Value::UnsignedInteger(n)),
106 WireValue::F64(f) => Ok(Value::Float(f)),
107 WireValue::Text(s) => Ok(Value::text(s)),
108 WireValue::Bool(b) => Ok(Value::Boolean(b)),
109 WireValue::Bytes(bytes) => Ok(Value::Blob(bytes)),
110 WireValue::Timestamp(t) => {
111 let timestamp = i64::try_from(t).map_err(|_| "timestamp exceeds i64 range")?;
112 Ok(Value::Timestamp(timestamp))
113 }
114 }
115 }
116}
117
118#[inline]
120pub fn write_frame_header(buf: &mut Vec<u8>, msg_type: u8, payload_len: u32) {
121 let total = payload_len + 1; buf.extend_from_slice(&total.to_le_bytes());
123 buf.push(msg_type);
124}
125
126pub fn build_legacy_frame(msg_type: u8, payload: &[u8]) -> Vec<u8> {
127 let mut out = Vec::with_capacity(5 + payload.len());
128 write_frame_header(&mut out, msg_type, payload.len() as u32);
129 out.extend_from_slice(payload);
130 out
131}
132
133pub fn build_legacy_result_frame(payload: &[u8]) -> Vec<u8> {
134 build_legacy_frame(MSG_RESULT, payload)
135}
136
137pub fn build_legacy_error_frame(message: &[u8]) -> Vec<u8> {
138 build_legacy_frame(MSG_ERROR, message)
139}
140
141pub fn build_legacy_bulk_ok_frame(payload: &[u8]) -> Vec<u8> {
142 build_legacy_frame(MSG_BULK_OK, payload)
143}
144
145pub fn build_legacy_bulk_stream_ack_frame() -> Vec<u8> {
146 build_legacy_frame(MSG_BULK_STREAM_ACK, &[])
147}
148
149pub fn build_legacy_prepared_ok_frame(payload: &[u8]) -> Vec<u8> {
150 build_legacy_frame(MSG_PREPARED_OK, payload)
151}
152
153pub fn build_legacy_cursor_ok_frame(payload: &[u8]) -> Vec<u8> {
154 build_legacy_frame(MSG_CURSOR_OK, payload)
155}
156
157pub fn build_legacy_cursor_batch_frame(payload: &[u8]) -> Vec<u8> {
158 build_legacy_frame(MSG_CURSOR_BATCH, payload)
159}
160
161#[inline]
162pub fn encode_value(buf: &mut Vec<u8>, value: &WireValue) {
163 match value {
164 WireValue::Null | WireValue::Bytes(_) => buf.push(VAL_NULL),
165 WireValue::I64(n) => {
166 buf.push(VAL_I64);
167 buf.extend_from_slice(&n.to_le_bytes());
168 }
169 WireValue::U64(n) | WireValue::Timestamp(n) => {
170 buf.push(VAL_U64);
171 buf.extend_from_slice(&n.to_le_bytes());
172 }
173 WireValue::F64(f) => {
174 buf.push(VAL_F64);
175 buf.extend_from_slice(&f.to_le_bytes());
176 }
177 WireValue::Text(s) => {
178 buf.push(VAL_TEXT);
179 let bytes = s.as_bytes();
180 buf.extend_from_slice(&(bytes.len() as u32).to_le_bytes());
181 buf.extend_from_slice(bytes);
182 }
183 WireValue::Bool(b) => {
184 buf.push(VAL_BOOL);
185 buf.push(*b as u8);
186 }
187 }
188}
189
190#[inline]
191pub fn decode_value(data: &[u8], pos: &mut usize) -> WireValue {
192 try_decode_value(data, pos).unwrap_or(WireValue::Null)
193}
194
195#[inline]
196pub fn try_decode_value(data: &[u8], pos: &mut usize) -> Result<WireValue, &'static str> {
197 if *pos >= data.len() {
198 return Err("missing value tag");
199 }
200
201 let tag = data[*pos];
202 *pos += 1;
203
204 match tag {
205 VAL_NULL => Ok(WireValue::Null),
206 VAL_I64 => Ok(WireValue::I64(i64::from_le_bytes(read_array::<8>(
207 data,
208 pos,
209 "truncated i64 value",
210 )?))),
211 VAL_U64 => Ok(WireValue::U64(u64::from_le_bytes(read_array::<8>(
212 data,
213 pos,
214 "truncated u64 value",
215 )?))),
216 VAL_F64 => Ok(WireValue::F64(f64::from_le_bytes(read_array::<8>(
217 data,
218 pos,
219 "truncated f64 value",
220 )?))),
221 VAL_TEXT => {
222 let len =
223 u32::from_le_bytes(read_array::<4>(data, pos, "truncated text length")?) as usize;
224 let bytes = read_bytes(data, pos, len, "truncated text value")?;
225 let cow = std::string::String::from_utf8_lossy(bytes);
226 Ok(WireValue::Text(cow.into_owned()))
227 }
228 VAL_BOOL => {
229 let bytes = read_bytes(data, pos, 1, "truncated bool value")?;
230 Ok(WireValue::Bool(bytes[0] != 0))
231 }
232 _ => Err("unknown value tag"),
233 }
234}
235
236#[inline]
237fn read_bytes<'a>(
238 data: &'a [u8],
239 pos: &mut usize,
240 len: usize,
241 err: &'static str,
242) -> Result<&'a [u8], &'static str> {
243 let end = pos.saturating_add(len);
244 if end > data.len() {
245 return Err(err);
246 }
247 let bytes = &data[*pos..end];
248 *pos = end;
249 Ok(bytes)
250}
251
252#[inline]
253fn read_array<const N: usize>(
254 data: &[u8],
255 pos: &mut usize,
256 err: &'static str,
257) -> Result<[u8; N], &'static str> {
258 let bytes = read_bytes(data, pos, N, err)?;
259 let mut array = [0u8; N];
260 array.copy_from_slice(bytes);
261 Ok(array)
262}
263
264#[inline]
265pub fn encode_column_name(buf: &mut Vec<u8>, name: &str) {
266 let bytes = name.as_bytes();
267 buf.extend_from_slice(&(bytes.len() as u16).to_le_bytes());
268 buf.extend_from_slice(bytes);
269}
270
271pub fn encode_result_payload_header<'a, I>(buf: &mut Vec<u8>, columns: I, row_count: u32) -> usize
272where
273 I: IntoIterator<Item = &'a str>,
274 I::IntoIter: ExactSizeIterator,
275{
276 let columns = columns.into_iter();
277 buf.extend_from_slice(&(columns.len() as u16).to_le_bytes());
278 for column in columns {
279 encode_column_name(buf, column);
280 }
281 let row_count_offset = buf.len();
282 buf.extend_from_slice(&row_count.to_le_bytes());
283 row_count_offset
284}
285
286pub fn set_result_payload_row_count(
287 buf: &mut [u8],
288 row_count_offset: usize,
289 row_count: u32,
290) -> Result<(), &'static str> {
291 let end = row_count_offset.saturating_add(4);
292 if end > buf.len() {
293 return Err("result payload row-count offset out of bounds");
294 }
295 buf[row_count_offset..end].copy_from_slice(&row_count.to_le_bytes());
296 Ok(())
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 #[test]
308 fn value_wirevalue_field_mapping_round_trips() {
309 let cases = [
310 (Value::Null, WireValue::Null),
311 (Value::Integer(-7), WireValue::I64(-7)),
312 (Value::UnsignedInteger(9), WireValue::U64(9)),
313 (Value::Float(1.5), WireValue::F64(1.5)),
314 (Value::text("hi"), WireValue::Text("hi".to_string())),
315 (Value::Blob(vec![1, 2, 3]), WireValue::Bytes(vec![1, 2, 3])),
316 (Value::Boolean(true), WireValue::Bool(true)),
317 (Value::Timestamp(42), WireValue::Timestamp(42)),
318 ];
319 for (value, wire) in cases {
320 assert_eq!(WireValue::from(value.clone()), wire);
322 assert_eq!(WireValue::from(&value), wire);
323 assert_eq!(Value::try_from(wire.clone()), Ok(value));
325 }
326 }
327
328 #[test]
329 fn wirevalue_timestamp_rejects_i64_overflow() {
330 let overflow = WireValue::Timestamp(u64::MAX);
331 assert_eq!(
332 Value::try_from(overflow),
333 Err("timestamp exceeds i64 range")
334 );
335 }
336
337 #[test]
338 fn frame_header_keeps_legacy_length_shape() {
339 let mut out = Vec::new();
340 write_frame_header(&mut out, MSG_RESULT, 10);
341 assert_eq!(out, [11, 0, 0, 0, MSG_RESULT]);
342 }
343
344 #[test]
345 fn legacy_frame_builders_wrap_payloads() {
346 assert_eq!(
347 build_legacy_result_frame(b"ok"),
348 [3, 0, 0, 0, MSG_RESULT, b'o', b'k']
349 );
350 assert_eq!(
351 build_legacy_error_frame(b"no"),
352 [3, 0, 0, 0, MSG_ERROR, b'n', b'o']
353 );
354 assert_eq!(
355 build_legacy_bulk_ok_frame(b"\x02\0\0\0\0\0\0\0"),
356 [9, 0, 0, 0, MSG_BULK_OK, 2, 0, 0, 0, 0, 0, 0, 0]
357 );
358 assert_eq!(
359 build_legacy_bulk_stream_ack_frame(),
360 [1, 0, 0, 0, MSG_BULK_STREAM_ACK]
361 );
362 assert_eq!(
363 build_legacy_prepared_ok_frame(b"p"),
364 [2, 0, 0, 0, MSG_PREPARED_OK, b'p']
365 );
366 assert_eq!(
367 build_legacy_cursor_ok_frame(b"c"),
368 [2, 0, 0, 0, MSG_CURSOR_OK, b'c']
369 );
370 assert_eq!(
371 build_legacy_cursor_batch_frame(b"b"),
372 [2, 0, 0, 0, MSG_CURSOR_BATCH, b'b']
373 );
374 }
375
376 #[test]
377 fn wire_values_round_trip_legacy_tags() {
378 let values = [
379 WireValue::Null,
380 WireValue::I64(-7),
381 WireValue::U64(42),
382 WireValue::F64(3.5),
383 WireValue::Text("hello".to_string()),
384 WireValue::Bool(true),
385 WireValue::Timestamp(1234),
386 ];
387
388 let mut bytes = Vec::new();
389 for value in &values {
390 encode_value(&mut bytes, value);
391 }
392
393 let mut pos = 0;
394 assert_eq!(try_decode_value(&bytes, &mut pos), Ok(WireValue::Null));
395 assert_eq!(try_decode_value(&bytes, &mut pos), Ok(WireValue::I64(-7)));
396 assert_eq!(try_decode_value(&bytes, &mut pos), Ok(WireValue::U64(42)));
397 assert_eq!(try_decode_value(&bytes, &mut pos), Ok(WireValue::F64(3.5)));
398 assert_eq!(
399 try_decode_value(&bytes, &mut pos),
400 Ok(WireValue::Text("hello".to_string()))
401 );
402 assert_eq!(
403 try_decode_value(&bytes, &mut pos),
404 Ok(WireValue::Bool(true))
405 );
406 assert_eq!(try_decode_value(&bytes, &mut pos), Ok(WireValue::U64(1234)));
407 assert_eq!(pos, bytes.len());
408 }
409
410 #[test]
411 fn result_payload_header_encodes_columns_and_row_count() {
412 let mut bytes = Vec::new();
413 let row_count_offset = encode_result_payload_header(&mut bytes, ["id", "name"], 3);
414
415 assert_eq!(
416 bytes,
417 [
418 2, 0, 2, 0, b'i', b'd', 4, 0, b'n', b'a', b'm', b'e', 3, 0, 0, 0, ]
423 );
424 assert_eq!(row_count_offset, bytes.len() - 4);
425 set_result_payload_row_count(&mut bytes, row_count_offset, 5).unwrap();
426 assert_eq!(
427 &bytes[row_count_offset..row_count_offset + 4],
428 &[5, 0, 0, 0]
429 );
430 }
431}