Skip to main content

geode_client/
proto.rs

1//! Protobuf encoding/decoding for Geode wire protocol.
2//!
3//! Wire format: 4-byte Big Endian length prefix + protobuf message body.
4//!
5//! This module provides manual protobuf encoding/decoding that is compatible
6//! with the geode.proto schema. It supports both QUIC (length-prefixed) and
7//! gRPC transports.
8
9use std::collections::HashMap;
10
11use crate::error::{Error, Result};
12
13// Wire type constants for protobuf encoding.
14const WIRE_VARINT: u32 = 0;
15const WIRE_FIXED64: u32 = 1;
16const WIRE_BYTES: u32 = 2;
17const WIRE_FIXED32: u32 = 5;
18
19// =============================================================================
20// Message Types (matching geode.proto)
21// =============================================================================
22
23/// Authentication handshake message.
24#[derive(Debug, Clone, Default, PartialEq)]
25pub struct HelloRequest {
26    pub username: String,
27    pub password: String,
28    pub tenant_id: String,
29}
30
31/// Server's response to HelloRequest.
32#[derive(Debug, Clone, Default, PartialEq)]
33pub struct HelloResponse {
34    pub success: bool,
35    pub session_id: String,
36    pub error_message: String,
37    pub capabilities: Vec<String>,
38}
39
40/// Request to execute a GQL query.
41#[derive(Debug, Clone, Default, PartialEq)]
42pub struct ExecuteRequest {
43    pub session_id: String,
44    pub query: String,
45    pub parameters: HashMap<String, String>,
46}
47
48/// Column definition in result schema.
49#[derive(Debug, Clone, Default, PartialEq)]
50pub struct ColumnDefinition {
51    pub name: String,
52    pub col_type: String,
53}
54
55/// Contains column definitions.
56#[derive(Debug, Clone, Default, PartialEq)]
57pub struct SchemaDefinition {
58    pub columns: Vec<ColumnDefinition>,
59}
60
61/// Protobuf value (oneof).
62#[derive(Debug, Clone, Default, PartialEq)]
63pub struct Value {
64    pub string_val: Option<String>,
65    pub int_val: Option<i64>,
66    pub double_val: Option<f64>,
67    pub bool_val: Option<bool>,
68    pub null_val: bool,
69}
70
71/// Row is a list of values.
72#[derive(Debug, Clone, Default, PartialEq)]
73pub struct Row {
74    pub values: Vec<Value>,
75}
76
77/// Contains rows and pagination info.
78#[derive(Debug, Clone, Default, PartialEq)]
79pub struct DataPage {
80    pub rows: Vec<Row>,
81    pub last_page: bool,
82}
83
84/// Query error.
85#[derive(Debug, Clone, Default, PartialEq)]
86pub struct ProtoError {
87    pub code: String,
88    pub message: String,
89    pub error_type: String,
90}
91
92/// Timing information.
93#[derive(Debug, Clone, Default, PartialEq)]
94pub struct ExecutionMetrics {
95    pub parse_duration_ns: i64,
96    pub plan_duration_ns: i64,
97    pub execute_duration_ns: i64,
98    pub total_duration_ns: i64,
99}
100
101/// Empty keep-alive message.
102#[derive(Debug, Clone, Default, PartialEq)]
103pub struct Heartbeat;
104
105/// Oneof for different response types.
106#[derive(Debug, Clone, Default, PartialEq)]
107pub struct ExecutionResponse {
108    pub schema: Option<SchemaDefinition>,
109    pub page: Option<DataPage>,
110    pub error: Option<ProtoError>,
111    pub metrics: Option<ExecutionMetrics>,
112    pub heartbeat: Option<Heartbeat>,
113}
114
115/// Empty ping request.
116#[derive(Debug, Clone, Default, PartialEq)]
117pub struct PingRequest;
118
119/// Ping success response.
120#[derive(Debug, Clone, Default, PartialEq)]
121pub struct PingResponse {
122    pub ok: bool,
123}
124
125/// Start a transaction.
126#[derive(Debug, Clone, Default, PartialEq)]
127pub struct BeginRequest {
128    pub read_only: bool,
129}
130
131/// Confirms transaction start.
132#[derive(Debug, Clone, Default, PartialEq)]
133pub struct BeginResponse {
134    pub session_id: String,
135    pub tx_id: String,
136}
137
138/// Commit a transaction.
139#[derive(Debug, Clone, Default, PartialEq)]
140pub struct CommitRequest;
141
142/// Confirms commit.
143#[derive(Debug, Clone, Default, PartialEq)]
144pub struct CommitResponse {
145    pub success: bool,
146}
147
148/// Roll back a transaction.
149#[derive(Debug, Clone, Default, PartialEq)]
150pub struct RollbackRequest;
151
152/// Confirms rollback.
153#[derive(Debug, Clone, Default, PartialEq)]
154pub struct RollbackResponse {
155    pub success: bool,
156}
157
158/// Top-level client message (oneof).
159#[derive(Debug, Clone, Default, PartialEq)]
160pub struct QuicClientMessage {
161    pub hello: Option<HelloRequest>,
162    pub execute: Option<ExecuteRequest>,
163    pub ping: Option<PingRequest>,
164    pub begin: Option<BeginRequest>,
165    pub commit: Option<CommitRequest>,
166    pub rollback: Option<RollbackRequest>,
167}
168
169/// Top-level server message (oneof).
170#[derive(Debug, Clone, Default, PartialEq)]
171pub struct QuicServerMessage {
172    pub hello: Option<HelloResponse>,
173    pub execute: Option<ExecutionResponse>,
174    pub ping: Option<PingResponse>,
175    pub begin: Option<BeginResponse>,
176    pub commit: Option<CommitResponse>,
177    pub rollback: Option<RollbackResponse>,
178}
179
180// =============================================================================
181// Encoding helpers
182// =============================================================================
183
184fn encode_varint(v: u64) -> Vec<u8> {
185    let mut result = Vec::new();
186    let mut val = v;
187    loop {
188        let mut byte = (val & 0x7F) as u8;
189        val >>= 7;
190        if val != 0 {
191            byte |= 0x80;
192        }
193        result.push(byte);
194        if val == 0 {
195            break;
196        }
197    }
198    result
199}
200
201fn encode_tag(field_num: u32, wire_type: u32) -> Vec<u8> {
202    encode_varint(((field_num << 3) | wire_type) as u64)
203}
204
205fn encode_string(field_num: u32, s: &str) -> Vec<u8> {
206    if s.is_empty() {
207        return Vec::new();
208    }
209    let mut result = encode_tag(field_num, WIRE_BYTES);
210    result.extend(encode_varint(s.len() as u64));
211    result.extend(s.as_bytes());
212    result
213}
214
215/// Reserved for future use with byte array fields.
216#[allow(dead_code)]
217fn encode_bytes(field_num: u32, data: &[u8]) -> Vec<u8> {
218    if data.is_empty() {
219        return Vec::new();
220    }
221    let mut result = encode_tag(field_num, WIRE_BYTES);
222    result.extend(encode_varint(data.len() as u64));
223    result.extend(data);
224    result
225}
226
227fn encode_bool(field_num: u32, v: bool) -> Vec<u8> {
228    if !v {
229        return Vec::new();
230    }
231    let mut result = encode_tag(field_num, WIRE_VARINT);
232    result.push(1);
233    result
234}
235
236#[allow(dead_code)]
237fn encode_int64(field_num: u32, v: i64) -> Vec<u8> {
238    if v == 0 {
239        return Vec::new();
240    }
241    let mut result = encode_tag(field_num, WIRE_VARINT);
242    result.extend(encode_varint(v as u64));
243    result
244}
245
246fn encode_submessage(field_num: u32, data: &[u8]) -> Vec<u8> {
247    if data.is_empty() {
248        return Vec::new();
249    }
250    let mut result = encode_tag(field_num, WIRE_BYTES);
251    result.extend(encode_varint(data.len() as u64));
252    result.extend(data);
253    result
254}
255
256// =============================================================================
257// Message encoding
258// =============================================================================
259
260fn encode_hello_request(req: &HelloRequest) -> Vec<u8> {
261    let mut result = Vec::new();
262    result.extend(encode_string(1, &req.username));
263    result.extend(encode_string(2, &req.password));
264    result.extend(encode_string(3, &req.tenant_id));
265    result
266}
267
268fn encode_execute_request(req: &ExecuteRequest) -> Vec<u8> {
269    let mut result = Vec::new();
270    result.extend(encode_string(1, &req.session_id));
271    result.extend(encode_string(2, &req.query));
272    // Encode map as repeated key-value pairs
273    for (k, v) in &req.parameters {
274        let mut entry = encode_string(1, k);
275        entry.extend(encode_string(2, v));
276        result.extend(encode_submessage(3, &entry));
277    }
278    result
279}
280
281fn encode_ping_request(_req: &PingRequest) -> Vec<u8> {
282    Vec::new()
283}
284
285fn encode_begin_request(req: &BeginRequest) -> Vec<u8> {
286    let mut result = Vec::new();
287    result.extend(encode_bool(1, req.read_only));
288    result
289}
290
291fn encode_commit_request(_req: &CommitRequest) -> Vec<u8> {
292    Vec::new()
293}
294
295fn encode_rollback_request(_req: &RollbackRequest) -> Vec<u8> {
296    Vec::new()
297}
298
299/// Encode a QuicClientMessage to protobuf bytes.
300pub fn encode_quic_client_message(msg: &QuicClientMessage) -> Vec<u8> {
301    let mut result = Vec::new();
302    if let Some(ref hello) = msg.hello {
303        result.extend(encode_submessage(1, &encode_hello_request(hello)));
304    }
305    if let Some(ref execute) = msg.execute {
306        result.extend(encode_submessage(2, &encode_execute_request(execute)));
307    }
308    if let Some(ref ping) = msg.ping {
309        result.extend(encode_submessage(3, &encode_ping_request(ping)));
310    }
311    if let Some(ref begin) = msg.begin {
312        result.extend(encode_submessage(6, &encode_begin_request(begin)));
313    }
314    if let Some(ref commit) = msg.commit {
315        result.extend(encode_submessage(7, &encode_commit_request(commit)));
316    }
317    if let Some(ref rollback) = msg.rollback {
318        result.extend(encode_submessage(8, &encode_rollback_request(rollback)));
319    }
320    result
321}
322
323// =============================================================================
324// Decoding helpers
325// =============================================================================
326
327fn decode_varint(data: &[u8], pos: &mut usize) -> Result<u64> {
328    let mut result: u64 = 0;
329    let mut shift = 0;
330    while *pos < data.len() {
331        let b = data[*pos];
332        *pos += 1;
333        result |= ((b & 0x7F) as u64) << shift;
334        if b & 0x80 == 0 {
335            return Ok(result);
336        }
337        shift += 7;
338        if shift > 63 {
339            return Err(Error::protocol("Varint too long"));
340        }
341    }
342    Err(Error::protocol("Truncated varint"))
343}
344
345fn decode_tag(data: &[u8], pos: &mut usize) -> Result<(u32, u32)> {
346    let tag = decode_varint(data, pos)?;
347    Ok(((tag >> 3) as u32, (tag & 0x7) as u32))
348}
349
350fn skip_field(data: &[u8], pos: &mut usize, wire_type: u32) -> Result<()> {
351    match wire_type {
352        WIRE_VARINT => {
353            decode_varint(data, pos)?;
354        }
355        WIRE_FIXED64 => {
356            *pos += 8;
357        }
358        WIRE_BYTES => {
359            let length = decode_varint(data, pos)? as usize;
360            *pos += length;
361        }
362        WIRE_FIXED32 => {
363            *pos += 4;
364        }
365        _ => return Err(Error::protocol(format!("Unknown wire type: {}", wire_type))),
366    }
367    if *pos > data.len() {
368        return Err(Error::protocol("Message truncated"));
369    }
370    Ok(())
371}
372
373fn decode_string(data: &[u8], pos: &mut usize) -> Result<String> {
374    let length = decode_varint(data, pos)? as usize;
375    if *pos + length > data.len() {
376        return Err(Error::protocol("Truncated string"));
377    }
378    let s = String::from_utf8(data[*pos..*pos + length].to_vec())
379        .map_err(|_| Error::protocol("Invalid UTF-8"))?;
380    *pos += length;
381    Ok(s)
382}
383
384fn decode_bytes_field(data: &[u8], pos: &mut usize) -> Result<Vec<u8>> {
385    let length = decode_varint(data, pos)? as usize;
386    if *pos + length > data.len() {
387        return Err(Error::protocol("Truncated bytes"));
388    }
389    let bytes = data[*pos..*pos + length].to_vec();
390    *pos += length;
391    Ok(bytes)
392}
393
394fn decode_bool(data: &[u8], pos: &mut usize) -> Result<bool> {
395    let v = decode_varint(data, pos)?;
396    Ok(v != 0)
397}
398
399fn decode_int64(data: &[u8], pos: &mut usize) -> Result<i64> {
400    let v = decode_varint(data, pos)?;
401    // Handle sign extension for negative numbers
402    if v > 0x7FFFFFFFFFFFFFFF {
403        Ok((v as i64).wrapping_sub(i64::MIN).wrapping_add(i64::MIN))
404    } else {
405        Ok(v as i64)
406    }
407}
408
409fn decode_double(data: &[u8], pos: &mut usize) -> Result<f64> {
410    if *pos + 8 > data.len() {
411        return Err(Error::protocol("Truncated double"));
412    }
413    let bytes: [u8; 8] = data[*pos..*pos + 8].try_into().unwrap();
414    *pos += 8;
415    Ok(f64::from_le_bytes(bytes))
416}
417
418// =============================================================================
419// Message decoding
420// =============================================================================
421
422fn decode_hello_response(data: &[u8]) -> Result<HelloResponse> {
423    let mut resp = HelloResponse::default();
424    let mut pos = 0;
425    while pos < data.len() {
426        let (field_num, wire_type) = decode_tag(data, &mut pos)?;
427        match (field_num, wire_type) {
428            (1, WIRE_VARINT) => resp.success = decode_bool(data, &mut pos)?,
429            (2, WIRE_BYTES) => resp.session_id = decode_string(data, &mut pos)?,
430            (3, WIRE_BYTES) => resp.error_message = decode_string(data, &mut pos)?,
431            (4, WIRE_BYTES) => resp.capabilities.push(decode_string(data, &mut pos)?),
432            _ => skip_field(data, &mut pos, wire_type)?,
433        }
434    }
435    Ok(resp)
436}
437
438fn decode_column_definition(data: &[u8]) -> Result<ColumnDefinition> {
439    let mut col = ColumnDefinition::default();
440    let mut pos = 0;
441    while pos < data.len() {
442        let (field_num, wire_type) = decode_tag(data, &mut pos)?;
443        match (field_num, wire_type) {
444            (1, WIRE_BYTES) => col.name = decode_string(data, &mut pos)?,
445            (2, WIRE_BYTES) => col.col_type = decode_string(data, &mut pos)?,
446            _ => skip_field(data, &mut pos, wire_type)?,
447        }
448    }
449    Ok(col)
450}
451
452fn decode_schema_definition(data: &[u8]) -> Result<SchemaDefinition> {
453    let mut schema = SchemaDefinition::default();
454    let mut pos = 0;
455    while pos < data.len() {
456        let (field_num, wire_type) = decode_tag(data, &mut pos)?;
457        match (field_num, wire_type) {
458            (1, WIRE_BYTES) => {
459                let col_data = decode_bytes_field(data, &mut pos)?;
460                schema.columns.push(decode_column_definition(&col_data)?);
461            }
462            _ => skip_field(data, &mut pos, wire_type)?,
463        }
464    }
465    Ok(schema)
466}
467
468fn decode_value(data: &[u8]) -> Result<Value> {
469    let mut val = Value::default();
470    let mut pos = 0;
471    while pos < data.len() {
472        let (field_num, wire_type) = decode_tag(data, &mut pos)?;
473        match (field_num, wire_type) {
474            (1, WIRE_BYTES) => val.string_val = Some(decode_string(data, &mut pos)?),
475            (2, WIRE_VARINT) => val.int_val = Some(decode_int64(data, &mut pos)?),
476            (3, WIRE_FIXED64) => val.double_val = Some(decode_double(data, &mut pos)?),
477            (4, WIRE_VARINT) => val.bool_val = Some(decode_bool(data, &mut pos)?),
478            (5, WIRE_VARINT) => val.null_val = decode_bool(data, &mut pos)?,
479            _ => skip_field(data, &mut pos, wire_type)?,
480        }
481    }
482    Ok(val)
483}
484
485fn decode_row(data: &[u8]) -> Result<Row> {
486    let mut row = Row::default();
487    let mut pos = 0;
488    while pos < data.len() {
489        let (field_num, wire_type) = decode_tag(data, &mut pos)?;
490        match (field_num, wire_type) {
491            (1, WIRE_BYTES) => {
492                let val_data = decode_bytes_field(data, &mut pos)?;
493                row.values.push(decode_value(&val_data)?);
494            }
495            _ => skip_field(data, &mut pos, wire_type)?,
496        }
497    }
498    Ok(row)
499}
500
501fn decode_data_page(data: &[u8]) -> Result<DataPage> {
502    let mut page = DataPage::default();
503    let mut pos = 0;
504    while pos < data.len() {
505        let (field_num, wire_type) = decode_tag(data, &mut pos)?;
506        match (field_num, wire_type) {
507            (1, WIRE_BYTES) => {
508                let row_data = decode_bytes_field(data, &mut pos)?;
509                page.rows.push(decode_row(&row_data)?);
510            }
511            (2, WIRE_VARINT) => page.last_page = decode_bool(data, &mut pos)?,
512            _ => skip_field(data, &mut pos, wire_type)?,
513        }
514    }
515    Ok(page)
516}
517
518fn decode_proto_error(data: &[u8]) -> Result<ProtoError> {
519    let mut err = ProtoError::default();
520    let mut pos = 0;
521    while pos < data.len() {
522        let (field_num, wire_type) = decode_tag(data, &mut pos)?;
523        match (field_num, wire_type) {
524            (1, WIRE_BYTES) => err.code = decode_string(data, &mut pos)?,
525            (2, WIRE_BYTES) => err.message = decode_string(data, &mut pos)?,
526            (3, WIRE_BYTES) => err.error_type = decode_string(data, &mut pos)?,
527            _ => skip_field(data, &mut pos, wire_type)?,
528        }
529    }
530    Ok(err)
531}
532
533fn decode_execution_metrics(data: &[u8]) -> Result<ExecutionMetrics> {
534    let mut metrics = ExecutionMetrics::default();
535    let mut pos = 0;
536    while pos < data.len() {
537        let (field_num, wire_type) = decode_tag(data, &mut pos)?;
538        match (field_num, wire_type) {
539            (1, WIRE_VARINT) => metrics.parse_duration_ns = decode_int64(data, &mut pos)?,
540            (2, WIRE_VARINT) => metrics.plan_duration_ns = decode_int64(data, &mut pos)?,
541            (3, WIRE_VARINT) => metrics.execute_duration_ns = decode_int64(data, &mut pos)?,
542            (4, WIRE_VARINT) => metrics.total_duration_ns = decode_int64(data, &mut pos)?,
543            _ => skip_field(data, &mut pos, wire_type)?,
544        }
545    }
546    Ok(metrics)
547}
548
549fn decode_execution_response(data: &[u8]) -> Result<ExecutionResponse> {
550    let mut resp = ExecutionResponse::default();
551    let mut pos = 0;
552    while pos < data.len() {
553        let (field_num, wire_type) = decode_tag(data, &mut pos)?;
554        match (field_num, wire_type) {
555            (1, WIRE_BYTES) => {
556                let schema_data = decode_bytes_field(data, &mut pos)?;
557                resp.schema = Some(decode_schema_definition(&schema_data)?);
558            }
559            (2, WIRE_BYTES) => {
560                let page_data = decode_bytes_field(data, &mut pos)?;
561                resp.page = Some(decode_data_page(&page_data)?);
562            }
563            (3, WIRE_BYTES) => {
564                let err_data = decode_bytes_field(data, &mut pos)?;
565                resp.error = Some(decode_proto_error(&err_data)?);
566            }
567            (4, WIRE_BYTES) => {
568                let metrics_data = decode_bytes_field(data, &mut pos)?;
569                resp.metrics = Some(decode_execution_metrics(&metrics_data)?);
570            }
571            (5, WIRE_BYTES) => {
572                let _heartbeat_data = decode_bytes_field(data, &mut pos)?;
573                resp.heartbeat = Some(Heartbeat);
574            }
575            _ => skip_field(data, &mut pos, wire_type)?,
576        }
577    }
578    Ok(resp)
579}
580
581fn decode_ping_response(data: &[u8]) -> Result<PingResponse> {
582    let mut resp = PingResponse::default();
583    let mut pos = 0;
584    while pos < data.len() {
585        let (field_num, wire_type) = decode_tag(data, &mut pos)?;
586        match (field_num, wire_type) {
587            (1, WIRE_VARINT) => resp.ok = decode_bool(data, &mut pos)?,
588            _ => skip_field(data, &mut pos, wire_type)?,
589        }
590    }
591    Ok(resp)
592}
593
594fn decode_begin_response(data: &[u8]) -> Result<BeginResponse> {
595    let mut resp = BeginResponse::default();
596    let mut pos = 0;
597    while pos < data.len() {
598        let (field_num, wire_type) = decode_tag(data, &mut pos)?;
599        match (field_num, wire_type) {
600            (1, WIRE_BYTES) => resp.session_id = decode_string(data, &mut pos)?,
601            (2, WIRE_BYTES) => resp.tx_id = decode_string(data, &mut pos)?,
602            _ => skip_field(data, &mut pos, wire_type)?,
603        }
604    }
605    Ok(resp)
606}
607
608fn decode_commit_response(data: &[u8]) -> Result<CommitResponse> {
609    let mut resp = CommitResponse::default();
610    let mut pos = 0;
611    while pos < data.len() {
612        let (field_num, wire_type) = decode_tag(data, &mut pos)?;
613        match (field_num, wire_type) {
614            (1, WIRE_VARINT) => resp.success = decode_bool(data, &mut pos)?,
615            _ => skip_field(data, &mut pos, wire_type)?,
616        }
617    }
618    Ok(resp)
619}
620
621fn decode_rollback_response(data: &[u8]) -> Result<RollbackResponse> {
622    let mut resp = RollbackResponse::default();
623    let mut pos = 0;
624    while pos < data.len() {
625        let (field_num, wire_type) = decode_tag(data, &mut pos)?;
626        match (field_num, wire_type) {
627            (1, WIRE_VARINT) => resp.success = decode_bool(data, &mut pos)?,
628            _ => skip_field(data, &mut pos, wire_type)?,
629        }
630    }
631    Ok(resp)
632}
633
634/// Decode a QuicServerMessage from protobuf bytes.
635pub fn decode_quic_server_message(data: &[u8]) -> Result<QuicServerMessage> {
636    let mut msg = QuicServerMessage::default();
637    let mut pos = 0;
638    while pos < data.len() {
639        let (field_num, wire_type) = decode_tag(data, &mut pos)?;
640        match (field_num, wire_type) {
641            (1, WIRE_BYTES) => {
642                let hello_data = decode_bytes_field(data, &mut pos)?;
643                msg.hello = Some(decode_hello_response(&hello_data)?);
644            }
645            (2, WIRE_BYTES) => {
646                let exec_data = decode_bytes_field(data, &mut pos)?;
647                msg.execute = Some(decode_execution_response(&exec_data)?);
648            }
649            (3, WIRE_BYTES) => {
650                let ping_data = decode_bytes_field(data, &mut pos)?;
651                msg.ping = Some(decode_ping_response(&ping_data)?);
652            }
653            (6, WIRE_BYTES) => {
654                let begin_data = decode_bytes_field(data, &mut pos)?;
655                msg.begin = Some(decode_begin_response(&begin_data)?);
656            }
657            (7, WIRE_BYTES) => {
658                let commit_data = decode_bytes_field(data, &mut pos)?;
659                msg.commit = Some(decode_commit_response(&commit_data)?);
660            }
661            (8, WIRE_BYTES) => {
662                let rollback_data = decode_bytes_field(data, &mut pos)?;
663                msg.rollback = Some(decode_rollback_response(&rollback_data)?);
664            }
665            _ => skip_field(data, &mut pos, wire_type)?,
666        }
667    }
668    Ok(msg)
669}
670
671/// Encode message with 4-byte big-endian length prefix.
672pub fn encode_with_length_prefix(msg: &QuicClientMessage) -> Vec<u8> {
673    let data = encode_quic_client_message(msg);
674    let length = data.len() as u32;
675    let mut result = Vec::with_capacity(4 + data.len());
676    result.extend(&length.to_be_bytes());
677    result.extend(data);
678    result
679}
680
681/// Decode 4-byte big-endian length prefix.
682pub fn decode_length_prefix(data: &[u8]) -> Result<u32> {
683    if data.len() < 4 {
684        return Err(Error::protocol("Insufficient data for length prefix"));
685    }
686    Ok(u32::from_be_bytes([data[0], data[1], data[2], data[3]]))
687}
688
689// =============================================================================
690// Tests
691// =============================================================================
692
693#[cfg(test)]
694mod tests {
695    use super::*;
696
697    #[test]
698    fn test_encode_decode_hello_roundtrip() {
699        let req = HelloRequest {
700            username: "admin".to_string(),
701            password: "secret".to_string(),
702            tenant_id: "tenant1".to_string(),
703        };
704        let msg = QuicClientMessage {
705            hello: Some(req.clone()),
706            ..Default::default()
707        };
708        let encoded = encode_quic_client_message(&msg);
709        assert!(!encoded.is_empty());
710    }
711
712    #[test]
713    fn test_encode_decode_execute_roundtrip() {
714        let mut params = HashMap::new();
715        params.insert("name".to_string(), "Alice".to_string());
716        params.insert("age".to_string(), "30".to_string());
717
718        let req = ExecuteRequest {
719            session_id: "session123".to_string(),
720            query: "MATCH (n) RETURN n".to_string(),
721            parameters: params,
722        };
723        let msg = QuicClientMessage {
724            execute: Some(req),
725            ..Default::default()
726        };
727        let encoded = encode_quic_client_message(&msg);
728        assert!(!encoded.is_empty());
729    }
730
731    #[test]
732    fn test_encode_with_length_prefix() {
733        let msg = QuicClientMessage {
734            ping: Some(PingRequest),
735            ..Default::default()
736        };
737        let encoded = encode_with_length_prefix(&msg);
738        // Should have 4-byte length prefix
739        assert!(encoded.len() >= 4);
740        let length = u32::from_be_bytes([encoded[0], encoded[1], encoded[2], encoded[3]]);
741        assert_eq!(length as usize, encoded.len() - 4);
742    }
743
744    #[test]
745    fn test_decode_length_prefix() {
746        let data = [0x00, 0x00, 0x00, 0x10];
747        let length = decode_length_prefix(&data).unwrap();
748        assert_eq!(length, 16);
749    }
750
751    #[test]
752    fn test_decode_length_prefix_insufficient_data() {
753        let data = [0x00, 0x00];
754        let result = decode_length_prefix(&data);
755        assert!(result.is_err());
756    }
757
758    #[test]
759    fn test_decode_hello_response() {
760        // Manually construct a HelloResponse message
761        let mut encoded = Vec::new();
762        // Field 1 (success): varint true
763        encoded.extend(&[0x08, 0x01]);
764        // Field 2 (session_id): string "sess123"
765        encoded.extend(&[0x12, 0x07]);
766        encoded.extend(b"sess123");
767
768        let resp = decode_hello_response(&encoded).unwrap();
769        assert!(resp.success);
770        assert_eq!(resp.session_id, "sess123");
771    }
772
773    #[test]
774    fn test_decode_ping_response() {
775        // Field 1 (ok): varint true
776        let encoded = [0x08, 0x01];
777        let resp = decode_ping_response(&encoded).unwrap();
778        assert!(resp.ok);
779    }
780
781    #[test]
782    fn test_value_default() {
783        let val = Value::default();
784        assert!(val.string_val.is_none());
785        assert!(val.int_val.is_none());
786        assert!(val.double_val.is_none());
787        assert!(val.bool_val.is_none());
788        assert!(!val.null_val);
789    }
790
791    #[test]
792    fn test_message_defaults() {
793        let client_msg = QuicClientMessage::default();
794        assert!(client_msg.hello.is_none());
795        assert!(client_msg.execute.is_none());
796        assert!(client_msg.ping.is_none());
797
798        let server_msg = QuicServerMessage::default();
799        assert!(server_msg.hello.is_none());
800        assert!(server_msg.execute.is_none());
801        assert!(server_msg.ping.is_none());
802    }
803}