use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CompletableArgument {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(default)]
pub required: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub completion: Option<CompletionConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CompletionConfig {
pub provider: CompletionProvider,
#[serde(flatten)]
pub config: HashMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum CompletionProvider {
Static,
Resource,
Tool,
File,
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CompletionRequest {
pub argument: String,
pub partial: String,
pub context: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CompletionResponse {
pub completions: Vec<CompletionItem>,
#[serde(default)]
pub has_more: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub continuation_token: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CompletionItem {
pub value: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub label: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub icon: Option<String>,
#[serde(flatten)]
pub metadata: HashMap<String, Value>,
}
#[async_trait::async_trait]
pub trait CompletionProviderTrait: Send + Sync {
async fn complete(
&self,
request: CompletionRequest,
) -> crate::error::Result<CompletionResponse>;
}
#[derive(Debug)]
pub struct StaticCompletionProvider {
items: Vec<CompletionItem>,
}
impl StaticCompletionProvider {
pub fn new(items: Vec<CompletionItem>) -> Self {
Self { items }
}
pub fn from_strings(values: Vec<String>) -> Self {
Self {
items: values
.into_iter()
.map(|value| CompletionItem {
value,
label: None,
description: None,
icon: None,
metadata: HashMap::new(),
})
.collect(),
}
}
}
#[async_trait::async_trait]
impl CompletionProviderTrait for StaticCompletionProvider {
async fn complete(
&self,
request: CompletionRequest,
) -> crate::error::Result<CompletionResponse> {
let completions: Vec<CompletionItem> = self
.items
.iter()
.filter(|item| {
item.value.starts_with(&request.partial)
|| item
.label
.as_ref()
.is_some_and(|l| l.starts_with(&request.partial))
})
.cloned()
.collect();
Ok(CompletionResponse {
completions,
has_more: false,
continuation_token: None,
})
}
}
#[derive(Debug)]
pub struct FileCompletionProvider {
base_dir: Option<std::path::PathBuf>,
extensions: Option<Vec<String>>,
}
impl FileCompletionProvider {
pub fn new() -> Self {
Self {
base_dir: None,
extensions: None,
}
}
pub fn with_base_dir(mut self, dir: std::path::PathBuf) -> Self {
self.base_dir = Some(dir);
self
}
pub fn with_extensions(mut self, extensions: Vec<String>) -> Self {
self.extensions = Some(extensions);
self
}
}
#[async_trait::async_trait]
impl CompletionProviderTrait for FileCompletionProvider {
async fn complete(
&self,
request: CompletionRequest,
) -> crate::error::Result<CompletionResponse> {
use std::path::Path;
let partial_path = Path::new(&request.partial);
let (dir_path, file_prefix) = if request.partial.ends_with(std::path::MAIN_SEPARATOR) {
(partial_path, "")
} else {
(
partial_path.parent().unwrap_or_else(|| Path::new(".")),
partial_path
.file_name()
.and_then(|s| s.to_str())
.unwrap_or(""),
)
};
let search_dir = if dir_path.is_absolute() {
dir_path.to_path_buf()
} else if let Some(base) = &self.base_dir {
base.join(dir_path)
} else {
dir_path.to_path_buf()
};
let mut completions = Vec::new();
if let Ok(entries) = std::fs::read_dir(&search_dir) {
for entry in entries.flatten() {
if let Some(name) = entry.file_name().to_str() {
if name.starts_with(file_prefix) {
let is_dir = entry.file_type().is_ok_and(|t| t.is_dir());
if !is_dir {
if let Some(exts) = &self.extensions {
let has_valid_ext = Path::new(name)
.extension()
.and_then(|e| e.to_str())
.is_some_and(|e| exts.iter().any(|ext| ext == e));
if !has_valid_ext {
continue;
}
}
}
let value = if request.partial.is_empty() {
name.to_string()
} else {
dir_path.join(name).to_string_lossy().to_string()
};
completions.push(CompletionItem {
value,
label: Some(name.to_string()),
description: if is_dir {
Some("Directory".to_string())
} else {
Some("File".to_string())
},
icon: if is_dir {
Some("📁".to_string())
} else {
Some("📄".to_string())
},
metadata: HashMap::new(),
});
}
}
}
}
completions.sort_by(|a, b| a.value.cmp(&b.value));
Ok(CompletionResponse {
completions,
has_more: false,
continuation_token: None,
})
}
}
impl Default for FileCompletionProvider {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct CompletableBuilder {
name: String,
description: Option<String>,
required: bool,
completion: Option<CompletionConfig>,
}
impl CompletableBuilder {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
description: None,
required: false,
completion: None,
}
}
pub fn description(mut self, desc: impl Into<String>) -> Self {
self.description = Some(desc.into());
self
}
pub fn required(mut self) -> Self {
self.required = true;
self
}
pub fn static_completions(mut self, values: Vec<String>) -> Self {
self.completion = Some(CompletionConfig {
provider: CompletionProvider::Static,
config: {
let mut config = HashMap::new();
config.insert("values".to_string(), serde_json::to_value(values).unwrap());
config
},
});
self
}
pub fn file_completions(mut self) -> Self {
self.completion = Some(CompletionConfig {
provider: CompletionProvider::File,
config: HashMap::new(),
});
self
}
pub fn build(self) -> CompletableArgument {
CompletableArgument {
name: self.name,
description: self.description,
required: self.required,
completion: self.completion,
}
}
}
pub fn completable(name: impl Into<String>) -> CompletableBuilder {
CompletableBuilder::new(name)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_static_completion() {
let provider = StaticCompletionProvider::from_strings(vec![
"apple".to_string(),
"apricot".to_string(),
"banana".to_string(),
"cherry".to_string(),
]);
let request = CompletionRequest {
argument: "fruit".to_string(),
partial: "ap".to_string(),
context: HashMap::new(),
};
let response = provider.complete(request).await.unwrap();
assert_eq!(response.completions.len(), 2);
assert_eq!(response.completions[0].value, "apple");
assert_eq!(response.completions[1].value, "apricot");
}
#[test]
fn test_completable_builder() {
let arg = completable("environment")
.description("Target environment")
.required()
.static_completions(vec![
"development".to_string(),
"staging".to_string(),
"production".to_string(),
])
.build();
assert_eq!(arg.name, "environment");
assert_eq!(arg.description, Some("Target environment".to_string()));
assert!(arg.required);
assert!(arg.completion.is_some());
}
}