1use ipfrs_core::error::{Error, Result};
10use std::collections::HashMap;
11use std::sync::Arc;
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
15pub struct ProtocolVersion {
16 pub major: u16,
17 pub minor: u16,
18 pub patch: u16,
19}
20
21impl ProtocolVersion {
22 pub fn new(major: u16, minor: u16, patch: u16) -> Self {
24 Self {
25 major,
26 minor,
27 patch,
28 }
29 }
30
31 pub fn parse(s: &str) -> Result<Self> {
33 let parts: Vec<&str> = s.split('.').collect();
34 if parts.len() != 3 {
35 return Err(Error::Network(format!("Invalid version string: {}", s)));
36 }
37
38 let major = parts[0]
39 .parse()
40 .map_err(|e| Error::Network(format!("Invalid major version: {}", e)))?;
41 let minor = parts[1]
42 .parse()
43 .map_err(|e| Error::Network(format!("Invalid minor version: {}", e)))?;
44 let patch = parts[2]
45 .parse()
46 .map_err(|e| Error::Network(format!("Invalid patch version: {}", e)))?;
47
48 Ok(Self::new(major, minor, patch))
49 }
50
51 pub fn is_compatible_with(&self, other: &ProtocolVersion) -> bool {
54 self.major == other.major && self.minor >= other.minor
55 }
56}
57
58impl std::fmt::Display for ProtocolVersion {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 write!(f, "{}.{}.{}", self.major, self.minor, self.patch)
61 }
62}
63
64#[derive(Debug, Clone, PartialEq, Eq, Hash)]
66pub struct ProtocolId {
67 pub name: String,
69 pub version: ProtocolVersion,
71}
72
73impl ProtocolId {
74 pub fn new(name: String, version: ProtocolVersion) -> Self {
76 Self { name, version }
77 }
78
79 pub fn to_protocol_string(&self) -> String {
81 format!("/ipfrs/{}/{}", self.name, self.version)
82 }
83
84 pub fn parse(s: &str) -> Result<Self> {
86 let parts: Vec<&str> = s.trim_matches('/').split('/').collect();
87 if parts.len() != 3 || parts[0] != "ipfrs" {
88 return Err(Error::Network(format!("Invalid protocol string: {}", s)));
89 }
90
91 let name = parts[1].to_string();
92 let version = ProtocolVersion::parse(parts[2])?;
93
94 Ok(Self::new(name, version))
95 }
96}
97
98impl std::fmt::Display for ProtocolId {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 write!(f, "{}", self.to_protocol_string())
101 }
102}
103
104#[derive(Debug, Clone)]
106pub struct ProtocolCapabilities {
107 pub features: Vec<String>,
109 pub max_message_size: usize,
111 pub supports_streaming: bool,
113}
114
115impl Default for ProtocolCapabilities {
116 fn default() -> Self {
117 Self {
118 features: Vec::new(),
119 max_message_size: 1024 * 1024, supports_streaming: false,
121 }
122 }
123}
124
125type BoxedProtocolHandler = Arc<parking_lot::RwLock<Box<dyn ProtocolHandler>>>;
127
128pub trait ProtocolHandler: Send + Sync {
130 fn protocol_id(&self) -> ProtocolId;
132
133 fn capabilities(&self) -> ProtocolCapabilities {
135 ProtocolCapabilities::default()
136 }
137
138 fn handle_request(&mut self, request: &[u8]) -> Result<Vec<u8>>;
140
141 fn initialize(&mut self) -> Result<()> {
143 Ok(())
144 }
145
146 fn shutdown(&mut self) -> Result<()> {
148 Ok(())
149 }
150}
151
152pub struct ProtocolRegistry {
154 handlers: parking_lot::RwLock<HashMap<ProtocolId, BoxedProtocolHandler>>,
156 aliases: parking_lot::RwLock<HashMap<String, Vec<ProtocolVersion>>>,
158}
159
160impl ProtocolRegistry {
161 pub fn new() -> Self {
163 Self {
164 handlers: parking_lot::RwLock::new(HashMap::new()),
165 aliases: parking_lot::RwLock::new(HashMap::new()),
166 }
167 }
168
169 pub fn register(&self, handler: Box<dyn ProtocolHandler>) -> Result<()> {
171 let protocol_id = handler.protocol_id();
172 let mut handlers = self.handlers.write();
173
174 if handlers.contains_key(&protocol_id) {
175 return Err(Error::Network(format!(
176 "Protocol already registered: {}",
177 protocol_id
178 )));
179 }
180
181 let mut aliases = self.aliases.write();
183 aliases
184 .entry(protocol_id.name.clone())
185 .or_default()
186 .push(protocol_id.version.clone());
187
188 handlers.insert(protocol_id, Arc::new(parking_lot::RwLock::new(handler)));
189
190 Ok(())
191 }
192
193 pub fn unregister(&self, protocol_id: &ProtocolId) -> Result<()> {
195 let mut handlers = self.handlers.write();
196
197 if let Some(handler) = handlers.remove(protocol_id) {
198 let mut handler = handler.write();
200 handler.shutdown()?;
201
202 let mut aliases = self.aliases.write();
204 if let Some(versions) = aliases.get_mut(&protocol_id.name) {
205 versions.retain(|v| v != &protocol_id.version);
206 if versions.is_empty() {
207 aliases.remove(&protocol_id.name);
208 }
209 }
210
211 Ok(())
212 } else {
213 Err(Error::Network(format!(
214 "Protocol not registered: {}",
215 protocol_id
216 )))
217 }
218 }
219
220 pub fn get(&self, protocol_id: &ProtocolId) -> Option<BoxedProtocolHandler> {
222 let handlers = self.handlers.read();
223 handlers.get(protocol_id).cloned()
224 }
225
226 pub fn find_compatible(&self, name: &str, min_version: &ProtocolVersion) -> Option<ProtocolId> {
228 let aliases = self.aliases.read();
229 if let Some(versions) = aliases.get(name) {
230 let mut compatible_versions: Vec<_> = versions
232 .iter()
233 .filter(|v| v.is_compatible_with(min_version))
234 .collect();
235
236 compatible_versions.sort_by(|a, b| b.cmp(a)); if let Some(version) = compatible_versions.first() {
239 return Some(ProtocolId::new(name.to_string(), (*version).clone()));
240 }
241 }
242 None
243 }
244
245 pub fn list_protocols(&self) -> Vec<ProtocolId> {
247 let handlers = self.handlers.read();
248 handlers.keys().cloned().collect()
249 }
250
251 pub fn handle_request(&self, protocol_id: &ProtocolId, request: &[u8]) -> Result<Vec<u8>> {
253 if let Some(handler) = self.get(protocol_id) {
254 let mut handler = handler.write();
255 handler.handle_request(request)
256 } else {
257 Err(Error::Network(format!(
258 "No handler registered for protocol: {}",
259 protocol_id
260 )))
261 }
262 }
263
264 pub fn get_capabilities(&self, protocol_id: &ProtocolId) -> Option<ProtocolCapabilities> {
266 if let Some(handler) = self.get(protocol_id) {
267 let handler = handler.read();
268 Some(handler.capabilities())
269 } else {
270 None
271 }
272 }
273
274 pub fn shutdown_all(&self) -> Result<()> {
276 let handlers = self.handlers.write();
277 for handler in handlers.values() {
278 let mut handler = handler.write();
279 handler.shutdown()?;
280 }
281 Ok(())
282 }
283}
284
285impl Default for ProtocolRegistry {
286 fn default() -> Self {
287 Self::new()
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294
295 struct MockProtocolHandler {
297 id: ProtocolId,
298 }
299
300 impl MockProtocolHandler {
301 fn new(name: &str, version: ProtocolVersion) -> Self {
302 Self {
303 id: ProtocolId::new(name.to_string(), version),
304 }
305 }
306 }
307
308 impl ProtocolHandler for MockProtocolHandler {
309 fn protocol_id(&self) -> ProtocolId {
310 self.id.clone()
311 }
312
313 fn handle_request(&mut self, request: &[u8]) -> Result<Vec<u8>> {
314 Ok(request.to_vec())
315 }
316 }
317
318 #[test]
319 fn test_protocol_version_creation() {
320 let version = ProtocolVersion::new(1, 2, 3);
321 assert_eq!(version.major, 1);
322 assert_eq!(version.minor, 2);
323 assert_eq!(version.patch, 3);
324 }
325
326 #[test]
327 fn test_protocol_version_parse() {
328 let version = ProtocolVersion::parse("1.2.3").unwrap();
329 assert_eq!(version.major, 1);
330 assert_eq!(version.minor, 2);
331 assert_eq!(version.patch, 3);
332
333 assert!(ProtocolVersion::parse("invalid").is_err());
334 assert!(ProtocolVersion::parse("1.2").is_err());
335 }
336
337 #[test]
338 fn test_protocol_version_compatibility() {
339 let v1_0_0 = ProtocolVersion::new(1, 0, 0);
340 let v1_1_0 = ProtocolVersion::new(1, 1, 0);
341 let v1_2_0 = ProtocolVersion::new(1, 2, 0);
342 let v2_0_0 = ProtocolVersion::new(2, 0, 0);
343
344 assert!(v1_2_0.is_compatible_with(&v1_0_0));
346 assert!(v1_1_0.is_compatible_with(&v1_0_0));
347 assert!(v1_0_0.is_compatible_with(&v1_0_0));
348
349 assert!(!v1_0_0.is_compatible_with(&v1_1_0));
351
352 assert!(!v2_0_0.is_compatible_with(&v1_0_0));
354 assert!(!v1_0_0.is_compatible_with(&v2_0_0));
355 }
356
357 #[test]
358 fn test_protocol_version_display() {
359 let version = ProtocolVersion::new(1, 2, 3);
360 assert_eq!(format!("{}", version), "1.2.3");
361 }
362
363 #[test]
364 fn test_protocol_id_creation() {
365 let version = ProtocolVersion::new(1, 0, 0);
366 let id = ProtocolId::new("test".to_string(), version);
367 assert_eq!(id.name, "test");
368 assert_eq!(id.version.major, 1);
369 }
370
371 #[test]
372 fn test_protocol_id_to_string() {
373 let version = ProtocolVersion::new(1, 0, 0);
374 let id = ProtocolId::new("tensorswap".to_string(), version);
375 assert_eq!(id.to_protocol_string(), "/ipfrs/tensorswap/1.0.0");
376 }
377
378 #[test]
379 fn test_protocol_id_parse() {
380 let id = ProtocolId::parse("/ipfrs/tensorswap/1.0.0").unwrap();
381 assert_eq!(id.name, "tensorswap");
382 assert_eq!(id.version.major, 1);
383
384 assert!(ProtocolId::parse("/invalid/tensorswap/1.0.0").is_err());
385 assert!(ProtocolId::parse("/ipfrs/tensorswap").is_err());
386 }
387
388 #[test]
389 fn test_registry_creation() {
390 let registry = ProtocolRegistry::new();
391 assert_eq!(registry.list_protocols().len(), 0);
392 }
393
394 #[test]
395 fn test_register_handler() {
396 let registry = ProtocolRegistry::new();
397 let handler = Box::new(MockProtocolHandler::new(
398 "test",
399 ProtocolVersion::new(1, 0, 0),
400 ));
401
402 registry.register(handler).unwrap();
403 assert_eq!(registry.list_protocols().len(), 1);
404 }
405
406 #[test]
407 fn test_register_duplicate() {
408 let registry = ProtocolRegistry::new();
409 let handler1 = Box::new(MockProtocolHandler::new(
410 "test",
411 ProtocolVersion::new(1, 0, 0),
412 ));
413 let handler2 = Box::new(MockProtocolHandler::new(
414 "test",
415 ProtocolVersion::new(1, 0, 0),
416 ));
417
418 registry.register(handler1).unwrap();
419 assert!(registry.register(handler2).is_err());
420 }
421
422 #[test]
423 fn test_get_handler() {
424 let registry = ProtocolRegistry::new();
425 let version = ProtocolVersion::new(1, 0, 0);
426 let protocol_id = ProtocolId::new("test".to_string(), version.clone());
427
428 let handler = Box::new(MockProtocolHandler::new("test", version));
429 registry.register(handler).unwrap();
430
431 let retrieved = registry.get(&protocol_id);
432 assert!(retrieved.is_some());
433 }
434
435 #[test]
436 fn test_find_compatible() {
437 let registry = ProtocolRegistry::new();
438
439 let handler1 = Box::new(MockProtocolHandler::new(
440 "test",
441 ProtocolVersion::new(1, 0, 0),
442 ));
443 let handler2 = Box::new(MockProtocolHandler::new(
444 "test",
445 ProtocolVersion::new(1, 1, 0),
446 ));
447 let handler3 = Box::new(MockProtocolHandler::new(
448 "test",
449 ProtocolVersion::new(1, 2, 0),
450 ));
451
452 registry.register(handler1).unwrap();
453 registry.register(handler2).unwrap();
454 registry.register(handler3).unwrap();
455
456 let min_version = ProtocolVersion::new(1, 0, 0);
458 let compatible = registry.find_compatible("test", &min_version);
459
460 assert!(compatible.is_some());
461 let compatible = compatible.unwrap();
462 assert_eq!(compatible.version.major, 1);
463 assert_eq!(compatible.version.minor, 2);
464 }
465
466 #[test]
467 fn test_unregister_handler() {
468 let registry = ProtocolRegistry::new();
469 let version = ProtocolVersion::new(1, 0, 0);
470 let protocol_id = ProtocolId::new("test".to_string(), version.clone());
471
472 let handler = Box::new(MockProtocolHandler::new("test", version));
473 registry.register(handler).unwrap();
474
475 registry.unregister(&protocol_id).unwrap();
476 assert_eq!(registry.list_protocols().len(), 0);
477 }
478
479 #[test]
480 fn test_handle_request() {
481 let registry = ProtocolRegistry::new();
482 let version = ProtocolVersion::new(1, 0, 0);
483 let protocol_id = ProtocolId::new("test".to_string(), version.clone());
484
485 let handler = Box::new(MockProtocolHandler::new("test", version));
486 registry.register(handler).unwrap();
487
488 let request = b"test request";
489 let response = registry.handle_request(&protocol_id, request).unwrap();
490
491 assert_eq!(response, request);
492 }
493
494 #[test]
495 fn test_protocol_capabilities_default() {
496 let caps = ProtocolCapabilities::default();
497 assert_eq!(caps.max_message_size, 1024 * 1024);
498 assert!(!caps.supports_streaming);
499 }
500}