1use std::sync::Arc;
45
46use async_trait::async_trait;
47use dashmap::DashMap;
48use thiserror::Error;
49
50use crate::protocol::types::{
51 CompleteRequest, CompleteResult, CompletionInfo, CompletionRef, CompletionValue,
52};
53
54#[derive(Debug, Error)]
56pub enum CompletionError {
57 #[error("Invalid reference: {0}")]
59 InvalidReference(String),
60
61 #[error("No completion provider for: {0}")]
63 NoProvider(String),
64
65 #[error("Completion error: {0}")]
67 Internal(String),
68}
69
70pub const DEFAULT_MAX_COMPLETIONS: usize = 100;
72
73#[async_trait]
77pub trait CompletionProvider: Send + Sync {
78 async fn complete_prompt(
90 &self,
91 prompt_name: &str,
92 argument_name: &str,
93 argument_value: &str,
94 ) -> Vec<CompletionValue> {
95 let _ = (prompt_name, argument_name, argument_value);
96 vec![]
97 }
98
99 async fn complete_resource(
111 &self,
112 uri: &str,
113 argument_name: &str,
114 argument_value: &str,
115 ) -> Vec<CompletionValue> {
116 let _ = (uri, argument_name, argument_value);
117 vec![]
118 }
119
120 fn max_completions(&self) -> usize {
122 DEFAULT_MAX_COMPLETIONS
123 }
124}
125
126pub struct NoOpCompletionProvider;
128
129#[async_trait]
130impl CompletionProvider for NoOpCompletionProvider {}
131
132#[derive(Clone)]
137pub struct CompletionManager {
138 prompt_providers: Arc<DashMap<String, Arc<dyn CompletionProvider>>>,
140 resource_providers: Arc<DashMap<String, Arc<dyn CompletionProvider>>>,
142 default_provider: Arc<dyn CompletionProvider>,
144 max_completions: usize,
146}
147
148impl Default for CompletionManager {
149 fn default() -> Self {
150 Self::new()
151 }
152}
153
154impl CompletionManager {
155 pub fn new() -> Self {
157 Self {
158 prompt_providers: Arc::new(DashMap::new()),
159 resource_providers: Arc::new(DashMap::new()),
160 default_provider: Arc::new(NoOpCompletionProvider),
161 max_completions: DEFAULT_MAX_COMPLETIONS,
162 }
163 }
164
165 pub fn with_default_provider(provider: Arc<dyn CompletionProvider>) -> Self {
167 Self {
168 prompt_providers: Arc::new(DashMap::new()),
169 resource_providers: Arc::new(DashMap::new()),
170 default_provider: provider,
171 max_completions: DEFAULT_MAX_COMPLETIONS,
172 }
173 }
174
175 pub fn set_max_completions(&mut self, max: usize) {
177 self.max_completions = max;
178 }
179
180 pub fn register_prompt_provider(
182 &self,
183 prompt_name: impl Into<String>,
184 provider: Arc<dyn CompletionProvider>,
185 ) {
186 self.prompt_providers.insert(prompt_name.into(), provider);
187 }
188
189 pub fn register_resource_provider(
191 &self,
192 uri_pattern: impl Into<String>,
193 provider: Arc<dyn CompletionProvider>,
194 ) {
195 self.resource_providers.insert(uri_pattern.into(), provider);
196 }
197
198 pub fn set_default_provider(&mut self, provider: Arc<dyn CompletionProvider>) {
200 self.default_provider = provider;
201 }
202
203 pub async fn complete(&self, request: CompleteRequest) -> Result<CompleteResult, CompletionError> {
205 let values = match &request.reference {
206 CompletionRef::Prompt { name } => {
207 let provider = self
208 .prompt_providers
209 .get(name)
210 .map(|p| Arc::clone(&p))
211 .unwrap_or_else(|| Arc::clone(&self.default_provider));
212
213 provider
214 .complete_prompt(name, &request.argument.name, &request.argument.value)
215 .await
216 }
217 CompletionRef::Resource { uri } => {
218 let provider = self
220 .resource_providers
221 .get(uri)
222 .map(|p| Arc::clone(&p))
223 .or_else(|| self.find_matching_resource_provider(uri))
224 .unwrap_or_else(|| Arc::clone(&self.default_provider));
225
226 provider
227 .complete_resource(uri, &request.argument.name, &request.argument.value)
228 .await
229 }
230 };
231
232 let total = values.len();
234 let max = self.max_completions;
235 let has_more = total > max;
236 let values: Vec<_> = values.into_iter().take(max).collect();
237
238 let completion = if has_more {
239 CompletionInfo::with_pagination(values, total, true)
240 } else {
241 CompletionInfo::with_values(values)
242 };
243
244 Ok(CompleteResult { completion })
245 }
246
247 fn find_matching_resource_provider(&self, uri: &str) -> Option<Arc<dyn CompletionProvider>> {
249 for entry in self.resource_providers.iter() {
252 let pattern = entry.key();
253 if uri.starts_with(pattern) || pattern.starts_with(uri) {
254 return Some(Arc::clone(entry.value()));
255 }
256
257 if let Some(provider) = self.match_template_pattern(pattern, uri, entry.value()) {
259 return Some(provider);
260 }
261 }
262 None
263 }
264
265 fn match_template_pattern(
267 &self,
268 pattern: &str,
269 uri: &str,
270 provider: &Arc<dyn CompletionProvider>,
271 ) -> Option<Arc<dyn CompletionProvider>> {
272 let pattern_prefix = pattern.split('{').next()?;
274 if uri.starts_with(pattern_prefix) {
275 return Some(Arc::clone(provider));
276 }
277 None
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 struct TestCompletionProvider {
287 prompt_completions: Vec<CompletionValue>,
288 resource_completions: Vec<CompletionValue>,
289 }
290
291 impl TestCompletionProvider {
292 fn new(prompt: Vec<&str>, resource: Vec<&str>) -> Self {
293 Self {
294 prompt_completions: prompt.into_iter().map(CompletionValue::new).collect(),
295 resource_completions: resource.into_iter().map(CompletionValue::new).collect(),
296 }
297 }
298 }
299
300 #[async_trait]
301 impl CompletionProvider for TestCompletionProvider {
302 async fn complete_prompt(
303 &self,
304 _prompt_name: &str,
305 _argument_name: &str,
306 argument_value: &str,
307 ) -> Vec<CompletionValue> {
308 self.prompt_completions
309 .iter()
310 .filter(|v| v.value.starts_with(argument_value))
311 .cloned()
312 .collect()
313 }
314
315 async fn complete_resource(
316 &self,
317 _uri: &str,
318 _argument_name: &str,
319 argument_value: &str,
320 ) -> Vec<CompletionValue> {
321 self.resource_completions
322 .iter()
323 .filter(|v| v.value.starts_with(argument_value))
324 .cloned()
325 .collect()
326 }
327 }
328
329 #[tokio::test]
330 async fn test_completion_manager_new() {
331 let manager = CompletionManager::new();
332 assert_eq!(manager.max_completions, DEFAULT_MAX_COMPLETIONS);
333 }
334
335 #[tokio::test]
336 async fn test_prompt_completion() {
337 let manager = CompletionManager::new();
338 let provider = Arc::new(TestCompletionProvider::new(
339 vec!["Alice", "Amy", "Bob"],
340 vec![],
341 ));
342 manager.register_prompt_provider("greet", provider);
343
344 let request = CompleteRequest {
345 reference: CompletionRef::Prompt {
346 name: "greet".to_string(),
347 },
348 argument: crate::protocol::types::CompletionArgument {
349 name: "name".to_string(),
350 value: "A".to_string(),
351 },
352 };
353
354 let result = manager.complete(request).await.unwrap();
355 assert_eq!(result.completion.values.len(), 2); assert!(result.completion.values.iter().all(|v| v.value.starts_with("A")));
357 }
358
359 #[tokio::test]
360 async fn test_resource_completion() {
361 let manager = CompletionManager::new();
362 let provider = Arc::new(TestCompletionProvider::new(
363 vec![],
364 vec!["src/main.rs", "src/lib.rs", "tests/test.rs"],
365 ));
366 manager.register_resource_provider("file:///", provider);
367
368 let request = CompleteRequest {
369 reference: CompletionRef::Resource {
370 uri: "file:///src".to_string(),
371 },
372 argument: crate::protocol::types::CompletionArgument {
373 name: "path".to_string(),
374 value: "src/".to_string(),
375 },
376 };
377
378 let result = manager.complete(request).await.unwrap();
379 assert_eq!(result.completion.values.len(), 2); }
381
382 #[tokio::test]
383 async fn test_default_provider() {
384 let default = Arc::new(TestCompletionProvider::new(
385 vec!["default1", "default2"],
386 vec![],
387 ));
388 let manager = CompletionManager::with_default_provider(default);
389
390 let request = CompleteRequest {
392 reference: CompletionRef::Prompt {
393 name: "unknown_prompt".to_string(),
394 },
395 argument: crate::protocol::types::CompletionArgument {
396 name: "arg".to_string(),
397 value: "default".to_string(),
398 },
399 };
400
401 let result = manager.complete(request).await.unwrap();
402 assert_eq!(result.completion.values.len(), 2);
403 }
404
405 #[tokio::test]
406 async fn test_max_completions_limit() {
407 let mut manager = CompletionManager::new();
408 manager.set_max_completions(2);
409
410 let provider = Arc::new(TestCompletionProvider::new(
411 vec!["a1", "a2", "a3", "a4", "a5"],
412 vec![],
413 ));
414 manager.register_prompt_provider("test", provider);
415
416 let request = CompleteRequest {
417 reference: CompletionRef::Prompt {
418 name: "test".to_string(),
419 },
420 argument: crate::protocol::types::CompletionArgument {
421 name: "arg".to_string(),
422 value: "a".to_string(),
423 },
424 };
425
426 let result = manager.complete(request).await.unwrap();
427 assert_eq!(result.completion.values.len(), 2);
428 assert_eq!(result.completion.total, Some(5));
429 assert_eq!(result.completion.has_more, Some(true));
430 }
431
432 #[tokio::test]
433 async fn test_empty_completions() {
434 let manager = CompletionManager::new();
435
436 let request = CompleteRequest {
437 reference: CompletionRef::Prompt {
438 name: "nonexistent".to_string(),
439 },
440 argument: crate::protocol::types::CompletionArgument {
441 name: "arg".to_string(),
442 value: "x".to_string(),
443 },
444 };
445
446 let result = manager.complete(request).await.unwrap();
447 assert!(result.completion.values.is_empty());
448 }
449
450 #[tokio::test]
451 async fn test_completion_value_constructors() {
452 let simple = CompletionValue::new("value");
453 assert_eq!(simple.value, "value");
454 assert!(simple.label.is_none());
455 assert!(simple.description.is_none());
456
457 let with_label = CompletionValue::with_label("value", "Label");
458 assert_eq!(with_label.value, "value");
459 assert_eq!(with_label.label, Some("Label".to_string()));
460 assert!(with_label.description.is_none());
461
462 let with_desc = CompletionValue::with_description("value", "Description");
463 assert_eq!(with_desc.value, "value");
464 assert!(with_desc.label.is_none());
465 assert_eq!(with_desc.description, Some("Description".to_string()));
466
467 let full = CompletionValue::full("value", "Label", "Description");
468 assert_eq!(full.value, "value");
469 assert_eq!(full.label, Some("Label".to_string()));
470 assert_eq!(full.description, Some("Description".to_string()));
471 }
472
473 #[tokio::test]
474 async fn test_completion_info_constructors() {
475 let empty = CompletionInfo::empty();
476 assert!(empty.values.is_empty());
477 assert!(empty.total.is_none());
478 assert!(empty.has_more.is_none());
479
480 let with_values = CompletionInfo::with_values(vec![
481 CompletionValue::new("a"),
482 CompletionValue::new("b"),
483 ]);
484 assert_eq!(with_values.values.len(), 2);
485 assert!(with_values.total.is_none());
486 assert!(with_values.has_more.is_none());
487
488 let paginated = CompletionInfo::with_pagination(
489 vec![CompletionValue::new("a")],
490 10,
491 true,
492 );
493 assert_eq!(paginated.values.len(), 1);
494 assert_eq!(paginated.total, Some(10));
495 assert_eq!(paginated.has_more, Some(true));
496 }
497}