Skip to main content

agent_diva_providers/
registry.rs

1//! Provider registry - single source of truth for LLM provider metadata
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6/// API specification type
7#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
8#[serde(rename_all = "snake_case")]
9pub enum ApiType {
10    #[default]
11    Openai,
12    Anthropic,
13    Google,
14    Other,
15}
16
17/// One LLM provider's metadata
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ProviderSpec {
20    // Identity
21    pub name: String,
22    #[serde(default)]
23    pub api_type: ApiType,
24    pub keywords: Vec<String>,
25    pub env_key: String,
26    pub display_name: String,
27    #[serde(default)]
28    pub default_model: Option<String>,
29
30    // Model prefixing
31    pub litellm_prefix: String,
32    pub skip_prefixes: Vec<String>,
33
34    // Extra env vars
35    pub env_extras: Vec<(String, String)>,
36
37    pub default_api_base: String,
38
39    // Prompt caching support
40    #[serde(default)]
41    pub supports_prompt_caching: bool,
42
43    // Models list
44    #[serde(default)]
45    pub models: Vec<String>,
46
47    // Per-model param overrides
48    pub model_overrides: Vec<(String, HashMap<String, serde_json::Value>)>,
49}
50
51impl ProviderSpec {
52    pub fn label(&self) -> String {
53        if !self.display_name.is_empty() {
54            self.display_name.clone()
55        } else {
56            let mut name = self.name.clone();
57            if let Some(first_char) = name.chars().next() {
58                name = first_char.to_uppercase().to_string() + &name[first_char.len_utf8()..];
59            }
60            name
61        }
62    }
63
64    pub fn default_model(&self) -> Option<&str> {
65        self.default_model
66            .as_deref()
67            .map(str::trim)
68            .filter(|value| !value.is_empty())
69    }
70}
71
72/// Registry of available LLM providers
73pub struct ProviderRegistry {
74    providers: Vec<ProviderSpec>,
75}
76
77impl ProviderRegistry {
78    /// Create a new provider registry with default providers
79    pub fn new() -> Self {
80        Self {
81            providers: Self::default_providers(),
82        }
83    }
84
85    /// Get all provider specs
86    pub fn all(&self) -> &[ProviderSpec] {
87        &self.providers
88    }
89
90    /// Find a provider by model name (case-insensitive keyword matching)
91    pub fn find_by_model(&self, model: &str) -> Option<&ProviderSpec> {
92        let model_lower = model.to_lowercase();
93        self.providers
94            .iter()
95            .find(|spec| spec.keywords.iter().any(|kw| model_lower.contains(kw)))
96    }
97
98    /// Find a provider by config field name
99    pub fn find_by_name(&self, name: &str) -> Option<&ProviderSpec> {
100        self.providers.iter().find(|spec| spec.name == name)
101    }
102
103    fn default_providers() -> Vec<ProviderSpec> {
104        let yaml = include_str!("providers.yaml");
105        serde_yaml::from_str(yaml).expect("Failed to parse default providers configuration")
106    }
107}
108
109impl Default for ProviderRegistry {
110    fn default() -> Self {
111        Self::new()
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    #[test]
120    fn test_find_by_model() {
121        let registry = ProviderRegistry::new();
122
123        // Test Claude
124        let spec = registry.find_by_model("claude-3-opus");
125        assert!(spec.is_some());
126        assert_eq!(spec.unwrap().name, "anthropic");
127
128        // Test DeepSeek
129        let spec = registry.find_by_model("deepseek-chat");
130        assert!(spec.is_some());
131        assert_eq!(spec.unwrap().name, "deepseek");
132
133        // Test Qwen
134        let spec = registry.find_by_model("qwen-max");
135        assert!(spec.is_some());
136        assert_eq!(spec.unwrap().name, "dashscope");
137
138        // Test MiniMax
139        let spec = registry.find_by_model("MiniMax-M2.1");
140        assert!(spec.is_some());
141        assert_eq!(spec.unwrap().name, "minimax");
142    }
143
144    #[test]
145    fn test_find_by_name() {
146        let registry = ProviderRegistry::new();
147        let spec = registry.find_by_name("anthropic");
148        assert!(spec.is_some());
149        assert_eq!(spec.unwrap().display_name, "Anthropic");
150    }
151}