mcpkit_server/capability/
completions.rs1use crate::context::Context;
7use crate::handler::CompletionHandler;
8use mcpkit_core::error::McpError;
9use mcpkit_core::types::completion::{
10 CompleteRequest, CompleteResult, Completion, CompletionArgument, CompletionRef, CompletionTotal,
11};
12use std::collections::HashMap;
13use std::future::Future;
14use std::pin::Pin;
15
16pub type BoxedCompletionFn = Box<
18 dyn for<'a> Fn(
19 &'a str,
20 &'a Context<'a>,
21 ) -> Pin<Box<dyn Future<Output = Result<Vec<String>, McpError>> + Send + 'a>>
22 + Send
23 + Sync,
24>;
25
26pub struct RegisteredCompletion {
28 pub ref_type: String,
30 pub ref_value: String,
32 pub arg_name: String,
34 pub handler: BoxedCompletionFn,
36}
37
38pub struct CompletionService {
42 completions: HashMap<(String, String, String), RegisteredCompletion>,
44}
45
46impl Default for CompletionService {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52impl CompletionService {
53 pub fn new() -> Self {
55 Self {
56 completions: HashMap::new(),
57 }
58 }
59
60 pub fn register_prompt_completion<F, Fut>(
62 &mut self,
63 prompt_name: impl Into<String>,
64 arg_name: impl Into<String>,
65 handler: F,
66 ) where
67 F: Fn(&str, &Context<'_>) -> Fut + Send + Sync + 'static,
68 Fut: Future<Output = Result<Vec<String>, McpError>> + Send + 'static,
69 {
70 let ref_type = "ref/prompt".to_string();
71 let ref_value = prompt_name.into();
72 let arg_name = arg_name.into();
73 let key = (ref_type.clone(), ref_value.clone(), arg_name.clone());
74
75 let boxed: BoxedCompletionFn = Box::new(move |input, ctx| Box::pin(handler(input, ctx)));
76
77 self.completions.insert(
78 key,
79 RegisteredCompletion {
80 ref_type,
81 ref_value,
82 arg_name,
83 handler: boxed,
84 },
85 );
86 }
87
88 pub fn register_resource_completion<F, Fut>(
90 &mut self,
91 uri_pattern: impl Into<String>,
92 arg_name: impl Into<String>,
93 handler: F,
94 ) where
95 F: Fn(&str, &Context<'_>) -> Fut + Send + Sync + 'static,
96 Fut: Future<Output = Result<Vec<String>, McpError>> + Send + 'static,
97 {
98 let ref_type = "ref/resource".to_string();
99 let ref_value = uri_pattern.into();
100 let arg_name = arg_name.into();
101 let key = (ref_type.clone(), ref_value.clone(), arg_name.clone());
102
103 let boxed: BoxedCompletionFn = Box::new(move |input, ctx| Box::pin(handler(input, ctx)));
104
105 self.completions.insert(
106 key,
107 RegisteredCompletion {
108 ref_type,
109 ref_value,
110 arg_name,
111 handler: boxed,
112 },
113 );
114 }
115
116 pub async fn complete(
118 &self,
119 request: &CompleteRequest,
120 ctx: &Context<'_>,
121 ) -> Result<CompleteResult, McpError> {
122 let ref_type = request.ref_.ref_type().to_string();
123 let ref_value = request.ref_.value().to_string();
124 let arg_name = request.argument.name.clone();
125 let input = &request.argument.value;
126
127 let key = (ref_type, ref_value, arg_name);
128
129 if let Some(registered) = self.completions.get(&key) {
130 let values = (registered.handler)(input, ctx).await?;
131 let total = values.len();
132
133 Ok(CompleteResult {
134 completion: Completion {
135 values,
136 total: Some(CompletionTotal::Exact(total)),
137 has_more: Some(false),
138 },
139 })
140 } else {
141 Ok(CompleteResult {
143 completion: Completion {
144 values: Vec::new(),
145 total: Some(CompletionTotal::Exact(0)),
146 has_more: Some(false),
147 },
148 })
149 }
150 }
151
152 pub fn has_completion(
154 &self,
155 ref_type: &str,
156 ref_value: &str,
157 arg_name: &str,
158 ) -> bool {
159 let key = (ref_type.to_string(), ref_value.to_string(), arg_name.to_string());
160 self.completions.contains_key(&key)
161 }
162}
163
164impl CompletionHandler for CompletionService {
165 async fn complete_resource(
166 &self,
167 partial_uri: &str,
168 ctx: &Context<'_>,
169 ) -> Result<Vec<String>, McpError> {
170 let request = CompleteRequest {
171 ref_: CompletionRef::resource(partial_uri),
172 argument: CompletionArgument {
173 name: "uri".to_string(),
174 value: partial_uri.to_string(),
175 },
176 };
177
178 let result = CompletionService::complete(self, &request, ctx).await?;
179 Ok(result.completion.values)
180 }
181
182 async fn complete_prompt_arg(
183 &self,
184 prompt_name: &str,
185 arg_name: &str,
186 partial_value: &str,
187 ctx: &Context<'_>,
188 ) -> Result<Vec<String>, McpError> {
189 let request = CompleteRequest {
190 ref_: CompletionRef::prompt(prompt_name),
191 argument: CompletionArgument {
192 name: arg_name.to_string(),
193 value: partial_value.to_string(),
194 },
195 };
196
197 let result = CompletionService::complete(self, &request, ctx).await?;
198 Ok(result.completion.values)
199 }
200}
201
202pub struct CompleteRequestBuilder {
204 ref_: CompletionRef,
205 arg_name: String,
206 arg_value: String,
207}
208
209impl CompleteRequestBuilder {
210 pub fn for_prompt(prompt_name: impl Into<String>, arg_name: impl Into<String>) -> Self {
212 Self {
213 ref_: CompletionRef::prompt(prompt_name.into()),
214 arg_name: arg_name.into(),
215 arg_value: String::new(),
216 }
217 }
218
219 pub fn for_resource(uri: impl Into<String>, arg_name: impl Into<String>) -> Self {
221 Self {
222 ref_: CompletionRef::resource(uri.into()),
223 arg_name: arg_name.into(),
224 arg_value: String::new(),
225 }
226 }
227
228 pub fn value(mut self, value: impl Into<String>) -> Self {
230 self.arg_value = value.into();
231 self
232 }
233
234 pub fn build(self) -> CompleteRequest {
236 CompleteRequest {
237 ref_: self.ref_,
238 argument: mcpkit_core::types::completion::CompletionArgument {
239 name: self.arg_name,
240 value: self.arg_value,
241 },
242 }
243 }
244}
245
246pub struct CompletionFilter;
248
249impl CompletionFilter {
250 pub fn by_prefix(values: &[String], prefix: &str) -> Vec<String> {
252 let prefix_lower = prefix.to_lowercase();
253 values
254 .iter()
255 .filter(|v| v.to_lowercase().starts_with(&prefix_lower))
256 .cloned()
257 .collect()
258 }
259
260 pub fn by_substring(values: &[String], substring: &str) -> Vec<String> {
262 let sub_lower = substring.to_lowercase();
263 values
264 .iter()
265 .filter(|v| v.to_lowercase().contains(&sub_lower))
266 .cloned()
267 .collect()
268 }
269
270 pub fn limit(values: Vec<String>, max: usize) -> Vec<String> {
272 values.into_iter().take(max).collect()
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use crate::context::{NoOpPeer, Context};
280 use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
281 use mcpkit_core::protocol::RequestId;
282
283 fn make_context() -> (RequestId, ClientCapabilities, ServerCapabilities, NoOpPeer) {
284 (
285 RequestId::Number(1),
286 ClientCapabilities::default(),
287 ServerCapabilities::default(),
288 NoOpPeer,
289 )
290 }
291
292 #[test]
293 fn test_complete_request_builder() {
294 let request = CompleteRequestBuilder::for_prompt("code-review", "language")
295 .value("py")
296 .build();
297
298 assert_eq!(request.ref_.ref_type(), "ref/prompt");
299 assert_eq!(request.argument.name, "language");
300 assert_eq!(request.argument.value, "py");
301 }
302
303 #[test]
304 fn test_completion_filter() {
305 let values = vec![
306 "python".to_string(),
307 "javascript".to_string(),
308 "typescript".to_string(),
309 "rust".to_string(),
310 ];
311
312 let filtered = CompletionFilter::by_prefix(&values, "py");
313 assert_eq!(filtered, vec!["python"]);
314
315 let filtered = CompletionFilter::by_substring(&values, "script");
316 assert_eq!(filtered, vec!["javascript", "typescript"]);
317
318 let limited = CompletionFilter::limit(values.clone(), 2);
319 assert_eq!(limited.len(), 2);
320 }
321
322 #[tokio::test]
323 async fn test_completion_service() {
324 let mut service = CompletionService::new();
325
326 let languages = vec![
327 "python".to_string(),
328 "javascript".to_string(),
329 "typescript".to_string(),
330 "rust".to_string(),
331 ];
332
333 service.register_prompt_completion("code-review", "language", move |input, _ctx| {
334 let langs = languages.clone();
335 let input = input.to_string();
336 async move { Ok(CompletionFilter::by_prefix(&langs, &input)) }
337 });
338
339 assert!(service.has_completion("ref/prompt", "code-review", "language"));
340
341 let (req_id, client_caps, server_caps, peer) = make_context();
342 let ctx = Context::new(&req_id, None, &client_caps, &server_caps, &peer);
343
344 let request = CompleteRequestBuilder::for_prompt("code-review", "language")
345 .value("py")
346 .build();
347
348 let result = service.complete(&request, &ctx).await.unwrap();
349 assert_eq!(result.completion.values, vec!["python"]);
350 }
351}