mcp_host/managers/
completion.rs

1//! Completion system for argument value suggestions
2//!
3//! Provides `completion/complete` handler for autocomplete suggestions on
4//! prompt arguments and resource URIs.
5//!
6//! # MCP Spec
7//!
8//! The completion system allows clients to request argument value completions.
9//! Servers can provide suggestions based on:
10//! - Prompt argument names and current partial values
11//! - Resource URI patterns and partial paths
12//!
13//! # Example
14//!
15//! ```rust,ignore
16//! use mcp_host::managers::completion::{CompletionManager, CompletionProvider};
17//! use mcp_host::protocol::types::{CompleteRequest, CompleteResult, CompletionValue};
18//!
19//! struct MyCompletionProvider;
20//!
21//! #[async_trait]
22//! impl CompletionProvider for MyCompletionProvider {
23//!     async fn complete_prompt(
24//!         &self,
25//!         prompt_name: &str,
26//!         argument_name: &str,
27//!         argument_value: &str,
28//!     ) -> Vec<CompletionValue> {
29//!         if prompt_name == "greet" && argument_name == "name" {
30//!             vec![
31//!                 CompletionValue::new("Alice"),
32//!                 CompletionValue::new("Bob"),
33//!             ]
34//!             .into_iter()
35//!             .filter(|v| v.value.starts_with(argument_value))
36//!             .collect()
37//!         } else {
38//!             vec![]
39//!         }
40//!     }
41//! }
42//! ```
43
44use std::sync::Arc;
45
46use async_trait::async_trait;
47use dashmap::DashMap;
48use thiserror::Error;
49
50use crate::protocol::types::{
51    CompleteRequest, CompleteResult, CompletionInfo, CompletionRef, CompletionValue,
52};
53
54/// Completion errors
55#[derive(Debug, Error)]
56pub enum CompletionError {
57    /// Invalid completion reference
58    #[error("Invalid reference: {0}")]
59    InvalidReference(String),
60
61    /// Provider not found for the given reference
62    #[error("No completion provider for: {0}")]
63    NoProvider(String),
64
65    /// Internal error during completion
66    #[error("Completion error: {0}")]
67    Internal(String),
68}
69
70/// Maximum number of completion values to return by default
71pub const DEFAULT_MAX_COMPLETIONS: usize = 100;
72
73/// Trait for providing completion suggestions
74///
75/// Implement this trait to provide custom completion logic for prompts and resources.
76#[async_trait]
77pub trait CompletionProvider: Send + Sync {
78    /// Provide completions for a prompt argument
79    ///
80    /// # Arguments
81    ///
82    /// * `prompt_name` - Name of the prompt being completed
83    /// * `argument_name` - Name of the argument being completed
84    /// * `argument_value` - Current partial value of the argument
85    ///
86    /// # Returns
87    ///
88    /// List of completion suggestions matching the partial value
89    async fn complete_prompt(
90        &self,
91        prompt_name: &str,
92        argument_name: &str,
93        argument_value: &str,
94    ) -> Vec<CompletionValue> {
95        let _ = (prompt_name, argument_name, argument_value);
96        vec![]
97    }
98
99    /// Provide completions for a resource URI argument
100    ///
101    /// # Arguments
102    ///
103    /// * `uri` - URI pattern of the resource
104    /// * `argument_name` - Name of the argument being completed
105    /// * `argument_value` - Current partial value of the argument
106    ///
107    /// # Returns
108    ///
109    /// List of completion suggestions matching the partial value
110    async fn complete_resource(
111        &self,
112        uri: &str,
113        argument_name: &str,
114        argument_value: &str,
115    ) -> Vec<CompletionValue> {
116        let _ = (uri, argument_name, argument_value);
117        vec![]
118    }
119
120    /// Maximum completions to return (default: 100)
121    fn max_completions(&self) -> usize {
122        DEFAULT_MAX_COMPLETIONS
123    }
124}
125
126/// Default completion provider that returns no suggestions
127pub struct NoOpCompletionProvider;
128
129#[async_trait]
130impl CompletionProvider for NoOpCompletionProvider {}
131
132/// Completion manager for handling completion requests
133///
134/// The manager routes completion requests to registered providers based on
135/// the reference type (prompt or resource).
136#[derive(Clone)]
137pub struct CompletionManager {
138    /// Prompt completion providers by prompt name
139    prompt_providers: Arc<DashMap<String, Arc<dyn CompletionProvider>>>,
140    /// Resource completion providers by URI pattern
141    resource_providers: Arc<DashMap<String, Arc<dyn CompletionProvider>>>,
142    /// Default provider for unregistered prompts/resources
143    default_provider: Arc<dyn CompletionProvider>,
144    /// Maximum completions to return
145    max_completions: usize,
146}
147
148impl Default for CompletionManager {
149    fn default() -> Self {
150        Self::new()
151    }
152}
153
154impl CompletionManager {
155    /// Create a new completion manager
156    pub fn new() -> Self {
157        Self {
158            prompt_providers: Arc::new(DashMap::new()),
159            resource_providers: Arc::new(DashMap::new()),
160            default_provider: Arc::new(NoOpCompletionProvider),
161            max_completions: DEFAULT_MAX_COMPLETIONS,
162        }
163    }
164
165    /// Create a completion manager with a default provider
166    pub fn with_default_provider(provider: Arc<dyn CompletionProvider>) -> Self {
167        Self {
168            prompt_providers: Arc::new(DashMap::new()),
169            resource_providers: Arc::new(DashMap::new()),
170            default_provider: provider,
171            max_completions: DEFAULT_MAX_COMPLETIONS,
172        }
173    }
174
175    /// Set maximum completions to return
176    pub fn set_max_completions(&mut self, max: usize) {
177        self.max_completions = max;
178    }
179
180    /// Register a completion provider for a specific prompt
181    pub fn register_prompt_provider(
182        &self,
183        prompt_name: impl Into<String>,
184        provider: Arc<dyn CompletionProvider>,
185    ) {
186        self.prompt_providers.insert(prompt_name.into(), provider);
187    }
188
189    /// Register a completion provider for a specific resource URI pattern
190    pub fn register_resource_provider(
191        &self,
192        uri_pattern: impl Into<String>,
193        provider: Arc<dyn CompletionProvider>,
194    ) {
195        self.resource_providers.insert(uri_pattern.into(), provider);
196    }
197
198    /// Set the default provider for unregistered prompts/resources
199    pub fn set_default_provider(&mut self, provider: Arc<dyn CompletionProvider>) {
200        self.default_provider = provider;
201    }
202
203    /// Handle a completion request
204    pub async fn complete(&self, request: CompleteRequest) -> Result<CompleteResult, CompletionError> {
205        let values = match &request.reference {
206            CompletionRef::Prompt { name } => {
207                let provider = self
208                    .prompt_providers
209                    .get(name)
210                    .map(|p| Arc::clone(&p))
211                    .unwrap_or_else(|| Arc::clone(&self.default_provider));
212
213                provider
214                    .complete_prompt(name, &request.argument.name, &request.argument.value)
215                    .await
216            }
217            CompletionRef::Resource { uri } => {
218                // Try exact match first, then pattern matching
219                let provider = self
220                    .resource_providers
221                    .get(uri)
222                    .map(|p| Arc::clone(&p))
223                    .or_else(|| self.find_matching_resource_provider(uri))
224                    .unwrap_or_else(|| Arc::clone(&self.default_provider));
225
226                provider
227                    .complete_resource(uri, &request.argument.name, &request.argument.value)
228                    .await
229            }
230        };
231
232        // Apply max completions limit
233        let total = values.len();
234        let max = self.max_completions;
235        let has_more = total > max;
236        let values: Vec<_> = values.into_iter().take(max).collect();
237
238        let completion = if has_more {
239            CompletionInfo::with_pagination(values, total, true)
240        } else {
241            CompletionInfo::with_values(values)
242        };
243
244        Ok(CompleteResult { completion })
245    }
246
247    /// Find a resource provider by pattern matching
248    fn find_matching_resource_provider(&self, uri: &str) -> Option<Arc<dyn CompletionProvider>> {
249        // Simple prefix matching for now
250        // Could be extended to support URI templates like "file:///{path}"
251        for entry in self.resource_providers.iter() {
252            let pattern = entry.key();
253            if uri.starts_with(pattern) || pattern.starts_with(uri) {
254                return Some(Arc::clone(entry.value()));
255            }
256
257            // Check for template pattern matching
258            if let Some(provider) = self.match_template_pattern(pattern, uri, entry.value()) {
259                return Some(provider);
260            }
261        }
262        None
263    }
264
265    /// Match a URI against a template pattern
266    fn match_template_pattern(
267        &self,
268        pattern: &str,
269        uri: &str,
270        provider: &Arc<dyn CompletionProvider>,
271    ) -> Option<Arc<dyn CompletionProvider>> {
272        // Extract scheme and path prefix before any template variables
273        let pattern_prefix = pattern.split('{').next()?;
274        if uri.starts_with(pattern_prefix) {
275            return Some(Arc::clone(provider));
276        }
277        None
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    /// Test provider that returns predefined completions
286    struct TestCompletionProvider {
287        prompt_completions: Vec<CompletionValue>,
288        resource_completions: Vec<CompletionValue>,
289    }
290
291    impl TestCompletionProvider {
292        fn new(prompt: Vec<&str>, resource: Vec<&str>) -> Self {
293            Self {
294                prompt_completions: prompt.into_iter().map(CompletionValue::new).collect(),
295                resource_completions: resource.into_iter().map(CompletionValue::new).collect(),
296            }
297        }
298    }
299
300    #[async_trait]
301    impl CompletionProvider for TestCompletionProvider {
302        async fn complete_prompt(
303            &self,
304            _prompt_name: &str,
305            _argument_name: &str,
306            argument_value: &str,
307        ) -> Vec<CompletionValue> {
308            self.prompt_completions
309                .iter()
310                .filter(|v| v.value.starts_with(argument_value))
311                .cloned()
312                .collect()
313        }
314
315        async fn complete_resource(
316            &self,
317            _uri: &str,
318            _argument_name: &str,
319            argument_value: &str,
320        ) -> Vec<CompletionValue> {
321            self.resource_completions
322                .iter()
323                .filter(|v| v.value.starts_with(argument_value))
324                .cloned()
325                .collect()
326        }
327    }
328
329    #[tokio::test]
330    async fn test_completion_manager_new() {
331        let manager = CompletionManager::new();
332        assert_eq!(manager.max_completions, DEFAULT_MAX_COMPLETIONS);
333    }
334
335    #[tokio::test]
336    async fn test_prompt_completion() {
337        let manager = CompletionManager::new();
338        let provider = Arc::new(TestCompletionProvider::new(
339            vec!["Alice", "Amy", "Bob"],
340            vec![],
341        ));
342        manager.register_prompt_provider("greet", provider);
343
344        let request = CompleteRequest {
345            reference: CompletionRef::Prompt {
346                name: "greet".to_string(),
347            },
348            argument: crate::protocol::types::CompletionArgument {
349                name: "name".to_string(),
350                value: "A".to_string(),
351            },
352        };
353
354        let result = manager.complete(request).await.unwrap();
355        assert_eq!(result.completion.values.len(), 2); // Alice, Amy
356        assert!(result.completion.values.iter().all(|v| v.value.starts_with("A")));
357    }
358
359    #[tokio::test]
360    async fn test_resource_completion() {
361        let manager = CompletionManager::new();
362        let provider = Arc::new(TestCompletionProvider::new(
363            vec![],
364            vec!["src/main.rs", "src/lib.rs", "tests/test.rs"],
365        ));
366        manager.register_resource_provider("file:///", provider);
367
368        let request = CompleteRequest {
369            reference: CompletionRef::Resource {
370                uri: "file:///src".to_string(),
371            },
372            argument: crate::protocol::types::CompletionArgument {
373                name: "path".to_string(),
374                value: "src/".to_string(),
375            },
376        };
377
378        let result = manager.complete(request).await.unwrap();
379        assert_eq!(result.completion.values.len(), 2); // src/main.rs, src/lib.rs
380    }
381
382    #[tokio::test]
383    async fn test_default_provider() {
384        let default = Arc::new(TestCompletionProvider::new(
385            vec!["default1", "default2"],
386            vec![],
387        ));
388        let manager = CompletionManager::with_default_provider(default);
389
390        // Request for unregistered prompt uses default provider
391        let request = CompleteRequest {
392            reference: CompletionRef::Prompt {
393                name: "unknown_prompt".to_string(),
394            },
395            argument: crate::protocol::types::CompletionArgument {
396                name: "arg".to_string(),
397                value: "default".to_string(),
398            },
399        };
400
401        let result = manager.complete(request).await.unwrap();
402        assert_eq!(result.completion.values.len(), 2);
403    }
404
405    #[tokio::test]
406    async fn test_max_completions_limit() {
407        let mut manager = CompletionManager::new();
408        manager.set_max_completions(2);
409
410        let provider = Arc::new(TestCompletionProvider::new(
411            vec!["a1", "a2", "a3", "a4", "a5"],
412            vec![],
413        ));
414        manager.register_prompt_provider("test", provider);
415
416        let request = CompleteRequest {
417            reference: CompletionRef::Prompt {
418                name: "test".to_string(),
419            },
420            argument: crate::protocol::types::CompletionArgument {
421                name: "arg".to_string(),
422                value: "a".to_string(),
423            },
424        };
425
426        let result = manager.complete(request).await.unwrap();
427        assert_eq!(result.completion.values.len(), 2);
428        assert_eq!(result.completion.total, Some(5));
429        assert_eq!(result.completion.has_more, Some(true));
430    }
431
432    #[tokio::test]
433    async fn test_empty_completions() {
434        let manager = CompletionManager::new();
435
436        let request = CompleteRequest {
437            reference: CompletionRef::Prompt {
438                name: "nonexistent".to_string(),
439            },
440            argument: crate::protocol::types::CompletionArgument {
441                name: "arg".to_string(),
442                value: "x".to_string(),
443            },
444        };
445
446        let result = manager.complete(request).await.unwrap();
447        assert!(result.completion.values.is_empty());
448    }
449
450    #[tokio::test]
451    async fn test_completion_value_constructors() {
452        let simple = CompletionValue::new("value");
453        assert_eq!(simple.value, "value");
454        assert!(simple.label.is_none());
455        assert!(simple.description.is_none());
456
457        let with_label = CompletionValue::with_label("value", "Label");
458        assert_eq!(with_label.value, "value");
459        assert_eq!(with_label.label, Some("Label".to_string()));
460        assert!(with_label.description.is_none());
461
462        let with_desc = CompletionValue::with_description("value", "Description");
463        assert_eq!(with_desc.value, "value");
464        assert!(with_desc.label.is_none());
465        assert_eq!(with_desc.description, Some("Description".to_string()));
466
467        let full = CompletionValue::full("value", "Label", "Description");
468        assert_eq!(full.value, "value");
469        assert_eq!(full.label, Some("Label".to_string()));
470        assert_eq!(full.description, Some("Description".to_string()));
471    }
472
473    #[tokio::test]
474    async fn test_completion_info_constructors() {
475        let empty = CompletionInfo::empty();
476        assert!(empty.values.is_empty());
477        assert!(empty.total.is_none());
478        assert!(empty.has_more.is_none());
479
480        let with_values = CompletionInfo::with_values(vec![
481            CompletionValue::new("a"),
482            CompletionValue::new("b"),
483        ]);
484        assert_eq!(with_values.values.len(), 2);
485        assert!(with_values.total.is_none());
486        assert!(with_values.has_more.is_none());
487
488        let paginated = CompletionInfo::with_pagination(
489            vec![CompletionValue::new("a")],
490            10,
491            true,
492        );
493        assert_eq!(paginated.values.len(), 1);
494        assert_eq!(paginated.total, Some(10));
495        assert_eq!(paginated.has_more, Some(true));
496    }
497}