1use serde::{Deserialize, Serialize};
45use std::collections::HashMap;
46
47#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
49#[serde(untagged)]
50pub enum ExtensionValue {
51 String(String),
53 Integer(i64),
55 Float(f64),
57 Boolean(bool),
59 Array(Vec<serde_json::Value>),
61 Object(HashMap<String, serde_json::Value>),
63 Null,
65}
66
67pub trait Extension: Send + Sync {
69 fn name(&self) -> &str;
71
72 fn validate(&self, value: &ExtensionValue) -> Result<(), String>;
74
75 fn transform(&self, value: ExtensionValue) -> Result<ExtensionValue, String> {
77 Ok(value)
78 }
79
80 fn is_compatible(&self, _version: crate::ProtocolVersion) -> bool {
82 true }
84}
85
86#[derive(Default)]
88pub struct ExtensionRegistry {
89 extensions: HashMap<String, Box<dyn Extension>>,
90}
91
92impl ExtensionRegistry {
93 pub fn new() -> Self {
95 Self {
96 extensions: HashMap::new(),
97 }
98 }
99
100 pub fn register(&mut self, extension: Box<dyn Extension>) -> Result<(), String> {
102 let name = extension.name().to_string();
103
104 if self.extensions.contains_key(&name) {
105 return Err(format!("Extension '{}' already registered", name));
106 }
107
108 self.extensions.insert(name, extension);
109 Ok(())
110 }
111
112 pub fn unregister(&mut self, name: &str) -> bool {
114 self.extensions.remove(name).is_some()
115 }
116
117 pub fn get(&self, name: &str) -> Option<&dyn Extension> {
119 self.extensions.get(name).map(|b| b.as_ref())
120 }
121
122 #[inline]
124 pub fn has(&self, name: &str) -> bool {
125 self.extensions.contains_key(name)
126 }
127
128 #[inline]
130 pub fn list(&self) -> Vec<&str> {
131 self.extensions.keys().map(|s| s.as_str()).collect()
132 }
133
134 pub fn validate(&self, name: &str, value: &ExtensionValue) -> Result<(), String> {
136 match self.get(name) {
137 Some(ext) => ext.validate(value),
138 None => Err(format!("Extension '{}' not registered", name)),
139 }
140 }
141
142 pub fn transform(&self, name: &str, value: ExtensionValue) -> Result<ExtensionValue, String> {
144 match self.get(name) {
145 Some(ext) => ext.transform(value),
146 None => Err(format!("Extension '{}' not registered", name)),
147 }
148 }
149
150 pub fn validate_all(&self, extensions: &HashMap<String, ExtensionValue>) -> Result<(), String> {
152 for (name, value) in extensions {
153 self.validate(name, value)?;
154 }
155 Ok(())
156 }
157}
158
159#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct ExtendedMessage {
162 #[serde(flatten)]
164 pub message: crate::Message,
165
166 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
168 pub extensions: HashMap<String, ExtensionValue>,
169}
170
171impl ExtendedMessage {
172 pub fn new(message: crate::Message) -> Self {
174 Self {
175 message,
176 extensions: HashMap::new(),
177 }
178 }
179
180 pub fn with_extension(mut self, name: String, value: ExtensionValue) -> Self {
182 self.extensions.insert(name, value);
183 self
184 }
185
186 pub fn get_extension(&self, name: &str) -> Option<&ExtensionValue> {
188 self.extensions.get(name)
189 }
190
191 pub fn remove_extension(&mut self, name: &str) -> Option<ExtensionValue> {
193 self.extensions.remove(name)
194 }
195
196 pub fn validate_extensions(&self, registry: &ExtensionRegistry) -> Result<(), String> {
198 registry.validate_all(&self.extensions)
199 }
200}
201
202pub struct TelemetryExtension;
206
207impl Extension for TelemetryExtension {
208 fn name(&self) -> &str {
209 "telemetry"
210 }
211
212 fn validate(&self, value: &ExtensionValue) -> Result<(), String> {
213 match value {
214 ExtensionValue::Object(map) => {
215 if !map.contains_key("trace_id") && !map.contains_key("span_id") {
216 return Err("Telemetry extension requires 'trace_id' or 'span_id'".to_string());
217 }
218 Ok(())
219 }
220 _ => Err("Telemetry extension must be an object".to_string()),
221 }
222 }
223}
224
225pub struct MetricsExtension;
227
228impl Extension for MetricsExtension {
229 fn name(&self) -> &str {
230 "metrics"
231 }
232
233 fn validate(&self, value: &ExtensionValue) -> Result<(), String> {
234 match value {
235 ExtensionValue::Object(_) | ExtensionValue::Array(_) => Ok(()),
236 _ => Err("Metrics extension must be an object or array".to_string()),
237 }
238 }
239}
240
241pub struct RoutingExtension;
243
244impl Extension for RoutingExtension {
245 fn name(&self) -> &str {
246 "routing"
247 }
248
249 fn validate(&self, value: &ExtensionValue) -> Result<(), String> {
250 match value {
251 ExtensionValue::Object(map) => {
252 if let Some(priority) = map.get("priority") {
253 if let Some(p) = priority.as_i64() {
254 if !(0..=9).contains(&p) {
255 return Err("Routing priority must be 0-9".to_string());
256 }
257 }
258 }
259 Ok(())
260 }
261 _ => Err("Routing extension must be an object".to_string()),
262 }
263 }
264}
265
266pub fn create_default_registry() -> ExtensionRegistry {
268 let mut registry = ExtensionRegistry::new();
269 registry
270 .register(Box::new(TelemetryExtension))
271 .expect("Failed to register TelemetryExtension");
272 registry
273 .register(Box::new(MetricsExtension))
274 .expect("Failed to register MetricsExtension");
275 registry
276 .register(Box::new(RoutingExtension))
277 .expect("Failed to register RoutingExtension");
278 registry
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284 use serde_json::json;
285 use uuid::Uuid;
286
287 #[test]
288 fn test_extension_registry() {
289 let mut registry = ExtensionRegistry::new();
290
291 struct TestExt;
292 impl Extension for TestExt {
293 fn name(&self) -> &str {
294 "test"
295 }
296 fn validate(&self, _value: &ExtensionValue) -> Result<(), String> {
297 Ok(())
298 }
299 }
300
301 assert!(registry.register(Box::new(TestExt)).is_ok());
302 assert!(registry.has("test"));
303 assert_eq!(registry.list(), vec!["test"]);
304 }
305
306 #[test]
307 fn test_duplicate_registration() {
308 let mut registry = ExtensionRegistry::new();
309
310 struct TestExt;
311 impl Extension for TestExt {
312 fn name(&self) -> &str {
313 "test"
314 }
315 fn validate(&self, _value: &ExtensionValue) -> Result<(), String> {
316 Ok(())
317 }
318 }
319
320 assert!(registry.register(Box::new(TestExt)).is_ok());
321 assert!(registry.register(Box::new(TestExt)).is_err());
322 }
323
324 #[test]
325 fn test_extension_validation() {
326 let registry = create_default_registry();
327
328 let telemetry = ExtensionValue::Object(
329 vec![("trace_id".to_string(), json!("abc123"))]
330 .into_iter()
331 .collect(),
332 );
333
334 assert!(registry.validate("telemetry", &telemetry).is_ok());
335 }
336
337 #[test]
338 fn test_invalid_telemetry() {
339 let registry = create_default_registry();
340
341 let invalid = ExtensionValue::Object(HashMap::new());
342 assert!(registry.validate("telemetry", &invalid).is_err());
343 }
344
345 #[test]
346 fn test_extended_message() {
347 let task_id = Uuid::new_v4();
348 let body = serde_json::to_vec(&crate::TaskArgs::new()).unwrap();
349 let msg = crate::Message::new("tasks.test".to_string(), task_id, body);
350
351 let ext_msg = ExtendedMessage::new(msg).with_extension(
352 "telemetry".to_string(),
353 ExtensionValue::Object(
354 vec![("trace_id".to_string(), json!("xyz789"))]
355 .into_iter()
356 .collect(),
357 ),
358 );
359
360 assert!(ext_msg.get_extension("telemetry").is_some());
361 }
362
363 #[test]
364 fn test_extended_message_validation() {
365 let task_id = Uuid::new_v4();
366 let body = serde_json::to_vec(&crate::TaskArgs::new()).unwrap();
367 let msg = crate::Message::new("tasks.test".to_string(), task_id, body);
368
369 let ext_msg = ExtendedMessage::new(msg).with_extension(
370 "telemetry".to_string(),
371 ExtensionValue::Object(
372 vec![("trace_id".to_string(), json!("abc123"))]
373 .into_iter()
374 .collect(),
375 ),
376 );
377
378 let registry = create_default_registry();
379 assert!(ext_msg.validate_extensions(®istry).is_ok());
380 }
381
382 #[test]
383 fn test_unregister_extension() {
384 let mut registry = ExtensionRegistry::new();
385
386 struct TestExt;
387 impl Extension for TestExt {
388 fn name(&self) -> &str {
389 "test"
390 }
391 fn validate(&self, _value: &ExtensionValue) -> Result<(), String> {
392 Ok(())
393 }
394 }
395
396 registry.register(Box::new(TestExt)).unwrap();
397 assert!(registry.has("test"));
398
399 assert!(registry.unregister("test"));
400 assert!(!registry.has("test"));
401 }
402
403 #[test]
404 fn test_routing_extension_validation() {
405 let registry = create_default_registry();
406
407 let valid_routing = ExtensionValue::Object(
408 vec![("priority".to_string(), json!(5))]
409 .into_iter()
410 .collect(),
411 );
412 assert!(registry.validate("routing", &valid_routing).is_ok());
413
414 let invalid_routing = ExtensionValue::Object(
415 vec![("priority".to_string(), json!(10))]
416 .into_iter()
417 .collect(),
418 );
419 assert!(registry.validate("routing", &invalid_routing).is_err());
420 }
421
422 #[test]
423 fn test_extension_value_serialization() {
424 let value = ExtensionValue::Object(
425 vec![("key".to_string(), json!("value"))]
426 .into_iter()
427 .collect(),
428 );
429
430 let serialized = serde_json::to_string(&value).unwrap();
431 let deserialized: ExtensionValue = serde_json::from_str(&serialized).unwrap();
432
433 assert_eq!(value, deserialized);
434 }
435}