collet 0.1.0

Relentless agentic coding orchestrator with zero-drop agent loops
Documentation
use std::collections::HashMap;

use serde::Deserialize;
use tracing::warn;

/// Static mapping from collet provider names to LiteLLM provider name(s).
const PROVIDER_NAME_MAP: &[(&str, &[&str])] = &[
    ("zai-coding", &["zai"]),
    ("openai", &["openai"]),
    ("deepseek", &["deepseek"]),
    ("anthropic", &["anthropic"]),
    ("google-gemini", &["gemini"]),
    ("groq", &["groq"]),
    ("mistral", &["mistral"]),
    ("together-ai", &["together_ai"]),
    ("openrouter", &["openrouter"]),
    ("ollama", &["ollama"]),
    ("lm-studio", &["ollama"]),
    ("azure-openai", &["azure", "azure_ai"]),
    ("bedrock", &["bedrock", "bedrock_converse"]),
];

/// A parsed model entry from the LiteLLM registry.
#[derive(Debug, Clone)]
pub struct RegistryModel {
    /// LiteLLM key, e.g. `"deepseek/deepseek-chat"`.
    pub key: String,
    /// LiteLLM provider, e.g. `"deepseek"`.
    pub litellm_provider: String,
    pub max_input_tokens: Option<u32>,
    pub max_output_tokens: Option<u32>,
    pub supports_function_calling: bool,
    pub supports_vision: bool,
    /// Cost per **million** input tokens (converted from per-token).
    pub input_cost_per_million: Option<f64>,
    /// Cost per **million** output tokens (converted from per-token).
    pub output_cost_per_million: Option<f64>,
}

/// Intermediate serde target matching the LiteLLM JSON shape.
#[derive(Deserialize)]
struct RawEntry {
    litellm_provider: Option<String>,
    #[serde(default, deserialize_with = "lenient_u32::deserialize")]
    max_input_tokens: Option<u32>,
    #[serde(default, deserialize_with = "lenient_u32::deserialize")]
    max_output_tokens: Option<u32>,
    /// Fallback for `max_output_tokens` when that field is absent.
    #[serde(default, deserialize_with = "lenient_u32::deserialize")]
    max_tokens: Option<u32>,
    supports_function_calling: Option<bool>,
    supports_vision: Option<bool>,
    input_cost_per_token: Option<f64>,
    output_cost_per_token: Option<f64>,
}

/// Deserializes a field as `Option<u32>`, treating non-numeric values
/// (e.g. litellm description strings) as `None` instead of erroring.
mod lenient_u32 {
    use serde::Deserializer;

    pub fn deserialize<'de, D>(d: D) -> Result<Option<u32>, D::Error>
    where
        D: Deserializer<'de>,
    {
        struct Visitor;
        impl serde::de::Visitor<'_> for Visitor {
            type Value = Option<u32>;
            fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
                f.write_str("u32 or string")
            }
            fn visit_u64<E: serde::de::Error>(self, v: u64) -> Result<Self::Value, E> {
                Ok(Some(v as u32))
            }
            fn visit_i64<E: serde::de::Error>(self, v: i64) -> Result<Self::Value, E> {
                Ok(u32::try_from(v).ok())
            }
            fn visit_f64<E: serde::de::Error>(self, v: f64) -> Result<Self::Value, E> {
                Ok(Some(v as u32))
            }
            fn visit_str<E: serde::de::Error>(self, _: &str) -> Result<Self::Value, E> {
                Ok(None)
            }
            fn visit_none<E: serde::de::Error>(self) -> Result<Self::Value, E> {
                Ok(None)
            }
            fn visit_unit<E: serde::de::Error>(self) -> Result<Self::Value, E> {
                Ok(None)
            }
        }
        d.deserialize_any(Visitor)
    }
}

/// In-memory index of all models parsed from a LiteLLM pricing JSON.
#[derive(Debug, Clone)]
pub struct ModelRegistry {
    models: Vec<RegistryModel>,
}

