1use crate::context::Context;
7use crate::handler::CompletionHandler;
8use mcpkit_core::error::McpError;
9use mcpkit_core::types::completion::{
10 CompleteRequest, CompleteResult, Completion, CompletionArgument, CompletionRef, CompletionTotal,
11};
12use std::collections::HashMap;
13use std::future::Future;
14use std::pin::Pin;
15
16pub type BoxedCompletionFn = Box<
18 dyn for<'a> Fn(
19 &'a str,
20 &'a Context<'a>,
21 )
22 -> Pin<Box<dyn Future<Output = Result<Vec<String>, McpError>> + Send + 'a>>
23 + Send
24 + Sync,
25>;
26
27pub struct RegisteredCompletion {
29 pub ref_type: String,
31 pub ref_value: String,
33 pub arg_name: String,
35 pub handler: BoxedCompletionFn,
37}
38
39pub struct CompletionService {
43 completions: HashMap<(String, String, String), RegisteredCompletion>,
45}
46
47impl Default for CompletionService {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl CompletionService {
54 #[must_use]
56 pub fn new() -> Self {
57 Self {
58 completions: HashMap::new(),
59 }
60 }
61
62 pub fn register_prompt_completion<F, Fut>(
64 &mut self,
65 prompt_name: impl Into<String>,
66 arg_name: impl Into<String>,
67 handler: F,
68 ) where
69 F: Fn(&str, &Context<'_>) -> Fut + Send + Sync + 'static,
70 Fut: Future<Output = Result<Vec<String>, McpError>> + Send + 'static,
71 {
72 let ref_type = "ref/prompt".to_string();
73 let ref_value = prompt_name.into();
74 let arg_name = arg_name.into();
75 let key = (ref_type.clone(), ref_value.clone(), arg_name.clone());
76
77 let boxed: BoxedCompletionFn = Box::new(move |input, ctx| Box::pin(handler(input, ctx)));
78
79 self.completions.insert(
80 key,
81 RegisteredCompletion {
82 ref_type,
83 ref_value,
84 arg_name,
85 handler: boxed,
86 },
87 );
88 }
89
90 pub fn register_resource_completion<F, Fut>(
92 &mut self,
93 uri_pattern: impl Into<String>,
94 arg_name: impl Into<String>,
95 handler: F,
96 ) where
97 F: Fn(&str, &Context<'_>) -> Fut + Send + Sync + 'static,
98 Fut: Future<Output = Result<Vec<String>, McpError>> + Send + 'static,
99 {
100 let ref_type = "ref/resource".to_string();
101 let ref_value = uri_pattern.into();
102 let arg_name = arg_name.into();
103 let key = (ref_type.clone(), ref_value.clone(), arg_name.clone());
104
105 let boxed: BoxedCompletionFn = Box::new(move |input, ctx| Box::pin(handler(input, ctx)));
106
107 self.completions.insert(
108 key,
109 RegisteredCompletion {
110 ref_type,
111 ref_value,
112 arg_name,
113 handler: boxed,
114 },
115 );
116 }
117
118 pub async fn complete(
120 &self,
121 request: &CompleteRequest,
122 ctx: &Context<'_>,
123 ) -> Result<CompleteResult, McpError> {
124 let ref_type = request.ref_.ref_type().to_string();
125 let ref_value = request.ref_.value().to_string();
126 let arg_name = request.argument.name.clone();
127 let input = &request.argument.value;
128
129 let key = (ref_type, ref_value, arg_name);
130
131 if let Some(registered) = self.completions.get(&key) {
132 let values = (registered.handler)(input, ctx).await?;
133 let total = values.len();
134
135 Ok(CompleteResult {
136 completion: Completion {
137 values,
138 total: Some(CompletionTotal::Exact(total)),
139 has_more: Some(false),
140 },
141 })
142 } else {
143 Ok(CompleteResult {
145 completion: Completion {
146 values: Vec::new(),
147 total: Some(CompletionTotal::Exact(0)),
148 has_more: Some(false),
149 },
150 })
151 }
152 }
153
154 #[must_use]
156 pub fn has_completion(&self, ref_type: &str, ref_value: &str, arg_name: &str) -> bool {
157 let key = (
158 ref_type.to_string(),
159 ref_value.to_string(),
160 arg_name.to_string(),
161 );
162 self.completions.contains_key(&key)
163 }
164}
165
166impl CompletionHandler for CompletionService {
167 async fn complete_resource(
168 &self,
169 partial_uri: &str,
170 ctx: &Context<'_>,
171 ) -> Result<Vec<String>, McpError> {
172 let request = CompleteRequest {
173 ref_: CompletionRef::resource(partial_uri),
174 argument: CompletionArgument {
175 name: "uri".to_string(),
176 value: partial_uri.to_string(),
177 },
178 };
179
180 let result = Self::complete(self, &request, ctx).await?;
181 Ok(result.completion.values)
182 }
183
184 async fn complete_prompt_arg(
185 &self,
186 prompt_name: &str,
187 arg_name: &str,
188 partial_value: &str,
189 ctx: &Context<'_>,
190 ) -> Result<Vec<String>, McpError> {
191 let request = CompleteRequest {
192 ref_: CompletionRef::prompt(prompt_name),
193 argument: CompletionArgument {
194 name: arg_name.to_string(),
195 value: partial_value.to_string(),
196 },
197 };
198
199 let result = Self::complete(self, &request, ctx).await?;
200 Ok(result.completion.values)
201 }
202}
203
204pub struct CompleteRequestBuilder {
206 ref_: CompletionRef,
207 arg_name: String,
208 arg_value: String,
209}
210
211impl CompleteRequestBuilder {
212 pub fn for_prompt(prompt_name: impl Into<String>, arg_name: impl Into<String>) -> Self {
214 Self {
215 ref_: CompletionRef::prompt(prompt_name.into()),
216 arg_name: arg_name.into(),
217 arg_value: String::new(),
218 }
219 }
220
221 pub fn for_resource(uri: impl Into<String>, arg_name: impl Into<String>) -> Self {
223 Self {
224 ref_: CompletionRef::resource(uri.into()),
225 arg_name: arg_name.into(),
226 arg_value: String::new(),
227 }
228 }
229
230 pub fn value(mut self, value: impl Into<String>) -> Self {
232 self.arg_value = value.into();
233 self
234 }
235
236 #[must_use]
238 pub fn build(self) -> CompleteRequest {
239 CompleteRequest {
240 ref_: self.ref_,
241 argument: mcpkit_core::types::completion::CompletionArgument {
242 name: self.arg_name,
243 value: self.arg_value,
244 },
245 }
246 }
247}
248
249pub struct CompletionFilter;
251
252impl CompletionFilter {
253 #[must_use]
255 pub fn by_prefix(values: &[String], prefix: &str) -> Vec<String> {
256 let prefix_lower = prefix.to_lowercase();
257 values
258 .iter()
259 .filter(|v| v.to_lowercase().starts_with(&prefix_lower))
260 .cloned()
261 .collect()
262 }
263
264 #[must_use]
266 pub fn by_substring(values: &[String], substring: &str) -> Vec<String> {
267 let sub_lower = substring.to_lowercase();
268 values
269 .iter()
270 .filter(|v| v.to_lowercase().contains(&sub_lower))
271 .cloned()
272 .collect()
273 }
274
275 #[must_use]
277 pub fn limit(values: Vec<String>, max: usize) -> Vec<String> {
278 values.into_iter().take(max).collect()
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use crate::context::{Context, NoOpPeer};
286 use mcpkit_core::capability::{ClientCapabilities, ServerCapabilities};
287 use mcpkit_core::protocol::RequestId;
288 use mcpkit_core::protocol_version::ProtocolVersion;
289
290 fn make_context() -> (
291 RequestId,
292 ClientCapabilities,
293 ServerCapabilities,
294 ProtocolVersion,
295 NoOpPeer,
296 ) {
297 (
298 RequestId::Number(1),
299 ClientCapabilities::default(),
300 ServerCapabilities::default(),
301 ProtocolVersion::LATEST,
302 NoOpPeer,
303 )
304 }
305
306 #[test]
307 fn test_complete_request_builder() {
308 let request = CompleteRequestBuilder::for_prompt("code-review", "language")
309 .value("py")
310 .build();
311
312 assert_eq!(request.ref_.ref_type(), "ref/prompt");
313 assert_eq!(request.argument.name, "language");
314 assert_eq!(request.argument.value, "py");
315 }
316
317 #[test]
318 fn test_completion_filter() {
319 let values = vec![
320 "python".to_string(),
321 "javascript".to_string(),
322 "typescript".to_string(),
323 "rust".to_string(),
324 ];
325
326 let filtered = CompletionFilter::by_prefix(&values, "py");
327 assert_eq!(filtered, vec!["python"]);
328
329 let filtered = CompletionFilter::by_substring(&values, "script");
330 assert_eq!(filtered, vec!["javascript", "typescript"]);
331
332 let limited = CompletionFilter::limit(values, 2);
333 assert_eq!(limited.len(), 2);
334 }
335
336 #[tokio::test]
337 async fn test_completion_service() -> Result<(), Box<dyn std::error::Error>> {
338 let mut service = CompletionService::new();
339
340 let languages = vec![
341 "python".to_string(),
342 "javascript".to_string(),
343 "typescript".to_string(),
344 "rust".to_string(),
345 ];
346
347 service.register_prompt_completion("code-review", "language", move |input, _ctx| {
348 let langs = languages.clone();
349 let input = input.to_string();
350 async move { Ok(CompletionFilter::by_prefix(&langs, &input)) }
351 });
352
353 assert!(service.has_completion("ref/prompt", "code-review", "language"));
354
355 let (req_id, client_caps, server_caps, protocol_version, peer) = make_context();
356 let ctx = Context::new(
357 &req_id,
358 None,
359 &client_caps,
360 &server_caps,
361 protocol_version,
362 &peer,
363 );
364
365 let request = CompleteRequestBuilder::for_prompt("code-review", "language")
366 .value("py")
367 .build();
368
369 let result = service.complete(&request, &ctx).await?;
370 assert_eq!(result.completion.values, vec!["python"]);
371
372 Ok(())
373 }
374}