1use rkyv::{Archive, Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq, Archive, Serialize, Deserialize)]
7pub struct Handshake {
8 pub protocol_version: u32,
10 pub client_id: String,
12 pub capabilities: Vec<String>,
14}
15
16impl Handshake {
17 pub fn new(client_id: impl Into<String>) -> Self {
19 Self {
20 protocol_version: crate::PROTOCOL_VERSION,
21 client_id: client_id.into(),
22 capabilities: vec![],
23 }
24 }
25
26 pub fn with_version(protocol_version: u32, client_id: impl Into<String>) -> Self {
28 Self {
29 protocol_version,
30 client_id: client_id.into(),
31 capabilities: vec![],
32 }
33 }
34
35 pub fn with_capability(mut self, capability: impl Into<String>) -> Self {
37 self.capabilities.push(capability.into());
38 self
39 }
40
41 pub fn with_capabilities(mut self, capabilities: Vec<String>) -> Self {
43 self.capabilities.extend(capabilities);
44 self
45 }
46}
47
48#[derive(Debug, Clone, PartialEq, Archive, Serialize, Deserialize)]
50pub struct HandshakeResponse {
51 pub accepted: bool,
53 pub protocol_version: u32,
55 pub schema_version: u64,
57 pub server_id: String,
59 pub capabilities: Vec<String>,
61 pub error: Option<String>,
63}
64
65impl HandshakeResponse {
66 pub fn accept(
68 protocol_version: u32,
69 schema_version: u64,
70 server_id: impl Into<String>,
71 ) -> Self {
72 Self {
73 accepted: true,
74 protocol_version,
75 schema_version,
76 server_id: server_id.into(),
77 capabilities: vec![],
78 error: None,
79 }
80 }
81
82 pub fn reject(error: impl Into<String>) -> Self {
84 Self {
85 accepted: false,
86 protocol_version: 0,
87 schema_version: 0,
88 server_id: String::new(),
89 capabilities: vec![],
90 error: Some(error.into()),
91 }
92 }
93
94 pub fn with_capability(mut self, capability: impl Into<String>) -> Self {
96 self.capabilities.push(capability.into());
97 self
98 }
99
100 pub fn with_capabilities(mut self, capabilities: Vec<String>) -> Self {
102 self.capabilities.extend(capabilities);
103 self
104 }
105}
106
107pub mod capabilities {
109 pub const STREAMING: &str = "streaming";
111 pub const CDC: &str = "cdc";
113 pub const BATCH: &str = "batch";
115 pub const COMPRESSION: &str = "compression";
117 pub const TRANSACTIONS: &str = "transactions";
119}
120
121pub fn is_version_compatible(client_version: u32, server_version: u32) -> bool {
123 client_version == server_version
126}
127
128pub fn negotiate_version(client_version: u32, server_version: u32) -> Option<u32> {
130 if is_version_compatible(client_version, server_version) {
131 Some(server_version)
132 } else {
133 None
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140
141 #[test]
142 fn test_handshake_creation() {
143 let handshake = Handshake::new("test-client")
144 .with_capability(capabilities::STREAMING)
145 .with_capability(capabilities::BATCH);
146
147 assert_eq!(handshake.protocol_version, crate::PROTOCOL_VERSION);
148 assert_eq!(handshake.client_id, "test-client");
149 assert_eq!(handshake.capabilities.len(), 2);
150 assert!(handshake.capabilities.contains(&capabilities::STREAMING.to_string()));
151 }
152
153 #[test]
154 fn test_handshake_response_accept() {
155 let response = HandshakeResponse::accept(1, 5, "server-1")
156 .with_capability(capabilities::STREAMING)
157 .with_capability(capabilities::TRANSACTIONS);
158
159 assert!(response.accepted);
160 assert_eq!(response.protocol_version, 1);
161 assert_eq!(response.schema_version, 5);
162 assert_eq!(response.server_id, "server-1");
163 assert!(response.error.is_none());
164 assert_eq!(response.capabilities.len(), 2);
165 }
166
167 #[test]
168 fn test_handshake_response_reject() {
169 let response = HandshakeResponse::reject("Unsupported protocol version");
170
171 assert!(!response.accepted);
172 assert_eq!(response.error, Some("Unsupported protocol version".to_string()));
173 }
174
175 #[test]
176 fn test_version_compatibility() {
177 assert!(is_version_compatible(1, 1));
178 assert!(!is_version_compatible(1, 2));
179 assert!(!is_version_compatible(2, 1));
180
181 assert_eq!(negotiate_version(1, 1), Some(1));
182 assert_eq!(negotiate_version(1, 2), None);
183 }
184
185 #[test]
186 fn test_handshake_serialization_roundtrip() {
187 let handshake = Handshake::new("rust-client-v1")
188 .with_capabilities(vec![
189 capabilities::STREAMING.into(),
190 capabilities::CDC.into(),
191 ]);
192
193 let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&handshake).unwrap();
194 let archived = rkyv::access::<ArchivedHandshake, rkyv::rancor::Error>(&bytes).unwrap();
195 let deserialized: Handshake =
196 rkyv::deserialize::<Handshake, rkyv::rancor::Error>(archived).unwrap();
197
198 assert_eq!(handshake, deserialized);
199
200 let response = HandshakeResponse::accept(1, 10, "ormdb-server")
202 .with_capability(capabilities::TRANSACTIONS);
203
204 let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&response).unwrap();
205 let archived =
206 rkyv::access::<ArchivedHandshakeResponse, rkyv::rancor::Error>(&bytes).unwrap();
207 let deserialized: HandshakeResponse =
208 rkyv::deserialize::<HandshakeResponse, rkyv::rancor::Error>(archived).unwrap();
209
210 assert_eq!(response, deserialized);
211 }
212}