impl ModelRegistry {
    /// Parse the LiteLLM `model_prices_and_context_window.json` bytes into a
    /// registry. Entries without a `litellm_provider` or whose key is
    /// `"sample_spec"` are silently skipped.
    pub fn parse(data: &[u8]) -> Self {
        let raw: HashMap<String, RawEntry> = match serde_json::from_slice(data) {
            Ok(map) => map,
            Err(e) => {
                warn!("failed to parse model registry JSON: {e}");
                return Self::empty();
            }
        };

        let mut models: Vec<RegistryModel> = raw
            .into_iter()
            .filter(|(key, _)| key != "sample_spec")
            .filter_map(|(key, entry)| {
                let provider = entry.litellm_provider?;
                Some(RegistryModel {
                    key,
                    litellm_provider: provider,
                    max_input_tokens: entry.max_input_tokens,
                    max_output_tokens: entry.max_output_tokens.or(entry.max_tokens),
                    supports_function_calling: entry.supports_function_calling.unwrap_or(false),
                    supports_vision: entry.supports_vision.unwrap_or(false),
                    input_cost_per_million: entry.input_cost_per_token.map(|c| c * 1_000_000.0),
                    output_cost_per_million: entry.output_cost_per_token.map(|c| c * 1_000_000.0),
                })
            })
            .collect();

        models.sort_by(|a, b| a.key.cmp(&b.key));

        Self { models }
    }

    /// Return an empty registry.
    pub fn empty() -> Self {
        Self { models: Vec::new() }
    }

    /// Whether the registry contains zero models.
    pub fn is_empty(&self) -> bool {
        self.models.is_empty()
    }

    /// Return the total number of models in the registry.
    pub fn len(&self) -> usize {
        self.models.len()
    }

    /// Return all models whose `litellm_provider` matches the given collet
    /// provider name (resolved through [`PROVIDER_NAME_MAP`]).
    ///
    /// If the collet name has no explicit mapping, the name itself is tried as
    /// a direct `litellm_provider` match.
    pub fn models_for_provider(&self, collet_provider: &str) -> Vec<&RegistryModel> {
        let litellm_names = resolve_provider(collet_provider);
        let raw: Vec<&RegistryModel> = if litellm_names.is_empty() {
            self.models
                .iter()
                .filter(|m| m.litellm_provider == collet_provider)
                .collect()
        } else {
            self.models
                .iter()
                .filter(|m| litellm_names.iter().any(|n| *n == m.litellm_provider))
                .collect()
        };

        // Filter out fine-tuning template entries (ft:...) and deduplicate by
        // display name so the wizard doesn't show identical-looking items.
        let mut seen = std::collections::HashSet::new();
        raw.into_iter()
            .filter(|m| {
                let display = Self::model_name(&m.key);
                !display.starts_with("ft:") && seen.insert(display.to_string())
            })
            .collect()
    }

    /// Extract the model name portion from a LiteLLM key.
    ///
    /// `"deepseek/deepseek-chat"` → `"deepseek-chat"`.
    /// If the key contains no `/`, the whole key is returned.
    pub fn model_name(key: &str) -> &str {
        key.rsplit_once('/').map_or(key, |(_, name)| name)
    }

    /// Find the best-matching [`RegistryModel`] for a given collet provider and
    /// model name.  Matching order:
    ///
    /// 1. Exact key match (`"deepseek/deepseek-chat"` == model)
    /// 2. Trailing path match (`key` ends with `"/model_name"`)
    /// 3. Substring match (`key` contains `model_name`)
    pub fn find_model(&self, collet_provider: &str, model_name: &str) -> Option<&RegistryModel> {
        let candidates = self.models_for_provider(collet_provider);
        // 1. Exact key
        if let Some(m) = candidates.iter().find(|m| m.key == model_name) {
            return Some(m);
        }
        // 2. Trailing path component
        let suffix = format!("/{model_name}");
        if let Some(m) = candidates.iter().find(|m| m.key.ends_with(&suffix)) {
            return Some(m);
        }
        // 3. Substring (handles versioned names like "deepseek-chat-v2")
        candidates.into_iter().find(|m| {
            Self::model_name(&m.key).contains(model_name)
                || model_name.contains(Self::model_name(&m.key))
        })
    }

    /// Return all unique `litellm_provider` values, sorted alphabetically.
    pub fn providers(&self) -> Vec<String> {
        self.models
            .iter()
            .map(|m| m.litellm_provider.clone())
            .collect::<std::collections::BTreeSet<_>>()
            .into_iter()
            .collect()
    }
}

/// Resolve a collet provider name to one or more LiteLLM provider names.
fn resolve_provider(collet_name: &str) -> Vec<&'static str> {
    for &(collet, litellm_names) in PROVIDER_NAME_MAP {
        if collet == collet_name {
            return litellm_names.to_vec();
        }
    }
    // No mapping found — caller handles the direct-match fallback.
    vec![]
}

#[cfg(test)]
mod tests {
    use super::*;

