1use crate::ProtocolVersion;
28use std::collections::HashSet;
29
30#[derive(Debug, Clone, PartialEq)]
32pub struct ProtocolDetection {
33 pub version: ProtocolVersion,
35 pub confidence: f32,
37 pub method: DetectionMethod,
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum DetectionMethod {
44 Headers,
46 Structure,
48 ContentType,
50 Default,
52}
53
54impl std::fmt::Display for DetectionMethod {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 match self {
57 DetectionMethod::Headers => write!(f, "headers"),
58 DetectionMethod::Structure => write!(f, "structure"),
59 DetectionMethod::ContentType => write!(f, "content-type"),
60 DetectionMethod::Default => write!(f, "default"),
61 }
62 }
63}
64
65#[derive(Debug, Clone)]
67pub enum NegotiationError {
68 NoCommonVersion,
70 UnsupportedVersion(ProtocolVersion),
72 InvalidData(String),
74}
75
76impl std::fmt::Display for NegotiationError {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 match self {
79 NegotiationError::NoCommonVersion => {
80 write!(f, "No common protocol version found")
81 }
82 NegotiationError::UnsupportedVersion(v) => {
83 write!(f, "Protocol version {} is not supported", v)
84 }
85 NegotiationError::InvalidData(msg) => {
86 write!(f, "Invalid protocol data: {}", msg)
87 }
88 }
89 }
90}
91
92impl std::error::Error for NegotiationError {}
93
94#[derive(Debug, Clone)]
96pub struct ProtocolNegotiator {
97 supported: HashSet<ProtocolVersion>,
99 preferred: Option<ProtocolVersion>,
101}
102
103impl Default for ProtocolNegotiator {
104 fn default() -> Self {
105 Self::new()
106 }
107}
108
109impl ProtocolNegotiator {
110 pub fn new() -> Self {
112 let mut supported = HashSet::new();
113 supported.insert(ProtocolVersion::V2);
114 supported.insert(ProtocolVersion::V5);
115
116 Self {
117 supported,
118 preferred: Some(ProtocolVersion::V2), }
120 }
121
122 pub fn v2_only() -> Self {
124 let mut supported = HashSet::new();
125 supported.insert(ProtocolVersion::V2);
126
127 Self {
128 supported,
129 preferred: Some(ProtocolVersion::V2),
130 }
131 }
132
133 pub fn prefer_v5() -> Self {
135 let mut supported = HashSet::new();
136 supported.insert(ProtocolVersion::V2);
137 supported.insert(ProtocolVersion::V5);
138
139 Self {
140 supported,
141 preferred: Some(ProtocolVersion::V5),
142 }
143 }
144
145 #[must_use]
147 pub fn prefer(mut self, version: ProtocolVersion) -> Self {
148 self.preferred = Some(version);
149 self.supported.insert(version);
150 self
151 }
152
153 #[must_use]
155 pub fn support(mut self, version: ProtocolVersion) -> Self {
156 self.supported.insert(version);
157 self
158 }
159
160 #[must_use]
162 pub fn unsupport(mut self, version: ProtocolVersion) -> Self {
163 self.supported.remove(&version);
164 if self.preferred == Some(version) {
165 self.preferred = None;
166 }
167 self
168 }
169
170 #[inline]
172 pub fn is_supported(&self, version: ProtocolVersion) -> bool {
173 self.supported.contains(&version)
174 }
175
176 #[inline]
178 pub fn supported_versions(&self) -> Vec<ProtocolVersion> {
179 self.supported.iter().copied().collect()
180 }
181
182 #[inline]
184 pub fn preferred_version(&self) -> Option<ProtocolVersion> {
185 self.preferred
186 }
187
188 pub fn negotiate(
193 &self,
194 remote_versions: &[ProtocolVersion],
195 ) -> Result<ProtocolVersion, NegotiationError> {
196 let remote_set: HashSet<_> = remote_versions.iter().copied().collect();
198 let common: Vec<_> = self.supported.intersection(&remote_set).copied().collect();
199
200 if common.is_empty() {
201 return Err(NegotiationError::NoCommonVersion);
202 }
203
204 if let Some(preferred) = self.preferred {
206 if common.contains(&preferred) {
207 return Ok(preferred);
208 }
209 }
210
211 if common.contains(&ProtocolVersion::V5) {
213 Ok(ProtocolVersion::V5)
214 } else {
215 Ok(ProtocolVersion::V2)
216 }
217 }
218
219 pub fn validate_version(&self, version: ProtocolVersion) -> Result<(), NegotiationError> {
221 if self.is_supported(version) {
222 Ok(())
223 } else {
224 Err(NegotiationError::UnsupportedVersion(version))
225 }
226 }
227}
228
229pub fn detect_protocol(json: &serde_json::Value) -> ProtocolDetection {
233 if let Some(headers) = json.get("headers") {
235 if headers.get("protocol").is_some() {
236 return ProtocolDetection {
237 version: ProtocolVersion::V5,
238 confidence: 1.0,
239 method: DetectionMethod::Headers,
240 };
241 }
242
243 if headers.get("lang").is_some() {
245 return ProtocolDetection {
246 version: ProtocolVersion::V2,
247 confidence: 0.9,
248 method: DetectionMethod::Headers,
249 };
250 }
251 }
252
253 if json.get("headers").is_some()
255 && json.get("properties").is_some()
256 && json.get("body").is_some()
257 {
258 return ProtocolDetection {
259 version: ProtocolVersion::V2,
260 confidence: 0.8,
261 method: DetectionMethod::Structure,
262 };
263 }
264
265 ProtocolDetection {
267 version: ProtocolVersion::V2,
268 confidence: 0.5,
269 method: DetectionMethod::Default,
270 }
271}
272
273pub fn detect_protocol_from_bytes(bytes: &[u8]) -> Result<ProtocolDetection, NegotiationError> {
275 let json: serde_json::Value =
276 serde_json::from_slice(bytes).map_err(|e| NegotiationError::InvalidData(e.to_string()))?;
277
278 Ok(detect_protocol(&json))
279}
280
281pub fn negotiate_protocol(
285 local: &[ProtocolVersion],
286 remote: &[ProtocolVersion],
287) -> Result<ProtocolVersion, NegotiationError> {
288 let mut negotiator = ProtocolNegotiator::new();
289
290 negotiator.supported.clear();
292 for v in local {
293 negotiator = negotiator.support(*v);
294 }
295
296 if let Some(&first) = local.first() {
297 negotiator = negotiator.prefer(first);
298 }
299
300 negotiator.negotiate(remote)
301}
302
303#[derive(Debug, Clone, Default)]
305pub struct ProtocolCapabilities {
306 pub chains: bool,
308 pub groups: bool,
310 pub chords: bool,
312 pub eta: bool,
314 pub expires: bool,
316 pub revocation: bool,
318 pub events: bool,
320 pub results: bool,
322}
323
324impl ProtocolCapabilities {
325 pub fn v2() -> Self {
327 Self {
328 chains: true,
329 groups: true,
330 chords: true,
331 eta: true,
332 expires: true,
333 revocation: true,
334 events: true,
335 results: true,
336 }
337 }
338
339 pub fn v5() -> Self {
341 Self {
342 chains: true,
343 groups: true,
344 chords: true,
345 eta: true,
346 expires: true,
347 revocation: true,
348 events: true,
349 results: true,
350 }
351 }
352
353 pub fn for_version(version: ProtocolVersion) -> Self {
355 match version {
356 ProtocolVersion::V2 => Self::v2(),
357 ProtocolVersion::V5 => Self::v5(),
358 }
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use serde_json::json;
366
367 #[test]
368 fn test_protocol_negotiator_default() {
369 let negotiator = ProtocolNegotiator::new();
370 assert!(negotiator.is_supported(ProtocolVersion::V2));
371 assert!(negotiator.is_supported(ProtocolVersion::V5));
372 assert_eq!(negotiator.preferred_version(), Some(ProtocolVersion::V2));
373 }
374
375 #[test]
376 fn test_protocol_negotiator_v2_only() {
377 let negotiator = ProtocolNegotiator::v2_only();
378 assert!(negotiator.is_supported(ProtocolVersion::V2));
379 assert!(!negotiator.is_supported(ProtocolVersion::V5));
380 }
381
382 #[test]
383 fn test_protocol_negotiator_prefer_v5() {
384 let negotiator = ProtocolNegotiator::prefer_v5();
385 assert!(negotiator.is_supported(ProtocolVersion::V2));
386 assert!(negotiator.is_supported(ProtocolVersion::V5));
387 assert_eq!(negotiator.preferred_version(), Some(ProtocolVersion::V5));
388 }
389
390 #[test]
391 fn test_negotiate_common_version() {
392 let negotiator = ProtocolNegotiator::new();
393 let result = negotiator.negotiate(&[ProtocolVersion::V2]);
394 assert_eq!(result.unwrap(), ProtocolVersion::V2);
395 }
396
397 #[test]
398 fn test_negotiate_prefers_preferred() {
399 let negotiator = ProtocolNegotiator::new().prefer(ProtocolVersion::V5);
400 let result = negotiator.negotiate(&[ProtocolVersion::V2, ProtocolVersion::V5]);
401 assert_eq!(result.unwrap(), ProtocolVersion::V5);
402 }
403
404 #[test]
405 fn test_negotiate_no_common() {
406 let negotiator = ProtocolNegotiator::v2_only();
407 let result = negotiator.negotiate(&[ProtocolVersion::V5]);
408 assert!(matches!(result, Err(NegotiationError::NoCommonVersion)));
409 }
410
411 #[test]
412 fn test_validate_version_supported() {
413 let negotiator = ProtocolNegotiator::new();
414 assert!(negotiator.validate_version(ProtocolVersion::V2).is_ok());
415 }
416
417 #[test]
418 fn test_validate_version_unsupported() {
419 let negotiator = ProtocolNegotiator::v2_only().unsupport(ProtocolVersion::V2);
420 let result = negotiator.validate_version(ProtocolVersion::V5);
421 assert!(matches!(
422 result,
423 Err(NegotiationError::UnsupportedVersion(_))
424 ));
425 }
426
427 #[test]
428 fn test_detect_protocol_v2() {
429 let msg = json!({
430 "headers": {
431 "task": "test",
432 "id": "123",
433 "lang": "py"
434 },
435 "properties": {},
436 "body": "test"
437 });
438
439 let detection = detect_protocol(&msg);
440 assert_eq!(detection.version, ProtocolVersion::V2);
441 assert!(detection.confidence >= 0.8);
442 }
443
444 #[test]
445 fn test_detect_protocol_v5() {
446 let msg = json!({
447 "headers": {
448 "task": "test",
449 "id": "123",
450 "protocol": 2
451 },
452 "properties": {},
453 "body": "test"
454 });
455
456 let detection = detect_protocol(&msg);
457 assert_eq!(detection.version, ProtocolVersion::V5);
458 assert_eq!(detection.confidence, 1.0);
459 }
460
461 #[test]
462 fn test_detect_protocol_from_bytes() {
463 let bytes = br#"{"headers":{"lang":"py"},"properties":{},"body":""}"#;
464 let detection = detect_protocol_from_bytes(bytes).unwrap();
465 assert_eq!(detection.version, ProtocolVersion::V2);
466 }
467
468 #[test]
469 fn test_negotiate_protocol_helper() {
470 let result = negotiate_protocol(
471 &[ProtocolVersion::V2, ProtocolVersion::V5],
472 &[ProtocolVersion::V2],
473 );
474 assert_eq!(result.unwrap(), ProtocolVersion::V2);
475 }
476
477 #[test]
478 fn test_protocol_capabilities() {
479 let caps = ProtocolCapabilities::for_version(ProtocolVersion::V2);
480 assert!(caps.chains);
481 assert!(caps.groups);
482 assert!(caps.chords);
483 assert!(caps.events);
484 }
485
486 #[test]
487 fn test_detection_method_display() {
488 assert_eq!(DetectionMethod::Headers.to_string(), "headers");
489 assert_eq!(DetectionMethod::Structure.to_string(), "structure");
490 assert_eq!(DetectionMethod::ContentType.to_string(), "content-type");
491 assert_eq!(DetectionMethod::Default.to_string(), "default");
492 }
493
494 #[test]
495 fn test_negotiation_error_display() {
496 let err = NegotiationError::NoCommonVersion;
497 assert_eq!(err.to_string(), "No common protocol version found");
498
499 let err = NegotiationError::UnsupportedVersion(ProtocolVersion::V5);
500 assert!(err.to_string().contains("v5"));
501
502 let err = NegotiationError::InvalidData("test".to_string());
503 assert!(err.to_string().contains("test"));
504 }
505
506 #[test]
507 fn test_supported_versions() {
508 let negotiator = ProtocolNegotiator::new();
509 let versions = negotiator.supported_versions();
510 assert!(versions.contains(&ProtocolVersion::V2));
511 assert!(versions.contains(&ProtocolVersion::V5));
512 }
513}