use std::collections::BTreeMap;
use std::fmt;
use std::sync::{Arc, RwLock};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)]
pub struct CommandInfo {
pub name: String,
pub plugin: Option<String>,
pub description: Option<String>,
pub args: Vec<CommandArg>,
pub return_type: Option<String>,
pub is_async: bool,
pub intent: Option<String>,
pub category: Option<String>,
pub examples: Vec<String>,
}
#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)]
pub struct CommandArg {
pub name: String,
pub type_name: String,
pub required: bool,
pub schema: Option<serde_json::Value>,
}
#[doc(hidden)]
pub struct CommandInfoFactory(pub fn() -> CommandInfo);
inventory::collect!(CommandInfoFactory);
impl CommandInfo {
#[must_use]
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
plugin: None,
description: None,
args: Vec::new(),
return_type: None,
is_async: false,
intent: None,
category: None,
examples: Vec::new(),
}
}
#[must_use]
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
#[must_use]
pub fn with_intent(mut self, intent: impl Into<String>) -> Self {
self.intent = Some(intent.into());
self
}
#[must_use]
pub fn with_category(mut self, category: impl Into<String>) -> Self {
self.category = Some(category.into());
self
}
}
#[derive(Debug, Clone)]
pub struct CommandRegistry {
commands: Arc<RwLock<BTreeMap<String, CommandInfo>>>,
}
impl CommandRegistry {
#[must_use]
pub fn new() -> Self {
Self {
commands: Arc::new(RwLock::new(BTreeMap::new())),
}
}
pub fn register(&self, info: CommandInfo) {
self.commands
.write()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.insert(info.name.clone(), info);
}
#[must_use]
pub fn get(&self, name: &str) -> Option<CommandInfo> {
self.commands
.read()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.get(name)
.cloned()
}
#[must_use]
pub fn list(&self) -> Vec<CommandInfo> {
self.commands
.read()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.values()
.cloned()
.collect()
}
#[must_use]
pub fn count(&self) -> usize {
self.commands
.read()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.len()
}
#[must_use]
pub fn search(&self, query: &str) -> Vec<CommandInfo> {
let query_lower = query.to_lowercase();
self.commands
.read()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.values()
.filter(|cmd| {
cmd.name.to_lowercase().contains(&query_lower)
|| cmd
.description
.as_ref()
.is_some_and(|d| d.to_lowercase().contains(&query_lower))
})
.cloned()
.collect()
}
#[must_use]
pub fn resolve(&self, query: &str) -> Vec<ScoredCommand> {
let query_lower = query.to_lowercase();
let query_words: Vec<&str> = query_lower.split_whitespace().collect();
if query_words.is_empty() {
return Vec::new();
}
let mut scored: Vec<ScoredCommand> = self
.commands
.read()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.values()
.filter_map(|cmd| {
let score = score_command(cmd, &query_lower, &query_words);
if score > 0.0 {
Some(ScoredCommand {
command: cmd.clone(),
score,
})
} else {
None
}
})
.collect();
scored.sort_by(|a, b| b.score.total_cmp(&a.score));
scored
}
}
#[must_use]
pub fn auto_discovered_commands() -> Vec<CommandInfo> {
inventory::iter::<CommandInfoFactory>
.into_iter()
.map(|factory| (factory.0)())
.collect()
}
impl CommandRegistry {
#[must_use]
pub fn from_auto_discovery() -> Self {
let registry = Self::new();
for info in auto_discovered_commands() {
registry.register(info);
}
registry
}
}
impl Default for CommandRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScoredCommand {
pub command: CommandInfo,
pub score: f64,
}
impl fmt::Display for ScoredCommand {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} (score: {:.2})", self.command.name, self.score)
}
}
const SCORE_EXACT_NAME: f64 = 10.0;
const SCORE_NAME_SUBSTRING: f64 = 3.0;
const SCORE_NAME_WORD: f64 = 2.0;
const SCORE_DESCRIPTION: f64 = 1.5;
const SCORE_INTENT: f64 = 2.5;
const SCORE_CATEGORY: f64 = 1.0;
const SCORE_EXAMPLE_FULL: f64 = 4.0;
const SCORE_EXAMPLE_WORD: f64 = 0.5;
fn score_command(cmd: &CommandInfo, query_lower: &str, query_words: &[&str]) -> f64 {
let mut score = 0.0;
let mut exact_bonus = 0.0;
let name_lower = cmd.name.to_lowercase();
let name_words: Vec<&str> = name_lower.split('_').collect();
if name_lower == query_lower.replace(' ', "_") {
exact_bonus += SCORE_EXACT_NAME;
}
for word in query_words {
if name_lower.contains(word) {
score += SCORE_NAME_SUBSTRING;
}
if name_words.contains(word) {
score += SCORE_NAME_WORD;
}
}
if let Some(desc) = &cmd.description {
let desc_lower = desc.to_lowercase();
for word in query_words {
if desc_lower.contains(word) {
score += SCORE_DESCRIPTION;
}
}
}
if let Some(intent) = &cmd.intent {
let intent_lower = intent.to_lowercase();
for word in query_words {
if intent_lower.contains(word) {
score += SCORE_INTENT;
}
}
}
if let Some(category) = &cmd.category {
let cat_lower = category.to_lowercase();
for word in query_words {
if cat_lower.contains(word) {
score += SCORE_CATEGORY;
}
}
}
for example in &cmd.examples {
let ex_lower = example.to_lowercase();
if ex_lower.contains(query_lower) {
exact_bonus += SCORE_EXAMPLE_FULL;
break;
}
for word in query_words {
if ex_lower.contains(word) {
score += SCORE_EXAMPLE_WORD;
}
}
}
let word_count = query_words.len() as f64;
let per_word_score = score / word_count;
exact_bonus + per_word_score
}