mcpkit_server/capability/
completions.rs

1//! Completion capability implementation.
2//!
3//! This module provides support for argument completion
4//! in MCP servers.
5
6use crate::context::Context;
7use crate::handler::CompletionHandler;
8use mcpkit_core::error::McpError;
9use mcpkit_core::types::completion::{
10    CompleteRequest, CompleteResult, Completion, CompletionArgument, CompletionRef, CompletionTotal,
11};
12use std::collections::HashMap;
13use std::future::Future;
14use std::pin::Pin;
15
16/// A boxed async function for handling completion requests.
17pub type BoxedCompletionFn = Box<
18    dyn for<'a> Fn(
19            &'a str,
20            &'a Context<'a>,
21        )
22            -> Pin<Box<dyn Future<Output = Result<Vec<String>, McpError>> + Send + 'a>>
23        + Send
24        + Sync,
25>;
26
27/// A registered completion provider.
28pub struct RegisteredCompletion {
29    /// Reference type (prompt or resource).
30    pub ref_type: String,
31    /// Reference value (name or URI pattern).
32    pub ref_value: String,
33    /// Argument name.
34    pub arg_name: String,
35    /// Completion handler.
36    pub handler: BoxedCompletionFn,
37}
38
39/// Service for handling completion requests.
40///
41/// This provides argument completion for prompts and resources.
42pub struct CompletionService {
43    /// Completions keyed by (`ref_type`, `ref_value`, `arg_name`).
44    completions: HashMap<(String, String, String), RegisteredCompletion>,
45}
46
47impl Default for CompletionService {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl CompletionService {
54    /// Create a new completion service.
55    #[must_use]
56    pub fn new() -> Self {
57        Self {
58            completions: HashMap::new(),
59        }
60    }
61
62    /// Register a completion provider for a prompt argument.
63    pub fn register_prompt_completion<F, Fut>(
64        &mut self,
65        prompt_name: impl Into<String>,
66        arg_name: impl Into<String>,
67        handler: F,
68    ) where
69        F: Fn(&str, &Context<'_>) -> Fut + Send + Sync + 'static,
70        Fut: Future<Output = Result<Vec<String>, McpError>> + Send + 'static,
71    {
72        let ref_type = "ref/prompt".to_string();
73        let ref_value = prompt_name.into();
74        let arg_name = arg_name.into();
75        let key = (ref_type.clone(), ref_value.clone(), arg_name.clone());
76
77        let boxed: BoxedCompletionFn = Box::new(move |input, ctx| Box::pin(handler(input, ctx)));
78
79        self.completions.insert(
80            key,
81            RegisteredCompletion {
82                ref_type,
83                ref_value,
84                arg_name,
85                handler: boxed,
86            },
87        );
88    }
89
90    /// Register a completion provider for a resource argument.
91    pub fn register_resource_completion<F, Fut>(
92        &mut self,
93        uri_pattern: impl Into<String>,
94        arg_name: impl Into<String>,
95        handler: F,
96    ) where
97        F: Fn(&str, &Context<'_>) -> Fut + Send + Sync + 'static,
98        Fut: Future<Output = Result<Vec<String>, McpError>> + Send + 'static,
99    {
100        let ref_type = "ref/resource".to_string();
101        let ref_value = uri_pattern.into();
102        let arg_name = arg_name.into();
103        let key = (ref_type.clone(), ref_value.clone(), arg_name.clone());
104
105        let boxed: BoxedCompletionFn = Box::new(move |input, ctx| Box::pin(handler(input, ctx)));
106
107        self.completions.insert(
108            key,
109            RegisteredCompletion {
110                ref_type,
111                ref_value,
112                arg_name,
113                handler: boxed,
114            },
115        );
116    }
117
118    /// Complete an argument value.
119    pub async fn complete(
120        &self,
121        request: &CompleteRequest,
122        ctx: &Context<'_>,
123    ) -> Result<CompleteResult, McpError> {
124        let ref_type = request.ref_.ref_type().to_string();
125        let ref_value = request.ref_.value().to_string();
126        let arg_name = request.argument.name.clone();
127        let input = &request.argument.value;
128
129        let key = (ref_type, ref_value, arg_name);
130
131        if let Some(registered) = self.completions.get(&key) {
132            let values = (registered.handler)(input, ctx).await?;
133            let total = values.len();
134
135            Ok(CompleteResult {
136                completion: Completion {
137                    values,
138                    total: Some(CompletionTotal::Exact(total)),
139                    has_more: Some(false),
140                },
141            })
142        } else {
143            // Return empty completions if no handler registered
144            Ok(CompleteResult {
145                completion: Completion {
146                    values: Vec::new(),
147                    total: Some(CompletionTotal::Exact(0)),
148                    has_more: Some(false),
149                },
150            })
151        }
152    }
153
154    /// Check if a completion provider exists.
155    #[must_use]
156    pub fn has_completion(&self, ref_type: &str, ref_value: &str, arg_name: &str) -> bool {
157        let key = (
158            ref_type.to_string(),
159            ref_value.to_string(),
160            arg_name.to_string(),
161        );
162        self.completions.contains_key(&key)
163    }
164}
165
166impl CompletionHandler for CompletionService {
167    async fn complete_resource(
168        &self,
169        partial_uri: &str,
170        ctx: &Context<'_>,
171    ) -> Result<Vec<String>, McpError> {
172        let request = CompleteRequest {
173            ref_: CompletionRef::resource(partial_uri),
174            argument: CompletionArgument {
175                name: "uri".to_string(),
176                value: partial_uri.to_string(),
177            },
178        };
179
180        let result = Self::complete(self, &request, ctx).await?;
181        Ok(result.completion.values)
182    }
183
184    async fn complete_prompt_arg(
185        &self,
186        prompt_name: &str,
187        arg_name: &str,
188        partial_value: &str,
189        ctx: &Context<'_>,
190    ) -> Result<Vec<String>, McpError> {
191        let request = CompleteRequest {
192            ref_: CompletionRef::prompt(prompt_name),
193            argument: CompletionArgument {
194                name: arg_name.to_string(),
195                value: partial_value.to_string(),
196            },
197        };
198
199        let result = Self::complete(self, &request, ctx).await?;
200        Ok(result.completion.values)
201    }
202}
203
204/// Builder for creating completion requests.
205pub struct CompleteRequestBuilder {
206    ref_: CompletionRef,
207    arg_name: String,
208    arg_value: String,
209}
210
211impl CompleteRequestBuilder {
212    /// Create a builder for prompt argument completion.
213    pub fn for_prompt(prompt_name: impl Into<String>, arg_name: impl Into<String>) -> Self {
214        Self {
215            ref_: CompletionRef::prompt(prompt_name.into()),
216            arg_name: arg_name.into(),
217            arg_value: String::new(),
218        }
219    }
220
221    /// Create a builder for resource argument completion.
222    pub fn for_resource(uri: impl Into<String>, arg_name: impl Into<String>) -> Self {
223        Self {
224            ref_: CompletionRef::resource(uri.into()),
225            arg_name: arg_name.into(),
226            arg_value: String::new(),
227        }
228    }
229
230    /// Set the current input value.
231    pub fn value(mut self, value: impl Into<String>) -> Self {
232        self.arg_value = value.into();
233        self
234    }
235
236    /// Build the request.
237    #[must_use]
238    pub fn build(self) -> CompleteRequest {
239        CompleteRequest {
240            ref_: self.ref_,
241            argument: mcpkit_core::types::completion::CompletionArgument {
242                name: self.arg_name,
243                value: self.arg_value,
244            },
245        }
246    }
247}
248
249/// Helper for filtering completions.
250pub struct CompletionFilter;
251
252impl CompletionFilter {
253    /// Filter completions by prefix match.
254    #[must_use]
255    pub fn by_prefix(values: &[String], prefix: &str) -> Vec<String> {
256        let prefix_lower = prefix.to_lowercase();
257        values
258            .iter()
259            .filter(|v| v.to_lowercase().starts_with(&prefix_lower))
260            .cloned()
261            .collect()
262    }
263
264    /// Filter completions by substring match.
265    #[must_use]
266    pub fn by_substring(values: &[String], substring: &str) -> Vec<String> {
267        let sub_lower = substring.to_lowercase();
268        values
269            .iter()
270            .filter(|v| v.to_lowercase().contains(&sub_lower))
271            .cloned()
272            .collect()
273    }
274
275    /// Filter and limit completions.
276    #[must_use]
277    pub fn limit(values: Vec<String>, max: usize) -> Vec<String> {
278        values.into_iter().take(max).collect()
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use crate::context::{Context, NoOpPeer};
286    use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
287    use mcpkit_core::protocol::RequestId;
288    use mcpkit_core::protocol_version::ProtocolVersion;
289
290    fn make_context() -> (
291        RequestId,
292        ClientCapabilities,
293        ServerCapabilities,
294        ProtocolVersion,
295        NoOpPeer,
296    ) {
297        (
298            RequestId::Number(1),
299            ClientCapabilities::default(),
300            ServerCapabilities::default(),
301            ProtocolVersion::LATEST,
302            NoOpPeer,
303        )
304    }
305
306    #[test]
307    fn test_complete_request_builder() {
308        let request = CompleteRequestBuilder::for_prompt("code-review", "language")
309            .value("py")
310            .build();
311
312        assert_eq!(request.ref_.ref_type(), "ref/prompt");
313        assert_eq!(request.argument.name, "language");
314        assert_eq!(request.argument.value, "py");
315    }
316
317    #[test]
318    fn test_completion_filter() {
319        let values = vec![
320            "python".to_string(),
321            "javascript".to_string(),
322            "typescript".to_string(),
323            "rust".to_string(),
324        ];
325
326        let filtered = CompletionFilter::by_prefix(&values, "py");
327        assert_eq!(filtered, vec!["python"]);
328
329        let filtered = CompletionFilter::by_substring(&values, "script");
330        assert_eq!(filtered, vec!["javascript", "typescript"]);
331
332        let limited = CompletionFilter::limit(values, 2);
333        assert_eq!(limited.len(), 2);
334    }
335
336    #[tokio::test]
337    async fn test_completion_service() -> Result<(), Box<dyn std::error::Error>> {
338        let mut service = CompletionService::new();
339
340        let languages = vec![
341            "python".to_string(),
342            "javascript".to_string(),
343            "typescript".to_string(),
344            "rust".to_string(),
345        ];
346
347        service.register_prompt_completion("code-review", "language", move |input, _ctx| {
348            let langs = languages.clone();
349            let input = input.to_string();
350            async move { Ok(CompletionFilter::by_prefix(&langs, &input)) }
351        });
352
353        assert!(service.has_completion("ref/prompt", "code-review", "language"));
354
355        let (req_id, client_caps, server_caps, protocol_version, peer) = make_context();
356        let ctx = Context::new(
357            &req_id,
358            None,
359            &client_caps,
360            &server_caps,
361            protocol_version,
362            &peer,
363        );
364
365        let request = CompleteRequestBuilder::for_prompt("code-review", "language")
366            .value("py")
367            .build();
368
369        let result = service.complete(&request, &ctx).await?;
370        assert_eq!(result.completion.values, vec!["python"]);
371
372        Ok(())
373    }
374}