mcpkit_core/extension/
discovery.rs1use super::{Extension, ExtensionRegistry};
40use crate::capability::{ClientCapabilities, ServerCapabilities};
41use std::collections::HashMap;
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum ExtensionRequirement {
46 Required,
48 Optional,
50}
51
52#[derive(Debug, Clone, Default)]
57pub struct ExtensionQuery {
58 requirements: HashMap<String, ExtensionRequirement>,
59 min_versions: HashMap<String, String>,
60}
61
62impl ExtensionQuery {
63 #[must_use]
65 pub fn new() -> Self {
66 Self::default()
67 }
68
69 #[must_use]
73 pub fn require(mut self, name: impl Into<String>) -> Self {
74 self.requirements
75 .insert(name.into(), ExtensionRequirement::Required);
76 self
77 }
78
79 #[must_use]
83 pub fn optional(mut self, name: impl Into<String>) -> Self {
84 self.requirements
85 .insert(name.into(), ExtensionRequirement::Optional);
86 self
87 }
88
89 #[must_use]
93 pub fn with_min_version(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
94 let name = name.into();
95 self.requirements
97 .entry(name.clone())
98 .or_insert(ExtensionRequirement::Required);
99 self.min_versions.insert(name, version.into());
100 self
101 }
102
103 #[must_use]
105 pub fn check(&self, capabilities: &ServerCapabilities) -> ExtensionQueryResult {
106 let registry = capabilities
107 .experimental
108 .as_ref()
109 .and_then(ExtensionRegistry::from_experimental);
110
111 let mut found = HashMap::new();
112 let mut missing_required = Vec::new();
113 let mut version_mismatches = Vec::new();
114
115 for (name, requirement) in &self.requirements {
116 let extension = registry.as_ref().and_then(|r| r.get(name));
117
118 if let Some(ext) = extension {
119 if let Some(min_version) = self.min_versions.get(name) {
121 if let Some(ref actual_version) = ext.version {
122 if !version_satisfies(actual_version, min_version) {
123 version_mismatches.push(VersionMismatch {
124 extension: name.clone(),
125 required: min_version.clone(),
126 actual: actual_version.clone(),
127 });
128 continue;
129 }
130 }
131 }
132 found.insert(name.clone(), ext.clone());
133 } else if *requirement == ExtensionRequirement::Required {
134 missing_required.push(name.clone());
135 }
136 }
137
138 ExtensionQueryResult {
139 found,
140 missing_required,
141 version_mismatches,
142 }
143 }
144
145 #[must_use]
147 pub fn check_client(&self, capabilities: &ClientCapabilities) -> ExtensionQueryResult {
148 let registry = capabilities
149 .experimental
150 .as_ref()
151 .and_then(ExtensionRegistry::from_experimental);
152
153 let mut found = HashMap::new();
154 let mut missing_required = Vec::new();
155 let mut version_mismatches = Vec::new();
156
157 for (name, requirement) in &self.requirements {
158 let extension = registry.as_ref().and_then(|r| r.get(name));
159
160 if let Some(ext) = extension {
161 if let Some(min_version) = self.min_versions.get(name) {
162 if let Some(ref actual_version) = ext.version {
163 if !version_satisfies(actual_version, min_version) {
164 version_mismatches.push(VersionMismatch {
165 extension: name.clone(),
166 required: min_version.clone(),
167 actual: actual_version.clone(),
168 });
169 continue;
170 }
171 }
172 }
173 found.insert(name.clone(), ext.clone());
174 } else if *requirement == ExtensionRequirement::Required {
175 missing_required.push(name.clone());
176 }
177 }
178
179 ExtensionQueryResult {
180 found,
181 missing_required,
182 version_mismatches,
183 }
184 }
185}
186
187#[derive(Debug, Clone)]
189pub struct ExtensionQueryResult {
190 pub found: HashMap<String, Extension>,
192 pub missing_required: Vec<String>,
194 pub version_mismatches: Vec<VersionMismatch>,
196}
197
198impl ExtensionQueryResult {
199 #[must_use]
201 pub fn is_satisfied(&self) -> bool {
202 self.missing_required.is_empty() && self.version_mismatches.is_empty()
203 }
204
205 #[must_use]
207 pub fn has(&self, name: &str) -> bool {
208 self.found.contains_key(name)
209 }
210
211 #[must_use]
213 pub fn get(&self, name: &str) -> Option<&Extension> {
214 self.found.get(name)
215 }
216
217 #[must_use]
219 pub fn found_names(&self) -> impl Iterator<Item = &str> {
220 self.found.keys().map(String::as_str)
221 }
222}
223
224#[derive(Debug, Clone)]
226pub struct VersionMismatch {
227 pub extension: String,
229 pub required: String,
231 pub actual: String,
233}
234
235fn version_satisfies(actual: &str, required: &str) -> bool {
239 let parse = |s: &str| -> (u32, u32, u32) {
240 let parts: Vec<&str> = s.split('.').collect();
241 let major = parts.first().and_then(|s| s.parse().ok()).unwrap_or(0);
242 let minor = parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0);
243 let patch = parts.get(2).and_then(|s| s.parse().ok()).unwrap_or(0);
244 (major, minor, patch)
245 };
246
247 let actual = parse(actual);
248 let required = parse(required);
249
250 actual >= required
251}
252
253#[must_use]
257pub fn negotiate_extensions(
258 client: &ClientCapabilities,
259 server: &ServerCapabilities,
260) -> ExtensionRegistry {
261 let client_registry = client
262 .experimental
263 .as_ref()
264 .and_then(ExtensionRegistry::from_experimental);
265 let server_registry = server
266 .experimental
267 .as_ref()
268 .and_then(ExtensionRegistry::from_experimental);
269
270 let mut result = ExtensionRegistry::new();
271
272 if let (Some(client_reg), Some(server_reg)) = (client_registry, server_registry) {
274 for name in client_reg.names() {
275 if let Some(server_ext) = server_reg.get(name) {
276 result = result.register(server_ext.clone());
278 }
279 }
280 }
281
282 result
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[test]
290 fn test_extension_query_satisfied() {
291 let registry = ExtensionRegistry::new()
292 .register(Extension::new("io.mcp.apps").with_version("0.1.0"))
293 .register(Extension::new("io.example.custom").with_version("1.0.0"));
294
295 let caps = ServerCapabilities::new().with_extensions(registry);
296
297 let query = ExtensionQuery::new()
298 .require("io.mcp.apps")
299 .optional("io.example.custom");
300
301 let result = query.check(&caps);
302 assert!(result.is_satisfied());
303 assert!(result.has("io.mcp.apps"));
304 assert!(result.has("io.example.custom"));
305 }
306
307 #[test]
308 fn test_extension_query_missing_required() {
309 let registry =
310 ExtensionRegistry::new().register(Extension::new("io.mcp.apps").with_version("0.1.0"));
311
312 let caps = ServerCapabilities::new().with_extensions(registry);
313
314 let query = ExtensionQuery::new()
315 .require("io.mcp.apps")
316 .require("io.example.missing");
317
318 let result = query.check(&caps);
319 assert!(!result.is_satisfied());
320 assert!(
321 result
322 .missing_required
323 .contains(&"io.example.missing".to_string())
324 );
325 }
326
327 #[test]
328 fn test_extension_query_optional_missing() {
329 let registry =
330 ExtensionRegistry::new().register(Extension::new("io.mcp.apps").with_version("0.1.0"));
331
332 let caps = ServerCapabilities::new().with_extensions(registry);
333
334 let query = ExtensionQuery::new()
335 .require("io.mcp.apps")
336 .optional("io.example.missing");
337
338 let result = query.check(&caps);
339 assert!(result.is_satisfied());
340 assert!(!result.has("io.example.missing"));
341 }
342
343 #[test]
344 fn test_version_satisfies() {
345 assert!(version_satisfies("1.0.0", "1.0.0"));
346 assert!(version_satisfies("1.1.0", "1.0.0"));
347 assert!(version_satisfies("2.0.0", "1.0.0"));
348 assert!(!version_satisfies("0.9.0", "1.0.0"));
349 assert!(!version_satisfies("1.0.0", "1.0.1"));
350 }
351
352 #[test]
353 fn test_version_requirement() {
354 let registry =
355 ExtensionRegistry::new().register(Extension::new("io.mcp.apps").with_version("0.1.0"));
356
357 let caps = ServerCapabilities::new().with_extensions(registry);
358
359 let query = ExtensionQuery::new().with_min_version("io.mcp.apps", "0.1.0");
361 assert!(query.check(&caps).is_satisfied());
362
363 let query = ExtensionQuery::new().with_min_version("io.mcp.apps", "0.2.0");
365 let result = query.check(&caps);
366 assert!(!result.is_satisfied());
367 assert_eq!(result.version_mismatches.len(), 1);
368 }
369
370 #[test]
371 fn test_negotiate_extensions() {
372 let client_registry = ExtensionRegistry::new()
373 .register(Extension::new("io.mcp.apps").with_version("0.1.0"))
374 .register(Extension::new("io.client.only").with_version("1.0.0"));
375
376 let server_registry = ExtensionRegistry::new()
377 .register(Extension::new("io.mcp.apps").with_version("0.2.0"))
378 .register(Extension::new("io.server.only").with_version("1.0.0"));
379
380 let client_caps = ClientCapabilities::new().with_extensions(client_registry);
381 let server_caps = ServerCapabilities::new().with_extensions(server_registry);
382
383 let negotiated = negotiate_extensions(&client_caps, &server_caps);
384
385 assert!(negotiated.has("io.mcp.apps"));
387 assert!(!negotiated.has("io.client.only"));
388 assert!(!negotiated.has("io.server.only"));
389
390 assert_eq!(
392 negotiated.get("io.mcp.apps").unwrap().version,
393 Some("0.2.0".to_string())
394 );
395 }
396}