Skip to main content

ormdb_proto/
handshake.rs

1//! Protocol handshake types for connection negotiation.
2
3use rkyv::{Archive, Deserialize, Serialize};
4
5/// Client handshake message sent when establishing a connection.
6#[derive(Debug, Clone, PartialEq, Archive, Serialize, Deserialize)]
7pub struct Handshake {
8    /// Protocol version the client supports.
9    pub protocol_version: u32,
10    /// Client identifier (for logging and debugging).
11    pub client_id: String,
12    /// Capabilities the client supports.
13    pub capabilities: Vec<String>,
14}
15
16impl Handshake {
17    /// Create a new handshake with the current protocol version.
18    pub fn new(client_id: impl Into<String>) -> Self {
19        Self {
20            protocol_version: crate::PROTOCOL_VERSION,
21            client_id: client_id.into(),
22            capabilities: vec![],
23        }
24    }
25
26    /// Create a handshake with a specific protocol version.
27    pub fn with_version(protocol_version: u32, client_id: impl Into<String>) -> Self {
28        Self {
29            protocol_version,
30            client_id: client_id.into(),
31            capabilities: vec![],
32        }
33    }
34
35    /// Add a capability to the handshake.
36    pub fn with_capability(mut self, capability: impl Into<String>) -> Self {
37        self.capabilities.push(capability.into());
38        self
39    }
40
41    /// Add multiple capabilities to the handshake.
42    pub fn with_capabilities(mut self, capabilities: Vec<String>) -> Self {
43        self.capabilities.extend(capabilities);
44        self
45    }
46}
47
48/// Server response to a client handshake.
49#[derive(Debug, Clone, PartialEq, Archive, Serialize, Deserialize)]
50pub struct HandshakeResponse {
51    /// Whether the handshake was accepted.
52    pub accepted: bool,
53    /// Protocol version the server will use for this connection.
54    pub protocol_version: u32,
55    /// Current schema version on the server.
56    pub schema_version: u64,
57    /// Server identifier.
58    pub server_id: String,
59    /// Capabilities the server supports.
60    pub capabilities: Vec<String>,
61    /// Error message if handshake was rejected.
62    pub error: Option<String>,
63}
64
65impl HandshakeResponse {
66    /// Create a successful handshake response.
67    pub fn accept(
68        protocol_version: u32,
69        schema_version: u64,
70        server_id: impl Into<String>,
71    ) -> Self {
72        Self {
73            accepted: true,
74            protocol_version,
75            schema_version,
76            server_id: server_id.into(),
77            capabilities: vec![],
78            error: None,
79        }
80    }
81
82    /// Create a rejected handshake response.
83    pub fn reject(error: impl Into<String>) -> Self {
84        Self {
85            accepted: false,
86            protocol_version: 0,
87            schema_version: 0,
88            server_id: String::new(),
89            capabilities: vec![],
90            error: Some(error.into()),
91        }
92    }
93
94    /// Add a capability to the response.
95    pub fn with_capability(mut self, capability: impl Into<String>) -> Self {
96        self.capabilities.push(capability.into());
97        self
98    }
99
100    /// Add multiple capabilities to the response.
101    pub fn with_capabilities(mut self, capabilities: Vec<String>) -> Self {
102        self.capabilities.extend(capabilities);
103        self
104    }
105}
106
107/// Standard capability identifiers.
108pub mod capabilities {
109    /// Streaming query results.
110    pub const STREAMING: &str = "streaming";
111    /// Change data capture / subscriptions.
112    pub const CDC: &str = "cdc";
113    /// Batch operations.
114    pub const BATCH: &str = "batch";
115    /// Compression support.
116    pub const COMPRESSION: &str = "compression";
117    /// Transaction support.
118    pub const TRANSACTIONS: &str = "transactions";
119}
120
121/// Check if a protocol version is compatible with the current version.
122pub fn is_version_compatible(client_version: u32, server_version: u32) -> bool {
123    // For now, require exact match. In the future, we can support
124    // version ranges or negotiate down to a common version.
125    client_version == server_version
126}
127
128/// Negotiate the protocol version between client and server.
129pub fn negotiate_version(client_version: u32, server_version: u32) -> Option<u32> {
130    if is_version_compatible(client_version, server_version) {
131        Some(server_version)
132    } else {
133        None
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    #[test]
142    fn test_handshake_creation() {
143        let handshake = Handshake::new("test-client")
144            .with_capability(capabilities::STREAMING)
145            .with_capability(capabilities::BATCH);
146
147        assert_eq!(handshake.protocol_version, crate::PROTOCOL_VERSION);
148        assert_eq!(handshake.client_id, "test-client");
149        assert_eq!(handshake.capabilities.len(), 2);
150        assert!(handshake.capabilities.contains(&capabilities::STREAMING.to_string()));
151    }
152
153    #[test]
154    fn test_handshake_response_accept() {
155        let response = HandshakeResponse::accept(1, 5, "server-1")
156            .with_capability(capabilities::STREAMING)
157            .with_capability(capabilities::TRANSACTIONS);
158
159        assert!(response.accepted);
160        assert_eq!(response.protocol_version, 1);
161        assert_eq!(response.schema_version, 5);
162        assert_eq!(response.server_id, "server-1");
163        assert!(response.error.is_none());
164        assert_eq!(response.capabilities.len(), 2);
165    }
166
167    #[test]
168    fn test_handshake_response_reject() {
169        let response = HandshakeResponse::reject("Unsupported protocol version");
170
171        assert!(!response.accepted);
172        assert_eq!(response.error, Some("Unsupported protocol version".to_string()));
173    }
174
175    #[test]
176    fn test_version_compatibility() {
177        assert!(is_version_compatible(1, 1));
178        assert!(!is_version_compatible(1, 2));
179        assert!(!is_version_compatible(2, 1));
180
181        assert_eq!(negotiate_version(1, 1), Some(1));
182        assert_eq!(negotiate_version(1, 2), None);
183    }
184
185    #[test]
186    fn test_handshake_serialization_roundtrip() {
187        let handshake = Handshake::new("rust-client-v1")
188            .with_capabilities(vec![
189                capabilities::STREAMING.into(),
190                capabilities::CDC.into(),
191            ]);
192
193        let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&handshake).unwrap();
194        let archived = rkyv::access::<ArchivedHandshake, rkyv::rancor::Error>(&bytes).unwrap();
195        let deserialized: Handshake =
196            rkyv::deserialize::<Handshake, rkyv::rancor::Error>(archived).unwrap();
197
198        assert_eq!(handshake, deserialized);
199
200        // Test response
201        let response = HandshakeResponse::accept(1, 10, "ormdb-server")
202            .with_capability(capabilities::TRANSACTIONS);
203
204        let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&response).unwrap();
205        let archived =
206            rkyv::access::<ArchivedHandshakeResponse, rkyv::rancor::Error>(&bytes).unwrap();
207        let deserialized: HandshakeResponse =
208            rkyv::deserialize::<HandshakeResponse, rkyv::rancor::Error>(archived).unwrap();
209
210        assert_eq!(response, deserialized);
211    }
212}