mcpkit_server/capability/
completions.rs

1//! Completion capability implementation.
2//!
3//! This module provides support for argument completion
4//! in MCP servers.
5
6use crate::context::Context;
7use crate::handler::CompletionHandler;
8use mcpkit_core::error::McpError;
9use mcpkit_core::types::completion::{
10    CompleteRequest, CompleteResult, Completion, CompletionArgument, CompletionRef, CompletionTotal,
11};
12use std::collections::HashMap;
13use std::future::Future;
14use std::pin::Pin;
15
16/// A boxed async function for handling completion requests.
17pub type BoxedCompletionFn = Box<
18    dyn for<'a> Fn(
19            &'a str,
20            &'a Context<'a>,
21        ) -> Pin<Box<dyn Future<Output = Result<Vec<String>, McpError>> + Send + 'a>>
22        + Send
23        + Sync,
24>;
25
26/// A registered completion provider.
27pub struct RegisteredCompletion {
28    /// Reference type (prompt or resource).
29    pub ref_type: String,
30    /// Reference value (name or URI pattern).
31    pub ref_value: String,
32    /// Argument name.
33    pub arg_name: String,
34    /// Completion handler.
35    pub handler: BoxedCompletionFn,
36}
37
38/// Service for handling completion requests.
39///
40/// This provides argument completion for prompts and resources.
41pub struct CompletionService {
42    /// Completions keyed by (ref_type, ref_value, arg_name).
43    completions: HashMap<(String, String, String), RegisteredCompletion>,
44}
45
46impl Default for CompletionService {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl CompletionService {
53    /// Create a new completion service.
54    pub fn new() -> Self {
55        Self {
56            completions: HashMap::new(),
57        }
58    }
59
60    /// Register a completion provider for a prompt argument.
61    pub fn register_prompt_completion<F, Fut>(
62        &mut self,
63        prompt_name: impl Into<String>,
64        arg_name: impl Into<String>,
65        handler: F,
66    ) where
67        F: Fn(&str, &Context<'_>) -> Fut + Send + Sync + 'static,
68        Fut: Future<Output = Result<Vec<String>, McpError>> + Send + 'static,
69    {
70        let ref_type = "ref/prompt".to_string();
71        let ref_value = prompt_name.into();
72        let arg_name = arg_name.into();
73        let key = (ref_type.clone(), ref_value.clone(), arg_name.clone());
74
75        let boxed: BoxedCompletionFn = Box::new(move |input, ctx| Box::pin(handler(input, ctx)));
76
77        self.completions.insert(
78            key,
79            RegisteredCompletion {
80                ref_type,
81                ref_value,
82                arg_name,
83                handler: boxed,
84            },
85        );
86    }
87
88    /// Register a completion provider for a resource argument.
89    pub fn register_resource_completion<F, Fut>(
90        &mut self,
91        uri_pattern: impl Into<String>,
92        arg_name: impl Into<String>,
93        handler: F,
94    ) where
95        F: Fn(&str, &Context<'_>) -> Fut + Send + Sync + 'static,
96        Fut: Future<Output = Result<Vec<String>, McpError>> + Send + 'static,
97    {
98        let ref_type = "ref/resource".to_string();
99        let ref_value = uri_pattern.into();
100        let arg_name = arg_name.into();
101        let key = (ref_type.clone(), ref_value.clone(), arg_name.clone());
102
103        let boxed: BoxedCompletionFn = Box::new(move |input, ctx| Box::pin(handler(input, ctx)));
104
105        self.completions.insert(
106            key,
107            RegisteredCompletion {
108                ref_type,
109                ref_value,
110                arg_name,
111                handler: boxed,
112            },
113        );
114    }
115
116    /// Complete an argument value.
117    pub async fn complete(
118        &self,
119        request: &CompleteRequest,
120        ctx: &Context<'_>,
121    ) -> Result<CompleteResult, McpError> {
122        let ref_type = request.ref_.ref_type().to_string();
123        let ref_value = request.ref_.value().to_string();
124        let arg_name = request.argument.name.clone();
125        let input = &request.argument.value;
126
127        let key = (ref_type, ref_value, arg_name);
128
129        if let Some(registered) = self.completions.get(&key) {
130            let values = (registered.handler)(input, ctx).await?;
131            let total = values.len();
132
133            Ok(CompleteResult {
134                completion: Completion {
135                    values,
136                    total: Some(CompletionTotal::Exact(total)),
137                    has_more: Some(false),
138                },
139            })
140        } else {
141            // Return empty completions if no handler registered
142            Ok(CompleteResult {
143                completion: Completion {
144                    values: Vec::new(),
145                    total: Some(CompletionTotal::Exact(0)),
146                    has_more: Some(false),
147                },
148            })
149        }
150    }
151
152    /// Check if a completion provider exists.
153    pub fn has_completion(
154        &self,
155        ref_type: &str,
156        ref_value: &str,
157        arg_name: &str,
158    ) -> bool {
159        let key = (ref_type.to_string(), ref_value.to_string(), arg_name.to_string());
160        self.completions.contains_key(&key)
161    }
162}
163
164impl CompletionHandler for CompletionService {
165    async fn complete_resource(
166        &self,
167        partial_uri: &str,
168        ctx: &Context<'_>,
169    ) -> Result<Vec<String>, McpError> {
170        let request = CompleteRequest {
171            ref_: CompletionRef::resource(partial_uri),
172            argument: CompletionArgument {
173                name: "uri".to_string(),
174                value: partial_uri.to_string(),
175            },
176        };
177
178        let result = CompletionService::complete(self, &request, ctx).await?;
179        Ok(result.completion.values)
180    }
181
182    async fn complete_prompt_arg(
183        &self,
184        prompt_name: &str,
185        arg_name: &str,
186        partial_value: &str,
187        ctx: &Context<'_>,
188    ) -> Result<Vec<String>, McpError> {
189        let request = CompleteRequest {
190            ref_: CompletionRef::prompt(prompt_name),
191            argument: CompletionArgument {
192                name: arg_name.to_string(),
193                value: partial_value.to_string(),
194            },
195        };
196
197        let result = CompletionService::complete(self, &request, ctx).await?;
198        Ok(result.completion.values)
199    }
200}
201
202/// Builder for creating completion requests.
203pub struct CompleteRequestBuilder {
204    ref_: CompletionRef,
205    arg_name: String,
206    arg_value: String,
207}
208
209impl CompleteRequestBuilder {
210    /// Create a builder for prompt argument completion.
211    pub fn for_prompt(prompt_name: impl Into<String>, arg_name: impl Into<String>) -> Self {
212        Self {
213            ref_: CompletionRef::prompt(prompt_name.into()),
214            arg_name: arg_name.into(),
215            arg_value: String::new(),
216        }
217    }
218
219    /// Create a builder for resource argument completion.
220    pub fn for_resource(uri: impl Into<String>, arg_name: impl Into<String>) -> Self {
221        Self {
222            ref_: CompletionRef::resource(uri.into()),
223            arg_name: arg_name.into(),
224            arg_value: String::new(),
225        }
226    }
227
228    /// Set the current input value.
229    pub fn value(mut self, value: impl Into<String>) -> Self {
230        self.arg_value = value.into();
231        self
232    }
233
234    /// Build the request.
235    pub fn build(self) -> CompleteRequest {
236        CompleteRequest {
237            ref_: self.ref_,
238            argument: mcpkit_core::types::completion::CompletionArgument {
239                name: self.arg_name,
240                value: self.arg_value,
241            },
242        }
243    }
244}
245
246/// Helper for filtering completions.
247pub struct CompletionFilter;
248
249impl CompletionFilter {
250    /// Filter completions by prefix match.
251    pub fn by_prefix(values: &[String], prefix: &str) -> Vec<String> {
252        let prefix_lower = prefix.to_lowercase();
253        values
254            .iter()
255            .filter(|v| v.to_lowercase().starts_with(&prefix_lower))
256            .cloned()
257            .collect()
258    }
259
260    /// Filter completions by substring match.
261    pub fn by_substring(values: &[String], substring: &str) -> Vec<String> {
262        let sub_lower = substring.to_lowercase();
263        values
264            .iter()
265            .filter(|v| v.to_lowercase().contains(&sub_lower))
266            .cloned()
267            .collect()
268    }
269
270    /// Filter and limit completions.
271    pub fn limit(values: Vec<String>, max: usize) -> Vec<String> {
272        values.into_iter().take(max).collect()
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use crate::context::{NoOpPeer, Context};
280    use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
281    use mcpkit_core::protocol::RequestId;
282
283    fn make_context() -> (RequestId, ClientCapabilities, ServerCapabilities, NoOpPeer) {
284        (
285            RequestId::Number(1),
286            ClientCapabilities::default(),
287            ServerCapabilities::default(),
288            NoOpPeer,
289        )
290    }
291
292    #[test]
293    fn test_complete_request_builder() {
294        let request = CompleteRequestBuilder::for_prompt("code-review", "language")
295            .value("py")
296            .build();
297
298        assert_eq!(request.ref_.ref_type(), "ref/prompt");
299        assert_eq!(request.argument.name, "language");
300        assert_eq!(request.argument.value, "py");
301    }
302
303    #[test]
304    fn test_completion_filter() {
305        let values = vec![
306            "python".to_string(),
307            "javascript".to_string(),
308            "typescript".to_string(),
309            "rust".to_string(),
310        ];
311
312        let filtered = CompletionFilter::by_prefix(&values, "py");
313        assert_eq!(filtered, vec!["python"]);
314
315        let filtered = CompletionFilter::by_substring(&values, "script");
316        assert_eq!(filtered, vec!["javascript", "typescript"]);
317
318        let limited = CompletionFilter::limit(values.clone(), 2);
319        assert_eq!(limited.len(), 2);
320    }
321
322    #[tokio::test]
323    async fn test_completion_service() {
324        let mut service = CompletionService::new();
325
326        let languages = vec![
327            "python".to_string(),
328            "javascript".to_string(),
329            "typescript".to_string(),
330            "rust".to_string(),
331        ];
332
333        service.register_prompt_completion("code-review", "language", move |input, _ctx| {
334            let langs = languages.clone();
335            let input = input.to_string();
336            async move { Ok(CompletionFilter::by_prefix(&langs, &input)) }
337        });
338
339        assert!(service.has_completion("ref/prompt", "code-review", "language"));
340
341        let (req_id, client_caps, server_caps, peer) = make_context();
342        let ctx = Context::new(&req_id, None, &client_caps, &server_caps, &peer);
343
344        let request = CompleteRequestBuilder::for_prompt("code-review", "language")
345            .value("py")
346            .build();
347
348        let result = service.complete(&request, &ctx).await.unwrap();
349        assert_eq!(result.completion.values, vec!["python"]);
350    }
351}