    fn sample_json() -> &'static [u8] {
        br#"{
  "sample_spec": {
    "litellm_provider": "sample",
    "max_input_tokens": 100
  },
  "deepseek/deepseek-chat": {
    "litellm_provider": "deepseek",
    "max_input_tokens": 128000,
    "max_output_tokens": 8192,
    "supports_function_calling": true,
    "supports_vision": false,
    "input_cost_per_token": 0.000003,
    "output_cost_per_token": 0.000008
  },
  "openai/gpt-4o": {
    "litellm_provider": "openai",
    "max_input_tokens": 128000,
    "max_tokens": 4096,
    "supports_function_calling": true,
    "supports_vision": true,
    "input_cost_per_token": 0.000005,
    "output_cost_per_token": 0.000015
  },
  "no-provider-model": {
    "max_input_tokens": 1000
  }
}"#
    }

    #[test]
    fn parse_skips_sample_spec_and_missing_provider() {
        let reg = ModelRegistry::parse(sample_json());
        assert_eq!(reg.models.len(), 2);
        assert!(reg.models.iter().all(|m| m.key != "sample_spec"));
        assert!(reg.models.iter().all(|m| m.key != "no-provider-model"));
    }

    #[test]
    fn parse_sorts_by_key() {
        let reg = ModelRegistry::parse(sample_json());
        assert_eq!(reg.models[0].key, "deepseek/deepseek-chat");
        assert_eq!(reg.models[1].key, "openai/gpt-4o");
    }

    #[test]
    fn cost_conversion() {
        let reg = ModelRegistry::parse(sample_json());
        let ds = reg
            .models
            .iter()
            .find(|m| m.key == "deepseek/deepseek-chat")
            .unwrap();
        let input = ds.input_cost_per_million.unwrap();
        assert!((input - 3.0).abs() < 1e-9, "expected 3.0, got {input}");
        let output = ds.output_cost_per_million.unwrap();
        assert!((output - 8.0).abs() < 1e-9, "expected 8.0, got {output}");
    }

    #[test]
    fn max_tokens_fallback() {
        let reg = ModelRegistry::parse(sample_json());
        let gpt = reg
            .models
            .iter()
            .find(|m| m.key == "openai/gpt-4o")
            .unwrap();
        assert_eq!(gpt.max_output_tokens, Some(4096));
    }

    #[test]
    fn models_for_provider_mapping() {
        let reg = ModelRegistry::parse(sample_json());
        let ds_models = reg.models_for_provider("deepseek");
        assert_eq!(ds_models.len(), 1);
        assert_eq!(ds_models[0].key, "deepseek/deepseek-chat");
    }

    #[test]
    fn models_for_provider_direct_fallback() {
        let reg = ModelRegistry::parse(sample_json());
        // Provider not in PROVIDER_NAME_MAP — falls back to direct match.
        let direct = reg.models_for_provider("openai");
        assert_eq!(direct.len(), 1);
        assert_eq!(direct[0].key, "openai/gpt-4o");
    }

    #[test]
    fn find_model_exact_key() {
        let reg = ModelRegistry::parse(sample_json());
        let m = reg.find_model("deepseek", "deepseek/deepseek-chat");
        assert!(m.is_some());
        assert_eq!(m.unwrap().key, "deepseek/deepseek-chat");
    }

    #[test]
    fn find_model_trailing_path() {
        let reg = ModelRegistry::parse(sample_json());
        let m = reg.find_model("deepseek", "deepseek-chat");
        assert!(m.is_some(), "should match via trailing path component");
        assert_eq!(m.unwrap().key, "deepseek/deepseek-chat");
    }

    #[test]
    fn find_model_openai() {
        let reg = ModelRegistry::parse(sample_json());
        let m = reg.find_model("openai", "gpt-4o");
        assert!(m.is_some());
        assert_eq!(m.unwrap().key, "openai/gpt-4o");
    }

    #[test]
    fn find_model_unknown_returns_none() {
        let reg = ModelRegistry::parse(sample_json());
        let m = reg.find_model("deepseek", "nonexistent-model");
        assert!(m.is_none());
    }

    #[test]
    fn model_name_extraction() {
        assert_eq!(
            ModelRegistry::model_name("deepseek/deepseek-chat"),
            "deepseek-chat"
        );
        assert_eq!(ModelRegistry::model_name("gpt-4o"), "gpt-4o");
        assert_eq!(ModelRegistry::model_name("a/b/c"), "c");
    }

    #[test]
    fn providers_sorted() {
        let reg = ModelRegistry::parse(sample_json());
        let providers = reg.providers();
        assert_eq!(providers, vec!["deepseek", "openai"]);
    }

    #[test]
    fn empty_registry() {
        let reg = ModelRegistry::empty();
        assert!(reg.is_empty());
        assert_eq!(reg.providers().len(), 0);
    }

    #[test]
    fn invalid_json_returns_empty() {
        let reg = ModelRegistry::parse(b"not json");
        assert!(reg.is_empty());
    }
}