celers_protocol/
negotiation.rs

1//! Protocol version negotiation and detection
2//!
3//! This module provides utilities for detecting and negotiating Celery protocol
4//! versions between CeleRS and Python Celery workers/clients.
5//!
6//! # Protocol Versions
7//!
8//! - **Protocol v1**: Legacy format (Celery 3.x and earlier) - Not supported
9//! - **Protocol v2**: Current stable format (Celery 4.x+) - Fully supported
10//! - **Protocol v5**: Extended format (Celery 5.x+) - Fully supported
11//!
12//! # Example
13//!
14//! ```
15//! use celers_protocol::negotiation::{ProtocolNegotiator, negotiate_protocol};
16//! use celers_protocol::ProtocolVersion;
17//!
18//! // Negotiate between supported versions
19//! let negotiator = ProtocolNegotiator::new()
20//!     .prefer(ProtocolVersion::V5)
21//!     .support(ProtocolVersion::V2);
22//!
23//! let agreed = negotiator.negotiate(&[ProtocolVersion::V2]).unwrap();
24//! assert_eq!(agreed, ProtocolVersion::V2);
25//! ```
26
27use crate::ProtocolVersion;
28use std::collections::HashSet;
29
30/// Protocol detection result
31#[derive(Debug, Clone, PartialEq)]
32pub struct ProtocolDetection {
33    /// Detected protocol version
34    pub version: ProtocolVersion,
35    /// Confidence level (0.0 - 1.0)
36    pub confidence: f32,
37    /// Detection method used
38    pub method: DetectionMethod,
39}
40
41/// Method used for protocol detection
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum DetectionMethod {
44    /// Detected from message headers
45    Headers,
46    /// Detected from message structure
47    Structure,
48    /// Detected from content type
49    ContentType,
50    /// Default assumption
51    Default,
52}
53
54impl std::fmt::Display for DetectionMethod {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        match self {
57            DetectionMethod::Headers => write!(f, "headers"),
58            DetectionMethod::Structure => write!(f, "structure"),
59            DetectionMethod::ContentType => write!(f, "content-type"),
60            DetectionMethod::Default => write!(f, "default"),
61        }
62    }
63}
64
65/// Protocol negotiation error
66#[derive(Debug, Clone)]
67pub enum NegotiationError {
68    /// No common protocol version found
69    NoCommonVersion,
70    /// Protocol version not supported
71    UnsupportedVersion(ProtocolVersion),
72    /// Invalid protocol data
73    InvalidData(String),
74}
75
76impl std::fmt::Display for NegotiationError {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        match self {
79            NegotiationError::NoCommonVersion => {
80                write!(f, "No common protocol version found")
81            }
82            NegotiationError::UnsupportedVersion(v) => {
83                write!(f, "Protocol version {} is not supported", v)
84            }
85            NegotiationError::InvalidData(msg) => {
86                write!(f, "Invalid protocol data: {}", msg)
87            }
88        }
89    }
90}
91
92impl std::error::Error for NegotiationError {}
93
94/// Protocol negotiator for version agreement
95#[derive(Debug, Clone)]
96pub struct ProtocolNegotiator {
97    /// Supported protocol versions
98    supported: HashSet<ProtocolVersion>,
99    /// Preferred protocol version (highest priority)
100    preferred: Option<ProtocolVersion>,
101}
102
103impl Default for ProtocolNegotiator {
104    fn default() -> Self {
105        Self::new()
106    }
107}
108
109impl ProtocolNegotiator {
110    /// Create a new negotiator with default support (v2 and v5)
111    pub fn new() -> Self {
112        let mut supported = HashSet::new();
113        supported.insert(ProtocolVersion::V2);
114        supported.insert(ProtocolVersion::V5);
115
116        Self {
117            supported,
118            preferred: Some(ProtocolVersion::V2), // Default to v2 for compatibility
119        }
120    }
121
122    /// Create a negotiator that only supports v2
123    pub fn v2_only() -> Self {
124        let mut supported = HashSet::new();
125        supported.insert(ProtocolVersion::V2);
126
127        Self {
128            supported,
129            preferred: Some(ProtocolVersion::V2),
130        }
131    }
132
133    /// Create a negotiator that prefers v5
134    pub fn prefer_v5() -> Self {
135        let mut supported = HashSet::new();
136        supported.insert(ProtocolVersion::V2);
137        supported.insert(ProtocolVersion::V5);
138
139        Self {
140            supported,
141            preferred: Some(ProtocolVersion::V5),
142        }
143    }
144
145    /// Set the preferred protocol version
146    #[must_use]
147    pub fn prefer(mut self, version: ProtocolVersion) -> Self {
148        self.preferred = Some(version);
149        self.supported.insert(version);
150        self
151    }
152
153    /// Add support for a protocol version
154    #[must_use]
155    pub fn support(mut self, version: ProtocolVersion) -> Self {
156        self.supported.insert(version);
157        self
158    }
159
160    /// Remove support for a protocol version
161    #[must_use]
162    pub fn unsupport(mut self, version: ProtocolVersion) -> Self {
163        self.supported.remove(&version);
164        if self.preferred == Some(version) {
165            self.preferred = None;
166        }
167        self
168    }
169
170    /// Check if a protocol version is supported
171    #[inline]
172    pub fn is_supported(&self, version: ProtocolVersion) -> bool {
173        self.supported.contains(&version)
174    }
175
176    /// Get all supported versions
177    #[inline]
178    pub fn supported_versions(&self) -> Vec<ProtocolVersion> {
179        self.supported.iter().copied().collect()
180    }
181
182    /// Get the preferred version
183    #[inline]
184    pub fn preferred_version(&self) -> Option<ProtocolVersion> {
185        self.preferred
186    }
187
188    /// Negotiate a protocol version with a remote party
189    ///
190    /// Returns the agreed version based on mutual support, preferring
191    /// our preferred version if mutually supported.
192    pub fn negotiate(
193        &self,
194        remote_versions: &[ProtocolVersion],
195    ) -> Result<ProtocolVersion, NegotiationError> {
196        // Find common versions
197        let remote_set: HashSet<_> = remote_versions.iter().copied().collect();
198        let common: Vec<_> = self.supported.intersection(&remote_set).copied().collect();
199
200        if common.is_empty() {
201            return Err(NegotiationError::NoCommonVersion);
202        }
203
204        // If our preferred version is in common, use it
205        if let Some(preferred) = self.preferred {
206            if common.contains(&preferred) {
207                return Ok(preferred);
208            }
209        }
210
211        // Otherwise, prefer v5 over v2
212        if common.contains(&ProtocolVersion::V5) {
213            Ok(ProtocolVersion::V5)
214        } else {
215            Ok(ProtocolVersion::V2)
216        }
217    }
218
219    /// Validate that a message uses a supported protocol version
220    pub fn validate_version(&self, version: ProtocolVersion) -> Result<(), NegotiationError> {
221        if self.is_supported(version) {
222            Ok(())
223        } else {
224            Err(NegotiationError::UnsupportedVersion(version))
225        }
226    }
227}
228
229/// Detect protocol version from a JSON message
230///
231/// Analyzes the message structure to determine which protocol version it uses.
232pub fn detect_protocol(json: &serde_json::Value) -> ProtocolDetection {
233    // Check for protocol header (v5 style)
234    if let Some(headers) = json.get("headers") {
235        if headers.get("protocol").is_some() {
236            return ProtocolDetection {
237                version: ProtocolVersion::V5,
238                confidence: 1.0,
239                method: DetectionMethod::Headers,
240            };
241        }
242
243        // v2 has lang header
244        if headers.get("lang").is_some() {
245            return ProtocolDetection {
246                version: ProtocolVersion::V2,
247                confidence: 0.9,
248                method: DetectionMethod::Headers,
249            };
250        }
251    }
252
253    // Check message structure
254    if json.get("headers").is_some()
255        && json.get("properties").is_some()
256        && json.get("body").is_some()
257    {
258        return ProtocolDetection {
259            version: ProtocolVersion::V2,
260            confidence: 0.8,
261            method: DetectionMethod::Structure,
262        };
263    }
264
265    // Default to v2
266    ProtocolDetection {
267        version: ProtocolVersion::V2,
268        confidence: 0.5,
269        method: DetectionMethod::Default,
270    }
271}
272
273/// Detect protocol version from message bytes
274pub fn detect_protocol_from_bytes(bytes: &[u8]) -> Result<ProtocolDetection, NegotiationError> {
275    let json: serde_json::Value =
276        serde_json::from_slice(bytes).map_err(|e| NegotiationError::InvalidData(e.to_string()))?;
277
278    Ok(detect_protocol(&json))
279}
280
281/// Simple negotiation helper function
282///
283/// Negotiates between local and remote supported versions.
284pub fn negotiate_protocol(
285    local: &[ProtocolVersion],
286    remote: &[ProtocolVersion],
287) -> Result<ProtocolVersion, NegotiationError> {
288    let mut negotiator = ProtocolNegotiator::new();
289
290    // Clear default and add only specified versions
291    negotiator.supported.clear();
292    for v in local {
293        negotiator = negotiator.support(*v);
294    }
295
296    if let Some(&first) = local.first() {
297        negotiator = negotiator.prefer(first);
298    }
299
300    negotiator.negotiate(remote)
301}
302
303/// Protocol capabilities
304#[derive(Debug, Clone, Default)]
305pub struct ProtocolCapabilities {
306    /// Supports task chains
307    pub chains: bool,
308    /// Supports task groups
309    pub groups: bool,
310    /// Supports chords
311    pub chords: bool,
312    /// Supports ETA/countdown
313    pub eta: bool,
314    /// Supports task expiration
315    pub expires: bool,
316    /// Supports task revocation
317    pub revocation: bool,
318    /// Supports task events
319    pub events: bool,
320    /// Supports result backends
321    pub results: bool,
322}
323
324impl ProtocolCapabilities {
325    /// Get capabilities for protocol v2
326    pub fn v2() -> Self {
327        Self {
328            chains: true,
329            groups: true,
330            chords: true,
331            eta: true,
332            expires: true,
333            revocation: true,
334            events: true,
335            results: true,
336        }
337    }
338
339    /// Get capabilities for protocol v5
340    pub fn v5() -> Self {
341        Self {
342            chains: true,
343            groups: true,
344            chords: true,
345            eta: true,
346            expires: true,
347            revocation: true,
348            events: true,
349            results: true,
350        }
351    }
352
353    /// Get capabilities for a protocol version
354    pub fn for_version(version: ProtocolVersion) -> Self {
355        match version {
356            ProtocolVersion::V2 => Self::v2(),
357            ProtocolVersion::V5 => Self::v5(),
358        }
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use serde_json::json;
366
367    #[test]
368    fn test_protocol_negotiator_default() {
369        let negotiator = ProtocolNegotiator::new();
370        assert!(negotiator.is_supported(ProtocolVersion::V2));
371        assert!(negotiator.is_supported(ProtocolVersion::V5));
372        assert_eq!(negotiator.preferred_version(), Some(ProtocolVersion::V2));
373    }
374
375    #[test]
376    fn test_protocol_negotiator_v2_only() {
377        let negotiator = ProtocolNegotiator::v2_only();
378        assert!(negotiator.is_supported(ProtocolVersion::V2));
379        assert!(!negotiator.is_supported(ProtocolVersion::V5));
380    }
381
382    #[test]
383    fn test_protocol_negotiator_prefer_v5() {
384        let negotiator = ProtocolNegotiator::prefer_v5();
385        assert!(negotiator.is_supported(ProtocolVersion::V2));
386        assert!(negotiator.is_supported(ProtocolVersion::V5));
387        assert_eq!(negotiator.preferred_version(), Some(ProtocolVersion::V5));
388    }
389
390    #[test]
391    fn test_negotiate_common_version() {
392        let negotiator = ProtocolNegotiator::new();
393        let result = negotiator.negotiate(&[ProtocolVersion::V2]);
394        assert_eq!(result.unwrap(), ProtocolVersion::V2);
395    }
396
397    #[test]
398    fn test_negotiate_prefers_preferred() {
399        let negotiator = ProtocolNegotiator::new().prefer(ProtocolVersion::V5);
400        let result = negotiator.negotiate(&[ProtocolVersion::V2, ProtocolVersion::V5]);
401        assert_eq!(result.unwrap(), ProtocolVersion::V5);
402    }
403
404    #[test]
405    fn test_negotiate_no_common() {
406        let negotiator = ProtocolNegotiator::v2_only();
407        let result = negotiator.negotiate(&[ProtocolVersion::V5]);
408        assert!(matches!(result, Err(NegotiationError::NoCommonVersion)));
409    }
410
411    #[test]
412    fn test_validate_version_supported() {
413        let negotiator = ProtocolNegotiator::new();
414        assert!(negotiator.validate_version(ProtocolVersion::V2).is_ok());
415    }
416
417    #[test]
418    fn test_validate_version_unsupported() {
419        let negotiator = ProtocolNegotiator::v2_only().unsupport(ProtocolVersion::V2);
420        let result = negotiator.validate_version(ProtocolVersion::V5);
421        assert!(matches!(
422            result,
423            Err(NegotiationError::UnsupportedVersion(_))
424        ));
425    }
426
427    #[test]
428    fn test_detect_protocol_v2() {
429        let msg = json!({
430            "headers": {
431                "task": "test",
432                "id": "123",
433                "lang": "py"
434            },
435            "properties": {},
436            "body": "test"
437        });
438
439        let detection = detect_protocol(&msg);
440        assert_eq!(detection.version, ProtocolVersion::V2);
441        assert!(detection.confidence >= 0.8);
442    }
443
444    #[test]
445    fn test_detect_protocol_v5() {
446        let msg = json!({
447            "headers": {
448                "task": "test",
449                "id": "123",
450                "protocol": 2
451            },
452            "properties": {},
453            "body": "test"
454        });
455
456        let detection = detect_protocol(&msg);
457        assert_eq!(detection.version, ProtocolVersion::V5);
458        assert_eq!(detection.confidence, 1.0);
459    }
460
461    #[test]
462    fn test_detect_protocol_from_bytes() {
463        let bytes = br#"{"headers":{"lang":"py"},"properties":{},"body":""}"#;
464        let detection = detect_protocol_from_bytes(bytes).unwrap();
465        assert_eq!(detection.version, ProtocolVersion::V2);
466    }
467
468    #[test]
469    fn test_negotiate_protocol_helper() {
470        let result = negotiate_protocol(
471            &[ProtocolVersion::V2, ProtocolVersion::V5],
472            &[ProtocolVersion::V2],
473        );
474        assert_eq!(result.unwrap(), ProtocolVersion::V2);
475    }
476
477    #[test]
478    fn test_protocol_capabilities() {
479        let caps = ProtocolCapabilities::for_version(ProtocolVersion::V2);
480        assert!(caps.chains);
481        assert!(caps.groups);
482        assert!(caps.chords);
483        assert!(caps.events);
484    }
485
486    #[test]
487    fn test_detection_method_display() {
488        assert_eq!(DetectionMethod::Headers.to_string(), "headers");
489        assert_eq!(DetectionMethod::Structure.to_string(), "structure");
490        assert_eq!(DetectionMethod::ContentType.to_string(), "content-type");
491        assert_eq!(DetectionMethod::Default.to_string(), "default");
492    }
493
494    #[test]
495    fn test_negotiation_error_display() {
496        let err = NegotiationError::NoCommonVersion;
497        assert_eq!(err.to_string(), "No common protocol version found");
498
499        let err = NegotiationError::UnsupportedVersion(ProtocolVersion::V5);
500        assert!(err.to_string().contains("v5"));
501
502        let err = NegotiationError::InvalidData("test".to_string());
503        assert!(err.to_string().contains("test"));
504    }
505
506    #[test]
507    fn test_supported_versions() {
508        let negotiator = ProtocolNegotiator::new();
509        let versions = negotiator.supported_versions();
510        assert!(versions.contains(&ProtocolVersion::V2));
511        assert!(versions.contains(&ProtocolVersion::V5));
512    }
513}