mcpkit_core/extension/
discovery.rs

1//! Extension discovery mechanism.
2//!
3//! This module provides utilities for discovering and querying extensions
4//! from capabilities. Extension discovery allows clients to:
5//!
6//! - Detect which extensions a server supports
7//! - Query extension versions and configurations
8//! - Check for required vs optional extensions
9//!
10//! # Example
11//!
12//! ```rust
13//! use mcpkit_core::extension::{Extension, ExtensionRegistry};
14//! use mcpkit_core::extension::discovery::{ExtensionQuery, ExtensionRequirement};
15//! use mcpkit_core::capability::ServerCapabilities;
16//!
17//! // Server declares extensions
18//! let registry = ExtensionRegistry::new()
19//!     .register(Extension::new("io.mcp.apps").with_version("0.1.0"))
20//!     .register(Extension::new("io.example.custom").with_version("1.0.0"));
21//!
22//! let caps = ServerCapabilities::new()
23//!     .with_tools()
24//!     .with_extensions(registry);
25//!
26//! // Client queries for extensions
27//! let query = ExtensionQuery::new()
28//!     .require("io.mcp.apps")
29//!     .optional("io.example.custom")
30//!     .optional("io.example.missing");
31//!
32//! let result = query.check(&caps);
33//! assert!(result.is_satisfied());
34//! assert!(result.has("io.mcp.apps"));
35//! assert!(result.has("io.example.custom"));
36//! assert!(!result.has("io.example.missing"));
37//! ```
38
39use super::{Extension, ExtensionRegistry};
40use crate::capability::{ClientCapabilities, ServerCapabilities};
41use std::collections::HashMap;
42
43/// Extension requirement level.
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum ExtensionRequirement {
46    /// Extension must be present for the query to be satisfied.
47    Required,
48    /// Extension is optional; its absence doesn't fail the query.
49    Optional,
50}
51
52/// A query for checking extension support.
53///
54/// Build a query with required and optional extensions, then check
55/// against capabilities to see which are satisfied.
56#[derive(Debug, Clone, Default)]
57pub struct ExtensionQuery {
58    requirements: HashMap<String, ExtensionRequirement>,
59    min_versions: HashMap<String, String>,
60}
61
62impl ExtensionQuery {
63    /// Create a new empty extension query.
64    #[must_use]
65    pub fn new() -> Self {
66        Self::default()
67    }
68
69    /// Add a required extension.
70    ///
71    /// The query will only be satisfied if this extension is present.
72    #[must_use]
73    pub fn require(mut self, name: impl Into<String>) -> Self {
74        self.requirements
75            .insert(name.into(), ExtensionRequirement::Required);
76        self
77    }
78
79    /// Add an optional extension.
80    ///
81    /// The query can be satisfied even if this extension is absent.
82    #[must_use]
83    pub fn optional(mut self, name: impl Into<String>) -> Self {
84        self.requirements
85            .insert(name.into(), ExtensionRequirement::Optional);
86        self
87    }
88
89    /// Require a minimum version for an extension.
90    ///
91    /// The extension must be present with at least this version.
92    #[must_use]
93    pub fn with_min_version(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
94        let name = name.into();
95        // Ensure the extension is in requirements
96        self.requirements
97            .entry(name.clone())
98            .or_insert(ExtensionRequirement::Required);
99        self.min_versions.insert(name, version.into());
100        self
101    }
102
103    /// Check the query against server capabilities.
104    #[must_use]
105    pub fn check(&self, capabilities: &ServerCapabilities) -> ExtensionQueryResult {
106        let registry = capabilities
107            .experimental
108            .as_ref()
109            .and_then(ExtensionRegistry::from_experimental);
110
111        let mut found = HashMap::new();
112        let mut missing_required = Vec::new();
113        let mut version_mismatches = Vec::new();
114
115        for (name, requirement) in &self.requirements {
116            let extension = registry.as_ref().and_then(|r| r.get(name));
117
118            if let Some(ext) = extension {
119                // Check version if specified
120                if let Some(min_version) = self.min_versions.get(name) {
121                    if let Some(ref actual_version) = ext.version {
122                        if !version_satisfies(actual_version, min_version) {
123                            version_mismatches.push(VersionMismatch {
124                                extension: name.clone(),
125                                required: min_version.clone(),
126                                actual: actual_version.clone(),
127                            });
128                            continue;
129                        }
130                    }
131                }
132                found.insert(name.clone(), ext.clone());
133            } else if *requirement == ExtensionRequirement::Required {
134                missing_required.push(name.clone());
135            }
136        }
137
138        ExtensionQueryResult {
139            found,
140            missing_required,
141            version_mismatches,
142        }
143    }
144
145    /// Check the query against client capabilities.
146    #[must_use]
147    pub fn check_client(&self, capabilities: &ClientCapabilities) -> ExtensionQueryResult {
148        let registry = capabilities
149            .experimental
150            .as_ref()
151            .and_then(ExtensionRegistry::from_experimental);
152
153        let mut found = HashMap::new();
154        let mut missing_required = Vec::new();
155        let mut version_mismatches = Vec::new();
156
157        for (name, requirement) in &self.requirements {
158            let extension = registry.as_ref().and_then(|r| r.get(name));
159
160            if let Some(ext) = extension {
161                if let Some(min_version) = self.min_versions.get(name) {
162                    if let Some(ref actual_version) = ext.version {
163                        if !version_satisfies(actual_version, min_version) {
164                            version_mismatches.push(VersionMismatch {
165                                extension: name.clone(),
166                                required: min_version.clone(),
167                                actual: actual_version.clone(),
168                            });
169                            continue;
170                        }
171                    }
172                }
173                found.insert(name.clone(), ext.clone());
174            } else if *requirement == ExtensionRequirement::Required {
175                missing_required.push(name.clone());
176            }
177        }
178
179        ExtensionQueryResult {
180            found,
181            missing_required,
182            version_mismatches,
183        }
184    }
185}
186
187/// Result of an extension query.
188#[derive(Debug, Clone)]
189pub struct ExtensionQueryResult {
190    /// Extensions that were found.
191    pub found: HashMap<String, Extension>,
192    /// Required extensions that were missing.
193    pub missing_required: Vec<String>,
194    /// Extensions with version mismatches.
195    pub version_mismatches: Vec<VersionMismatch>,
196}
197
198impl ExtensionQueryResult {
199    /// Check if the query is satisfied (all required extensions present).
200    #[must_use]
201    pub fn is_satisfied(&self) -> bool {
202        self.missing_required.is_empty() && self.version_mismatches.is_empty()
203    }
204
205    /// Check if a specific extension was found.
206    #[must_use]
207    pub fn has(&self, name: &str) -> bool {
208        self.found.contains_key(name)
209    }
210
211    /// Get a found extension by name.
212    #[must_use]
213    pub fn get(&self, name: &str) -> Option<&Extension> {
214        self.found.get(name)
215    }
216
217    /// Get the list of found extension names.
218    #[must_use]
219    pub fn found_names(&self) -> impl Iterator<Item = &str> {
220        self.found.keys().map(String::as_str)
221    }
222}
223
224/// Version mismatch information.
225#[derive(Debug, Clone)]
226pub struct VersionMismatch {
227    /// Extension name.
228    pub extension: String,
229    /// Required minimum version.
230    pub required: String,
231    /// Actual version found.
232    pub actual: String,
233}
234
235/// Simple semver-like version comparison.
236///
237/// Returns true if `actual` >= `required`.
238fn version_satisfies(actual: &str, required: &str) -> bool {
239    let parse = |s: &str| -> (u32, u32, u32) {
240        let parts: Vec<&str> = s.split('.').collect();
241        let major = parts.first().and_then(|s| s.parse().ok()).unwrap_or(0);
242        let minor = parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0);
243        let patch = parts.get(2).and_then(|s| s.parse().ok()).unwrap_or(0);
244        (major, minor, patch)
245    };
246
247    let actual = parse(actual);
248    let required = parse(required);
249
250    actual >= required
251}
252
253/// Negotiate extensions between client and server.
254///
255/// Returns the set of extensions both sides support.
256#[must_use]
257pub fn negotiate_extensions(
258    client: &ClientCapabilities,
259    server: &ServerCapabilities,
260) -> ExtensionRegistry {
261    let client_registry = client
262        .experimental
263        .as_ref()
264        .and_then(ExtensionRegistry::from_experimental);
265    let server_registry = server
266        .experimental
267        .as_ref()
268        .and_then(ExtensionRegistry::from_experimental);
269
270    let mut result = ExtensionRegistry::new();
271
272    // Only include extensions both sides support
273    if let (Some(client_reg), Some(server_reg)) = (client_registry, server_registry) {
274        for name in client_reg.names() {
275            if let Some(server_ext) = server_reg.get(name) {
276                // Use server's extension info (version, config)
277                result = result.register(server_ext.clone());
278            }
279        }
280    }
281
282    result
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    #[test]
290    fn test_extension_query_satisfied() {
291        let registry = ExtensionRegistry::new()
292            .register(Extension::new("io.mcp.apps").with_version("0.1.0"))
293            .register(Extension::new("io.example.custom").with_version("1.0.0"));
294
295        let caps = ServerCapabilities::new().with_extensions(registry);
296
297        let query = ExtensionQuery::new()
298            .require("io.mcp.apps")
299            .optional("io.example.custom");
300
301        let result = query.check(&caps);
302        assert!(result.is_satisfied());
303        assert!(result.has("io.mcp.apps"));
304        assert!(result.has("io.example.custom"));
305    }
306
307    #[test]
308    fn test_extension_query_missing_required() {
309        let registry =
310            ExtensionRegistry::new().register(Extension::new("io.mcp.apps").with_version("0.1.0"));
311
312        let caps = ServerCapabilities::new().with_extensions(registry);
313
314        let query = ExtensionQuery::new()
315            .require("io.mcp.apps")
316            .require("io.example.missing");
317
318        let result = query.check(&caps);
319        assert!(!result.is_satisfied());
320        assert!(
321            result
322                .missing_required
323                .contains(&"io.example.missing".to_string())
324        );
325    }
326
327    #[test]
328    fn test_extension_query_optional_missing() {
329        let registry =
330            ExtensionRegistry::new().register(Extension::new("io.mcp.apps").with_version("0.1.0"));
331
332        let caps = ServerCapabilities::new().with_extensions(registry);
333
334        let query = ExtensionQuery::new()
335            .require("io.mcp.apps")
336            .optional("io.example.missing");
337
338        let result = query.check(&caps);
339        assert!(result.is_satisfied());
340        assert!(!result.has("io.example.missing"));
341    }
342
343    #[test]
344    fn test_version_satisfies() {
345        assert!(version_satisfies("1.0.0", "1.0.0"));
346        assert!(version_satisfies("1.1.0", "1.0.0"));
347        assert!(version_satisfies("2.0.0", "1.0.0"));
348        assert!(!version_satisfies("0.9.0", "1.0.0"));
349        assert!(!version_satisfies("1.0.0", "1.0.1"));
350    }
351
352    #[test]
353    fn test_version_requirement() {
354        let registry =
355            ExtensionRegistry::new().register(Extension::new("io.mcp.apps").with_version("0.1.0"));
356
357        let caps = ServerCapabilities::new().with_extensions(registry);
358
359        // Version satisfied
360        let query = ExtensionQuery::new().with_min_version("io.mcp.apps", "0.1.0");
361        assert!(query.check(&caps).is_satisfied());
362
363        // Version not satisfied
364        let query = ExtensionQuery::new().with_min_version("io.mcp.apps", "0.2.0");
365        let result = query.check(&caps);
366        assert!(!result.is_satisfied());
367        assert_eq!(result.version_mismatches.len(), 1);
368    }
369
370    #[test]
371    fn test_negotiate_extensions() {
372        let client_registry = ExtensionRegistry::new()
373            .register(Extension::new("io.mcp.apps").with_version("0.1.0"))
374            .register(Extension::new("io.client.only").with_version("1.0.0"));
375
376        let server_registry = ExtensionRegistry::new()
377            .register(Extension::new("io.mcp.apps").with_version("0.2.0"))
378            .register(Extension::new("io.server.only").with_version("1.0.0"));
379
380        let client_caps = ClientCapabilities::new().with_extensions(client_registry);
381        let server_caps = ServerCapabilities::new().with_extensions(server_registry);
382
383        let negotiated = negotiate_extensions(&client_caps, &server_caps);
384
385        // Only common extension should be present
386        assert!(negotiated.has("io.mcp.apps"));
387        assert!(!negotiated.has("io.client.only"));
388        assert!(!negotiated.has("io.server.only"));
389
390        // Server's version should be used
391        assert_eq!(
392            negotiated.get("io.mcp.apps").unwrap().version,
393            Some("0.2.0".to_string())
394        );
395    }
396}