Skip to main content

geode_client/
proto.rs

1//! Protobuf types generated from geode.proto via prost-build/tonic-build.
2//!
3//! This module re-exports the generated types and provides helper functions
4//! for QUIC transport framing (4-byte big-endian length prefix).
5
6// Include the prost-generated code.
7#[allow(clippy::large_enum_variant)]
8#[path = "generated/geode.rs"]
9mod generated;
10
11// Re-export all generated types at this module level.
12pub use generated::*;
13
14use prost::Message;
15
16// Note: We use fully-qualified `crate::error::Error` to avoid collision with
17// the generated `proto::Error` message type (from `message Error` in geode.proto).
18use crate::error::Result;
19
20// =============================================================================
21// QUIC framing helpers (length-prefixed protobuf)
22// =============================================================================
23
24/// Encode a QuicClientMessage to protobuf bytes with a 4-byte big-endian
25/// length prefix, as required by the QUIC transport.
26pub fn encode_with_length_prefix(msg: &QuicClientMessage) -> Vec<u8> {
27    let data = msg.encode_to_vec();
28    let length = data.len() as u32;
29    let mut result = Vec::with_capacity(4 + data.len());
30    result.extend(&length.to_be_bytes());
31    result.extend(data);
32    result
33}
34
35/// Decode a 4-byte big-endian length prefix from the start of a byte slice.
36pub fn decode_length_prefix(data: &[u8]) -> Result<u32> {
37    if data.len() < 4 {
38        return Err(crate::error::Error::protocol(
39            "Insufficient data for length prefix",
40        ));
41    }
42    Ok(u32::from_be_bytes([data[0], data[1], data[2], data[3]]))
43}
44
45/// Decode a QuicServerMessage from raw protobuf bytes (without length prefix).
46pub fn decode_quic_server_message(data: &[u8]) -> Result<QuicServerMessage> {
47    QuicServerMessage::decode(data)
48        .map_err(|e| crate::error::Error::protocol(format!("Protobuf decode error: {}", e)))
49}
50
51// =============================================================================
52// Tests
53// =============================================================================
54
55#[cfg(test)]
56mod tests {
57    use super::*;
58
59    #[test]
60    fn test_encode_decode_hello_roundtrip() {
61        let req = HelloRequest {
62            username: "admin".to_string(),
63            password: "secret".to_string(),
64            tenant_id: Some("tenant1".to_string()),
65            client_name: String::new(),
66            client_version: String::new(),
67            wanted_conformance: String::new(),
68        };
69        let msg = QuicClientMessage {
70            msg: Some(quic_client_message::Msg::Hello(req)),
71        };
72        let encoded = msg.encode_to_vec();
73        assert!(!encoded.is_empty());
74
75        // Decode the client message as a QuicClientMessage
76        let decoded = QuicClientMessage::decode(encoded.as_slice()).unwrap();
77        match decoded.msg {
78            Some(quic_client_message::Msg::Hello(hello)) => {
79                assert_eq!(hello.username, "admin");
80                assert_eq!(hello.password, "secret");
81                assert_eq!(hello.tenant_id, Some("tenant1".to_string()));
82            }
83            _ => panic!("Expected Hello variant"),
84        }
85    }
86
87    #[test]
88    fn test_encode_decode_execute_roundtrip() {
89        let params = vec![
90            Param {
91                name: "name".to_string(),
92                value: Some(Value {
93                    kind: Some(value::Kind::StringVal(StringValue {
94                        value: "Alice".to_string(),
95                        kind: 0,
96                    })),
97                }),
98            },
99            Param {
100                name: "age".to_string(),
101                value: Some(Value {
102                    kind: Some(value::Kind::IntVal(IntValue { value: 30, kind: 0 })),
103                }),
104            },
105        ];
106
107        let req = ExecuteRequest {
108            session_id: "session123".to_string(),
109            query: "MATCH (n) RETURN n".to_string(),
110            params,
111        };
112        let msg = QuicClientMessage {
113            msg: Some(quic_client_message::Msg::Execute(req)),
114        };
115        let encoded = msg.encode_to_vec();
116        assert!(!encoded.is_empty());
117
118        let decoded = QuicClientMessage::decode(encoded.as_slice()).unwrap();
119        match decoded.msg {
120            Some(quic_client_message::Msg::Execute(exec)) => {
121                assert_eq!(exec.session_id, "session123");
122                assert_eq!(exec.query, "MATCH (n) RETURN n");
123                assert_eq!(exec.params.len(), 2);
124            }
125            _ => panic!("Expected Execute variant"),
126        }
127    }
128
129    #[test]
130    fn test_encode_with_length_prefix() {
131        let msg = QuicClientMessage {
132            msg: Some(quic_client_message::Msg::Ping(PingRequest {})),
133        };
134        let encoded = encode_with_length_prefix(&msg);
135        // Should have 4-byte length prefix
136        assert!(encoded.len() >= 4);
137        let length = u32::from_be_bytes([encoded[0], encoded[1], encoded[2], encoded[3]]);
138        assert_eq!(length as usize, encoded.len() - 4);
139    }
140
141    #[test]
142    fn test_decode_length_prefix() {
143        let data = [0x00, 0x00, 0x00, 0x10];
144        let length = decode_length_prefix(&data).unwrap();
145        assert_eq!(length, 16);
146    }
147
148    #[test]
149    fn test_decode_length_prefix_insufficient_data() {
150        let data = [0x00, 0x00];
151        let result = decode_length_prefix(&data);
152        assert!(result.is_err());
153    }
154
155    #[test]
156    fn test_decode_hello_response() {
157        // Build a HelloResponse, encode it, then decode
158        let resp = HelloResponse {
159            success: true,
160            session_id: "sess123".to_string(),
161            error_message: String::new(),
162            capabilities: vec![],
163        };
164        let encoded = resp.encode_to_vec();
165        let decoded = HelloResponse::decode(encoded.as_slice()).unwrap();
166        assert!(decoded.success);
167        assert_eq!(decoded.session_id, "sess123");
168    }
169
170    #[test]
171    fn test_decode_ping_response() {
172        let resp = PingResponse { ok: true };
173        let encoded = resp.encode_to_vec();
174        let decoded = PingResponse::decode(encoded.as_slice()).unwrap();
175        assert!(decoded.ok);
176    }
177
178    #[test]
179    fn test_value_null() {
180        let val = Value {
181            kind: Some(value::Kind::NullVal(NullValue {})),
182        };
183        assert!(matches!(val.kind, Some(value::Kind::NullVal(_))));
184    }
185
186    #[test]
187    fn test_value_default() {
188        let val = Value::default();
189        assert!(val.kind.is_none());
190    }
191
192    #[test]
193    fn test_message_defaults() {
194        let client_msg = QuicClientMessage::default();
195        assert!(client_msg.msg.is_none());
196
197        let server_msg = QuicServerMessage::default();
198        assert!(server_msg.msg.is_none());
199    }
200}