jamjet_protocols/
registry.rs1use crate::ProtocolAdapter;
24use std::{collections::HashMap, sync::Arc};
25use tracing::{debug, warn};
26
27#[derive(Clone, Default)]
31pub struct ProtocolRegistry {
32 adapters: HashMap<String, Arc<dyn ProtocolAdapter>>,
34 url_prefixes: Vec<(String, String)>, }
38
39impl ProtocolRegistry {
40 pub fn new() -> Self {
42 Self::default()
43 }
44
45 pub fn register(
50 &mut self,
51 protocol_name: impl Into<String>,
52 adapter: Arc<dyn ProtocolAdapter>,
53 url_prefixes: impl IntoIterator<Item = impl Into<String>>,
54 ) {
55 let name: String = protocol_name.into();
56 for prefix in url_prefixes {
57 self.url_prefixes.push((prefix.into(), name.clone()));
58 }
59 debug!(protocol = %name, "Registered protocol adapter");
60 self.adapters.insert(name, adapter);
61 }
62
63 pub fn adapter(&self, protocol_name: &str) -> Option<Arc<dyn ProtocolAdapter>> {
65 self.adapters.get(protocol_name).cloned()
66 }
67
68 pub fn adapter_for_url(&self, url: &str) -> Option<Arc<dyn ProtocolAdapter>> {
73 let mut candidates: Vec<_> = self
75 .url_prefixes
76 .iter()
77 .filter(|(prefix, _)| url.starts_with(prefix.as_str()))
78 .collect();
79 candidates.sort_by(|a, b| b.0.len().cmp(&a.0.len()));
80
81 if let Some((prefix, proto)) = candidates.first() {
82 debug!(url = %url, prefix = %prefix, protocol = %proto, "URL matched protocol adapter");
83 self.adapters.get(proto.as_str()).cloned()
84 } else {
85 warn!(url = %url, "No protocol adapter matched URL");
86 None
87 }
88 }
89
90 pub fn protocols(&self) -> Vec<&str> {
92 self.adapters.keys().map(|s| s.as_str()).collect()
93 }
94}
95
96impl std::fmt::Debug for ProtocolRegistry {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 f.debug_struct("ProtocolRegistry")
99 .field("protocols", &self.protocols())
100 .finish()
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use crate::{RemoteCapabilities, TaskHandle, TaskRequest, TaskStatus, TaskStream};
108 use async_trait::async_trait;
109
110 struct FakeAdapter(String);
111
112 #[async_trait]
113 impl ProtocolAdapter for FakeAdapter {
114 async fn discover(&self, _url: &str) -> Result<RemoteCapabilities, String> {
115 Ok(RemoteCapabilities {
116 name: self.0.clone(),
117 description: None,
118 skills: vec![],
119 protocols: vec![self.0.clone()],
120 })
121 }
122 async fn invoke(&self, _url: &str, _task: TaskRequest) -> Result<TaskHandle, String> {
123 Err("not implemented".into())
124 }
125 async fn stream(&self, _url: &str, _task: TaskRequest) -> Result<TaskStream, String> {
126 Err("not implemented".into())
127 }
128 async fn status(&self, _url: &str, _task_id: &str) -> Result<TaskStatus, String> {
129 Err("not implemented".into())
130 }
131 async fn cancel(&self, _url: &str, _task_id: &str) -> Result<(), String> {
132 Ok(())
133 }
134 }
135
136 #[test]
137 fn test_register_and_lookup_by_name() {
138 let mut reg = ProtocolRegistry::new();
139 reg.register(
140 "mcp",
141 Arc::new(FakeAdapter("mcp".into())),
142 vec!["http://mcp/"],
143 );
144 assert!(reg.adapter("mcp").is_some());
145 assert!(reg.adapter("a2a").is_none());
146 }
147
148 #[test]
149 fn test_adapter_for_url_matches_prefix() {
150 let mut reg = ProtocolRegistry::new();
151 reg.register("anp", Arc::new(FakeAdapter("anp".into())), vec!["did:"]);
152 reg.register(
153 "mcp",
154 Arc::new(FakeAdapter("mcp".into())),
155 vec!["http://mcp."],
156 );
157
158 assert!(reg.adapter_for_url("did:web:example.com").is_some());
159 assert!(reg
160 .adapter_for_url("http://mcp.example.com/tools")
161 .is_some());
162 assert!(reg.adapter_for_url("https://unknown.com").is_none());
163 }
164
165 #[test]
166 fn test_longest_prefix_wins() {
167 let mut reg = ProtocolRegistry::new();
168 reg.register(
169 "generic-http",
170 Arc::new(FakeAdapter("generic".into())),
171 vec!["http://"],
172 );
173 reg.register(
174 "specific-mcp",
175 Arc::new(FakeAdapter("specific".into())),
176 vec!["http://mcp.example.com/"],
177 );
178
179 let adapter = reg
180 .adapter_for_url("http://mcp.example.com/v1")
181 .expect("should match");
182 tokio::runtime::Runtime::new().unwrap().block_on(async {
184 let caps = adapter.discover("").await.unwrap();
185 assert_eq!(caps.name, "specific");
186 });
187 }
188
189 #[test]
190 fn test_protocols_list() {
191 let mut reg = ProtocolRegistry::new();
192 reg.register(
193 "mcp",
194 Arc::new(FakeAdapter("mcp".into())),
195 vec![] as Vec<String>,
196 );
197 reg.register(
198 "a2a",
199 Arc::new(FakeAdapter("a2a".into())),
200 vec![] as Vec<String>,
201 );
202 let mut protos = reg.protocols();
203 protos.sort();
204 assert_eq!(protos, vec!["a2a", "mcp"]);
205 }
206}