ipfrs_network/
protocol.rs

1//! Protocol handler registry and version negotiation
2//!
3//! Provides:
4//! - Protocol handler registration and lifecycle
5//! - Protocol version negotiation
6//! - Protocol capability advertisement
7//! - Dynamic handler loading
8
9use ipfrs_core::error::{Error, Result};
10use std::collections::HashMap;
11use std::sync::Arc;
12
13/// Protocol version using semantic versioning
14#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
15pub struct ProtocolVersion {
16    pub major: u16,
17    pub minor: u16,
18    pub patch: u16,
19}
20
21impl ProtocolVersion {
22    /// Create a new protocol version
23    pub fn new(major: u16, minor: u16, patch: u16) -> Self {
24        Self {
25            major,
26            minor,
27            patch,
28        }
29    }
30
31    /// Parse version from string (e.g., "1.2.3")
32    pub fn parse(s: &str) -> Result<Self> {
33        let parts: Vec<&str> = s.split('.').collect();
34        if parts.len() != 3 {
35            return Err(Error::Network(format!("Invalid version string: {}", s)));
36        }
37
38        let major = parts[0]
39            .parse()
40            .map_err(|e| Error::Network(format!("Invalid major version: {}", e)))?;
41        let minor = parts[1]
42            .parse()
43            .map_err(|e| Error::Network(format!("Invalid minor version: {}", e)))?;
44        let patch = parts[2]
45            .parse()
46            .map_err(|e| Error::Network(format!("Invalid patch version: {}", e)))?;
47
48        Ok(Self::new(major, minor, patch))
49    }
50
51    /// Check if this version is compatible with another
52    /// Compatible if major versions match and this minor >= other minor
53    pub fn is_compatible_with(&self, other: &ProtocolVersion) -> bool {
54        self.major == other.major && self.minor >= other.minor
55    }
56}
57
58impl std::fmt::Display for ProtocolVersion {
59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60        write!(f, "{}.{}.{}", self.major, self.minor, self.patch)
61    }
62}
63
64/// Protocol identifier
65#[derive(Debug, Clone, PartialEq, Eq, Hash)]
66pub struct ProtocolId {
67    /// Protocol name
68    pub name: String,
69    /// Protocol version
70    pub version: ProtocolVersion,
71}
72
73impl ProtocolId {
74    /// Create a new protocol ID
75    pub fn new(name: String, version: ProtocolVersion) -> Self {
76        Self { name, version }
77    }
78
79    /// Get the full protocol string (e.g., "/ipfrs/tensorswap/1.0.0")
80    pub fn to_protocol_string(&self) -> String {
81        format!("/ipfrs/{}/{}", self.name, self.version)
82    }
83
84    /// Parse protocol ID from string
85    pub fn parse(s: &str) -> Result<Self> {
86        let parts: Vec<&str> = s.trim_matches('/').split('/').collect();
87        if parts.len() != 3 || parts[0] != "ipfrs" {
88            return Err(Error::Network(format!("Invalid protocol string: {}", s)));
89        }
90
91        let name = parts[1].to_string();
92        let version = ProtocolVersion::parse(parts[2])?;
93
94        Ok(Self::new(name, version))
95    }
96}
97
98impl std::fmt::Display for ProtocolId {
99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100        write!(f, "{}", self.to_protocol_string())
101    }
102}
103
104/// Protocol capabilities
105#[derive(Debug, Clone)]
106pub struct ProtocolCapabilities {
107    /// Supported features
108    pub features: Vec<String>,
109    /// Maximum message size
110    pub max_message_size: usize,
111    /// Whether protocol supports streaming
112    pub supports_streaming: bool,
113}
114
115impl Default for ProtocolCapabilities {
116    fn default() -> Self {
117        Self {
118            features: Vec::new(),
119            max_message_size: 1024 * 1024, // 1MB default
120            supports_streaming: false,
121        }
122    }
123}
124
125/// Type alias for boxed protocol handler
126type BoxedProtocolHandler = Arc<parking_lot::RwLock<Box<dyn ProtocolHandler>>>;
127
128/// Protocol handler trait
129pub trait ProtocolHandler: Send + Sync {
130    /// Get protocol ID
131    fn protocol_id(&self) -> ProtocolId;
132
133    /// Get protocol capabilities
134    fn capabilities(&self) -> ProtocolCapabilities {
135        ProtocolCapabilities::default()
136    }
137
138    /// Handle incoming protocol request
139    fn handle_request(&mut self, request: &[u8]) -> Result<Vec<u8>>;
140
141    /// Initialize the handler
142    fn initialize(&mut self) -> Result<()> {
143        Ok(())
144    }
145
146    /// Shutdown the handler
147    fn shutdown(&mut self) -> Result<()> {
148        Ok(())
149    }
150}
151
152/// Protocol handler registry
153pub struct ProtocolRegistry {
154    /// Registered handlers by protocol ID
155    handlers: parking_lot::RwLock<HashMap<ProtocolId, BoxedProtocolHandler>>,
156    /// Protocol aliases (name -> list of versions)
157    aliases: parking_lot::RwLock<HashMap<String, Vec<ProtocolVersion>>>,
158}
159
160impl ProtocolRegistry {
161    /// Create a new protocol registry
162    pub fn new() -> Self {
163        Self {
164            handlers: parking_lot::RwLock::new(HashMap::new()),
165            aliases: parking_lot::RwLock::new(HashMap::new()),
166        }
167    }
168
169    /// Register a protocol handler
170    pub fn register(&self, handler: Box<dyn ProtocolHandler>) -> Result<()> {
171        let protocol_id = handler.protocol_id();
172        let mut handlers = self.handlers.write();
173
174        if handlers.contains_key(&protocol_id) {
175            return Err(Error::Network(format!(
176                "Protocol already registered: {}",
177                protocol_id
178            )));
179        }
180
181        // Add to aliases
182        let mut aliases = self.aliases.write();
183        aliases
184            .entry(protocol_id.name.clone())
185            .or_default()
186            .push(protocol_id.version.clone());
187
188        handlers.insert(protocol_id, Arc::new(parking_lot::RwLock::new(handler)));
189
190        Ok(())
191    }
192
193    /// Unregister a protocol handler
194    pub fn unregister(&self, protocol_id: &ProtocolId) -> Result<()> {
195        let mut handlers = self.handlers.write();
196
197        if let Some(handler) = handlers.remove(protocol_id) {
198            // Shutdown the handler
199            let mut handler = handler.write();
200            handler.shutdown()?;
201
202            // Remove from aliases
203            let mut aliases = self.aliases.write();
204            if let Some(versions) = aliases.get_mut(&protocol_id.name) {
205                versions.retain(|v| v != &protocol_id.version);
206                if versions.is_empty() {
207                    aliases.remove(&protocol_id.name);
208                }
209            }
210
211            Ok(())
212        } else {
213            Err(Error::Network(format!(
214                "Protocol not registered: {}",
215                protocol_id
216            )))
217        }
218    }
219
220    /// Get a protocol handler
221    pub fn get(&self, protocol_id: &ProtocolId) -> Option<BoxedProtocolHandler> {
222        let handlers = self.handlers.read();
223        handlers.get(protocol_id).cloned()
224    }
225
226    /// Find compatible protocol version
227    pub fn find_compatible(&self, name: &str, min_version: &ProtocolVersion) -> Option<ProtocolId> {
228        let aliases = self.aliases.read();
229        if let Some(versions) = aliases.get(name) {
230            // Find the highest compatible version
231            let mut compatible_versions: Vec<_> = versions
232                .iter()
233                .filter(|v| v.is_compatible_with(min_version))
234                .collect();
235
236            compatible_versions.sort_by(|a, b| b.cmp(a)); // Sort descending
237
238            if let Some(version) = compatible_versions.first() {
239                return Some(ProtocolId::new(name.to_string(), (*version).clone()));
240            }
241        }
242        None
243    }
244
245    /// Get all registered protocol IDs
246    pub fn list_protocols(&self) -> Vec<ProtocolId> {
247        let handlers = self.handlers.read();
248        handlers.keys().cloned().collect()
249    }
250
251    /// Handle a request with the appropriate protocol handler
252    pub fn handle_request(&self, protocol_id: &ProtocolId, request: &[u8]) -> Result<Vec<u8>> {
253        if let Some(handler) = self.get(protocol_id) {
254            let mut handler = handler.write();
255            handler.handle_request(request)
256        } else {
257            Err(Error::Network(format!(
258                "No handler registered for protocol: {}",
259                protocol_id
260            )))
261        }
262    }
263
264    /// Get protocol capabilities
265    pub fn get_capabilities(&self, protocol_id: &ProtocolId) -> Option<ProtocolCapabilities> {
266        if let Some(handler) = self.get(protocol_id) {
267            let handler = handler.read();
268            Some(handler.capabilities())
269        } else {
270            None
271        }
272    }
273
274    /// Shutdown all handlers
275    pub fn shutdown_all(&self) -> Result<()> {
276        let handlers = self.handlers.write();
277        for handler in handlers.values() {
278            let mut handler = handler.write();
279            handler.shutdown()?;
280        }
281        Ok(())
282    }
283}
284
285impl Default for ProtocolRegistry {
286    fn default() -> Self {
287        Self::new()
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    // Mock protocol handler for testing
296    struct MockProtocolHandler {
297        id: ProtocolId,
298    }
299
300    impl MockProtocolHandler {
301        fn new(name: &str, version: ProtocolVersion) -> Self {
302            Self {
303                id: ProtocolId::new(name.to_string(), version),
304            }
305        }
306    }
307
308    impl ProtocolHandler for MockProtocolHandler {
309        fn protocol_id(&self) -> ProtocolId {
310            self.id.clone()
311        }
312
313        fn handle_request(&mut self, request: &[u8]) -> Result<Vec<u8>> {
314            Ok(request.to_vec())
315        }
316    }
317
318    #[test]
319    fn test_protocol_version_creation() {
320        let version = ProtocolVersion::new(1, 2, 3);
321        assert_eq!(version.major, 1);
322        assert_eq!(version.minor, 2);
323        assert_eq!(version.patch, 3);
324    }
325
326    #[test]
327    fn test_protocol_version_parse() {
328        let version = ProtocolVersion::parse("1.2.3").unwrap();
329        assert_eq!(version.major, 1);
330        assert_eq!(version.minor, 2);
331        assert_eq!(version.patch, 3);
332
333        assert!(ProtocolVersion::parse("invalid").is_err());
334        assert!(ProtocolVersion::parse("1.2").is_err());
335    }
336
337    #[test]
338    fn test_protocol_version_compatibility() {
339        let v1_0_0 = ProtocolVersion::new(1, 0, 0);
340        let v1_1_0 = ProtocolVersion::new(1, 1, 0);
341        let v1_2_0 = ProtocolVersion::new(1, 2, 0);
342        let v2_0_0 = ProtocolVersion::new(2, 0, 0);
343
344        // Same major, higher or equal minor is compatible
345        assert!(v1_2_0.is_compatible_with(&v1_0_0));
346        assert!(v1_1_0.is_compatible_with(&v1_0_0));
347        assert!(v1_0_0.is_compatible_with(&v1_0_0));
348
349        // Lower minor is not compatible
350        assert!(!v1_0_0.is_compatible_with(&v1_1_0));
351
352        // Different major is not compatible
353        assert!(!v2_0_0.is_compatible_with(&v1_0_0));
354        assert!(!v1_0_0.is_compatible_with(&v2_0_0));
355    }
356
357    #[test]
358    fn test_protocol_version_display() {
359        let version = ProtocolVersion::new(1, 2, 3);
360        assert_eq!(format!("{}", version), "1.2.3");
361    }
362
363    #[test]
364    fn test_protocol_id_creation() {
365        let version = ProtocolVersion::new(1, 0, 0);
366        let id = ProtocolId::new("test".to_string(), version);
367        assert_eq!(id.name, "test");
368        assert_eq!(id.version.major, 1);
369    }
370
371    #[test]
372    fn test_protocol_id_to_string() {
373        let version = ProtocolVersion::new(1, 0, 0);
374        let id = ProtocolId::new("tensorswap".to_string(), version);
375        assert_eq!(id.to_protocol_string(), "/ipfrs/tensorswap/1.0.0");
376    }
377
378    #[test]
379    fn test_protocol_id_parse() {
380        let id = ProtocolId::parse("/ipfrs/tensorswap/1.0.0").unwrap();
381        assert_eq!(id.name, "tensorswap");
382        assert_eq!(id.version.major, 1);
383
384        assert!(ProtocolId::parse("/invalid/tensorswap/1.0.0").is_err());
385        assert!(ProtocolId::parse("/ipfrs/tensorswap").is_err());
386    }
387
388    #[test]
389    fn test_registry_creation() {
390        let registry = ProtocolRegistry::new();
391        assert_eq!(registry.list_protocols().len(), 0);
392    }
393
394    #[test]
395    fn test_register_handler() {
396        let registry = ProtocolRegistry::new();
397        let handler = Box::new(MockProtocolHandler::new(
398            "test",
399            ProtocolVersion::new(1, 0, 0),
400        ));
401
402        registry.register(handler).unwrap();
403        assert_eq!(registry.list_protocols().len(), 1);
404    }
405
406    #[test]
407    fn test_register_duplicate() {
408        let registry = ProtocolRegistry::new();
409        let handler1 = Box::new(MockProtocolHandler::new(
410            "test",
411            ProtocolVersion::new(1, 0, 0),
412        ));
413        let handler2 = Box::new(MockProtocolHandler::new(
414            "test",
415            ProtocolVersion::new(1, 0, 0),
416        ));
417
418        registry.register(handler1).unwrap();
419        assert!(registry.register(handler2).is_err());
420    }
421
422    #[test]
423    fn test_get_handler() {
424        let registry = ProtocolRegistry::new();
425        let version = ProtocolVersion::new(1, 0, 0);
426        let protocol_id = ProtocolId::new("test".to_string(), version.clone());
427
428        let handler = Box::new(MockProtocolHandler::new("test", version));
429        registry.register(handler).unwrap();
430
431        let retrieved = registry.get(&protocol_id);
432        assert!(retrieved.is_some());
433    }
434
435    #[test]
436    fn test_find_compatible() {
437        let registry = ProtocolRegistry::new();
438
439        let handler1 = Box::new(MockProtocolHandler::new(
440            "test",
441            ProtocolVersion::new(1, 0, 0),
442        ));
443        let handler2 = Box::new(MockProtocolHandler::new(
444            "test",
445            ProtocolVersion::new(1, 1, 0),
446        ));
447        let handler3 = Box::new(MockProtocolHandler::new(
448            "test",
449            ProtocolVersion::new(1, 2, 0),
450        ));
451
452        registry.register(handler1).unwrap();
453        registry.register(handler2).unwrap();
454        registry.register(handler3).unwrap();
455
456        // Should find the highest compatible version
457        let min_version = ProtocolVersion::new(1, 0, 0);
458        let compatible = registry.find_compatible("test", &min_version);
459
460        assert!(compatible.is_some());
461        let compatible = compatible.unwrap();
462        assert_eq!(compatible.version.major, 1);
463        assert_eq!(compatible.version.minor, 2);
464    }
465
466    #[test]
467    fn test_unregister_handler() {
468        let registry = ProtocolRegistry::new();
469        let version = ProtocolVersion::new(1, 0, 0);
470        let protocol_id = ProtocolId::new("test".to_string(), version.clone());
471
472        let handler = Box::new(MockProtocolHandler::new("test", version));
473        registry.register(handler).unwrap();
474
475        registry.unregister(&protocol_id).unwrap();
476        assert_eq!(registry.list_protocols().len(), 0);
477    }
478
479    #[test]
480    fn test_handle_request() {
481        let registry = ProtocolRegistry::new();
482        let version = ProtocolVersion::new(1, 0, 0);
483        let protocol_id = ProtocolId::new("test".to_string(), version.clone());
484
485        let handler = Box::new(MockProtocolHandler::new("test", version));
486        registry.register(handler).unwrap();
487
488        let request = b"test request";
489        let response = registry.handle_request(&protocol_id, request).unwrap();
490
491        assert_eq!(response, request);
492    }
493
494    #[test]
495    fn test_protocol_capabilities_default() {
496        let caps = ProtocolCapabilities::default();
497        assert_eq!(caps.max_message_size, 1024 * 1024);
498        assert!(!caps.supports_streaming);
499    }
500}