celers_protocol/
extension_api.rs

1//! Custom protocol extensions API
2//!
3//! This module provides a flexible extension system for adding custom fields
4//! and behaviors to Celery protocol messages without breaking compatibility.
5//!
6//! # Examples
7//!
8//! ```
9//! use celers_protocol::extension_api::{Extension, ExtensionRegistry, ExtensionValue};
10//! use serde_json::json;
11//!
12//! // Define a custom extension
13//! struct TelemetryExtension;
14//!
15//! impl Extension for TelemetryExtension {
16//!     fn name(&self) -> &str {
17//!         "telemetry"
18//!     }
19//!
20//!     fn validate(&self, value: &ExtensionValue) -> Result<(), String> {
21//!         // Custom validation logic
22//!         if let ExtensionValue::Object(map) = value {
23//!             if !map.contains_key("trace_id") {
24//!                 return Err("Missing trace_id".to_string());
25//!             }
26//!         }
27//!         Ok(())
28//!     }
29//! }
30//!
31//! // Register extension
32//! let mut registry = ExtensionRegistry::new();
33//! registry.register(Box::new(TelemetryExtension));
34//!
35//! // Use extension
36//! let value = ExtensionValue::Object(
37//!     vec![("trace_id".to_string(), json!("abc123"))]
38//!         .into_iter()
39//!         .collect()
40//! );
41//! registry.validate("telemetry", &value).unwrap();
42//! ```
43
44use serde::{Deserialize, Serialize};
45use std::collections::HashMap;
46
47/// Extension value types
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
49#[serde(untagged)]
50pub enum ExtensionValue {
51    /// String value
52    String(String),
53    /// Integer value
54    Integer(i64),
55    /// Float value
56    Float(f64),
57    /// Boolean value
58    Boolean(bool),
59    /// Array of values
60    Array(Vec<serde_json::Value>),
61    /// Object/map of values
62    Object(HashMap<String, serde_json::Value>),
63    /// Null value
64    Null,
65}
66
67/// Extension trait for custom protocol extensions
68pub trait Extension: Send + Sync {
69    /// Extension name (must be unique)
70    fn name(&self) -> &str;
71
72    /// Validate extension value
73    fn validate(&self, value: &ExtensionValue) -> Result<(), String>;
74
75    /// Transform extension value (optional)
76    fn transform(&self, value: ExtensionValue) -> Result<ExtensionValue, String> {
77        Ok(value)
78    }
79
80    /// Check if extension is compatible with protocol version
81    fn is_compatible(&self, _version: crate::ProtocolVersion) -> bool {
82        true // Compatible by default
83    }
84}
85
86/// Registry for managing custom extensions
87#[derive(Default)]
88pub struct ExtensionRegistry {
89    extensions: HashMap<String, Box<dyn Extension>>,
90}
91
92impl ExtensionRegistry {
93    /// Create a new extension registry
94    pub fn new() -> Self {
95        Self {
96            extensions: HashMap::new(),
97        }
98    }
99
100    /// Register a new extension
101    pub fn register(&mut self, extension: Box<dyn Extension>) -> Result<(), String> {
102        let name = extension.name().to_string();
103
104        if self.extensions.contains_key(&name) {
105            return Err(format!("Extension '{}' already registered", name));
106        }
107
108        self.extensions.insert(name, extension);
109        Ok(())
110    }
111
112    /// Unregister an extension
113    pub fn unregister(&mut self, name: &str) -> bool {
114        self.extensions.remove(name).is_some()
115    }
116
117    /// Get an extension by name
118    pub fn get(&self, name: &str) -> Option<&dyn Extension> {
119        self.extensions.get(name).map(|b| b.as_ref())
120    }
121
122    /// Check if an extension is registered
123    #[inline]
124    pub fn has(&self, name: &str) -> bool {
125        self.extensions.contains_key(name)
126    }
127
128    /// List all registered extension names
129    #[inline]
130    pub fn list(&self) -> Vec<&str> {
131        self.extensions.keys().map(|s| s.as_str()).collect()
132    }
133
134    /// Validate an extension value
135    pub fn validate(&self, name: &str, value: &ExtensionValue) -> Result<(), String> {
136        match self.get(name) {
137            Some(ext) => ext.validate(value),
138            None => Err(format!("Extension '{}' not registered", name)),
139        }
140    }
141
142    /// Transform an extension value
143    pub fn transform(&self, name: &str, value: ExtensionValue) -> Result<ExtensionValue, String> {
144        match self.get(name) {
145            Some(ext) => ext.transform(value),
146            None => Err(format!("Extension '{}' not registered", name)),
147        }
148    }
149
150    /// Validate all extensions in a map
151    pub fn validate_all(&self, extensions: &HashMap<String, ExtensionValue>) -> Result<(), String> {
152        for (name, value) in extensions {
153            self.validate(name, value)?;
154        }
155        Ok(())
156    }
157}
158
159/// Message with custom extensions
160#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct ExtendedMessage {
162    /// Base message
163    #[serde(flatten)]
164    pub message: crate::Message,
165
166    /// Custom extensions
167    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
168    pub extensions: HashMap<String, ExtensionValue>,
169}
170
171impl ExtendedMessage {
172    /// Create an extended message
173    pub fn new(message: crate::Message) -> Self {
174        Self {
175            message,
176            extensions: HashMap::new(),
177        }
178    }
179
180    /// Add an extension
181    pub fn with_extension(mut self, name: String, value: ExtensionValue) -> Self {
182        self.extensions.insert(name, value);
183        self
184    }
185
186    /// Get an extension value
187    pub fn get_extension(&self, name: &str) -> Option<&ExtensionValue> {
188        self.extensions.get(name)
189    }
190
191    /// Remove an extension
192    pub fn remove_extension(&mut self, name: &str) -> Option<ExtensionValue> {
193        self.extensions.remove(name)
194    }
195
196    /// Validate all extensions
197    pub fn validate_extensions(&self, registry: &ExtensionRegistry) -> Result<(), String> {
198        registry.validate_all(&self.extensions)
199    }
200}
201
202// Built-in extensions
203
204/// Telemetry/tracing extension
205pub struct TelemetryExtension;
206
207impl Extension for TelemetryExtension {
208    fn name(&self) -> &str {
209        "telemetry"
210    }
211
212    fn validate(&self, value: &ExtensionValue) -> Result<(), String> {
213        match value {
214            ExtensionValue::Object(map) => {
215                if !map.contains_key("trace_id") && !map.contains_key("span_id") {
216                    return Err("Telemetry extension requires 'trace_id' or 'span_id'".to_string());
217                }
218                Ok(())
219            }
220            _ => Err("Telemetry extension must be an object".to_string()),
221        }
222    }
223}
224
225/// Metrics collection extension
226pub struct MetricsExtension;
227
228impl Extension for MetricsExtension {
229    fn name(&self) -> &str {
230        "metrics"
231    }
232
233    fn validate(&self, value: &ExtensionValue) -> Result<(), String> {
234        match value {
235            ExtensionValue::Object(_) | ExtensionValue::Array(_) => Ok(()),
236            _ => Err("Metrics extension must be an object or array".to_string()),
237        }
238    }
239}
240
241/// Custom routing extension
242pub struct RoutingExtension;
243
244impl Extension for RoutingExtension {
245    fn name(&self) -> &str {
246        "routing"
247    }
248
249    fn validate(&self, value: &ExtensionValue) -> Result<(), String> {
250        match value {
251            ExtensionValue::Object(map) => {
252                if let Some(priority) = map.get("priority") {
253                    if let Some(p) = priority.as_i64() {
254                        if !(0..=9).contains(&p) {
255                            return Err("Routing priority must be 0-9".to_string());
256                        }
257                    }
258                }
259                Ok(())
260            }
261            _ => Err("Routing extension must be an object".to_string()),
262        }
263    }
264}
265
266/// Create a registry with built-in extensions
267pub fn create_default_registry() -> ExtensionRegistry {
268    let mut registry = ExtensionRegistry::new();
269    registry
270        .register(Box::new(TelemetryExtension))
271        .expect("Failed to register TelemetryExtension");
272    registry
273        .register(Box::new(MetricsExtension))
274        .expect("Failed to register MetricsExtension");
275    registry
276        .register(Box::new(RoutingExtension))
277        .expect("Failed to register RoutingExtension");
278    registry
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284    use serde_json::json;
285    use uuid::Uuid;
286
287    #[test]
288    fn test_extension_registry() {
289        let mut registry = ExtensionRegistry::new();
290
291        struct TestExt;
292        impl Extension for TestExt {
293            fn name(&self) -> &str {
294                "test"
295            }
296            fn validate(&self, _value: &ExtensionValue) -> Result<(), String> {
297                Ok(())
298            }
299        }
300
301        assert!(registry.register(Box::new(TestExt)).is_ok());
302        assert!(registry.has("test"));
303        assert_eq!(registry.list(), vec!["test"]);
304    }
305
306    #[test]
307    fn test_duplicate_registration() {
308        let mut registry = ExtensionRegistry::new();
309
310        struct TestExt;
311        impl Extension for TestExt {
312            fn name(&self) -> &str {
313                "test"
314            }
315            fn validate(&self, _value: &ExtensionValue) -> Result<(), String> {
316                Ok(())
317            }
318        }
319
320        assert!(registry.register(Box::new(TestExt)).is_ok());
321        assert!(registry.register(Box::new(TestExt)).is_err());
322    }
323
324    #[test]
325    fn test_extension_validation() {
326        let registry = create_default_registry();
327
328        let telemetry = ExtensionValue::Object(
329            vec![("trace_id".to_string(), json!("abc123"))]
330                .into_iter()
331                .collect(),
332        );
333
334        assert!(registry.validate("telemetry", &telemetry).is_ok());
335    }
336
337    #[test]
338    fn test_invalid_telemetry() {
339        let registry = create_default_registry();
340
341        let invalid = ExtensionValue::Object(HashMap::new());
342        assert!(registry.validate("telemetry", &invalid).is_err());
343    }
344
345    #[test]
346    fn test_extended_message() {
347        let task_id = Uuid::new_v4();
348        let body = serde_json::to_vec(&crate::TaskArgs::new()).unwrap();
349        let msg = crate::Message::new("tasks.test".to_string(), task_id, body);
350
351        let ext_msg = ExtendedMessage::new(msg).with_extension(
352            "telemetry".to_string(),
353            ExtensionValue::Object(
354                vec![("trace_id".to_string(), json!("xyz789"))]
355                    .into_iter()
356                    .collect(),
357            ),
358        );
359
360        assert!(ext_msg.get_extension("telemetry").is_some());
361    }
362
363    #[test]
364    fn test_extended_message_validation() {
365        let task_id = Uuid::new_v4();
366        let body = serde_json::to_vec(&crate::TaskArgs::new()).unwrap();
367        let msg = crate::Message::new("tasks.test".to_string(), task_id, body);
368
369        let ext_msg = ExtendedMessage::new(msg).with_extension(
370            "telemetry".to_string(),
371            ExtensionValue::Object(
372                vec![("trace_id".to_string(), json!("abc123"))]
373                    .into_iter()
374                    .collect(),
375            ),
376        );
377
378        let registry = create_default_registry();
379        assert!(ext_msg.validate_extensions(&registry).is_ok());
380    }
381
382    #[test]
383    fn test_unregister_extension() {
384        let mut registry = ExtensionRegistry::new();
385
386        struct TestExt;
387        impl Extension for TestExt {
388            fn name(&self) -> &str {
389                "test"
390            }
391            fn validate(&self, _value: &ExtensionValue) -> Result<(), String> {
392                Ok(())
393            }
394        }
395
396        registry.register(Box::new(TestExt)).unwrap();
397        assert!(registry.has("test"));
398
399        assert!(registry.unregister("test"));
400        assert!(!registry.has("test"));
401    }
402
403    #[test]
404    fn test_routing_extension_validation() {
405        let registry = create_default_registry();
406
407        let valid_routing = ExtensionValue::Object(
408            vec![("priority".to_string(), json!(5))]
409                .into_iter()
410                .collect(),
411        );
412        assert!(registry.validate("routing", &valid_routing).is_ok());
413
414        let invalid_routing = ExtensionValue::Object(
415            vec![("priority".to_string(), json!(10))]
416                .into_iter()
417                .collect(),
418        );
419        assert!(registry.validate("routing", &invalid_routing).is_err());
420    }
421
422    #[test]
423    fn test_extension_value_serialization() {
424        let value = ExtensionValue::Object(
425            vec![("key".to_string(), json!("value"))]
426                .into_iter()
427                .collect(),
428        );
429
430        let serialized = serde_json::to_string(&value).unwrap();
431        let deserialized: ExtensionValue = serde_json::from_str(&serialized).unwrap();
432
433        assert_eq!(value, deserialized);
434    }
435}