agent_diva_providers/
registry.rs1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
8#[serde(rename_all = "snake_case")]
9pub enum ApiType {
10 #[default]
11 Openai,
12 Anthropic,
13 Google,
14 Other,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ProviderSpec {
20 pub name: String,
22 #[serde(default)]
23 pub api_type: ApiType,
24 pub keywords: Vec<String>,
25 pub env_key: String,
26 pub display_name: String,
27 #[serde(default)]
28 pub default_model: Option<String>,
29
30 pub litellm_prefix: String,
32 pub skip_prefixes: Vec<String>,
33
34 pub env_extras: Vec<(String, String)>,
36
37 pub default_api_base: String,
38
39 #[serde(default)]
41 pub supports_prompt_caching: bool,
42
43 #[serde(default)]
45 pub models: Vec<String>,
46
47 pub model_overrides: Vec<(String, HashMap<String, serde_json::Value>)>,
49}
50
51impl ProviderSpec {
52 pub fn label(&self) -> String {
53 if !self.display_name.is_empty() {
54 self.display_name.clone()
55 } else {
56 let mut name = self.name.clone();
57 if let Some(first_char) = name.chars().next() {
58 name = first_char.to_uppercase().to_string() + &name[first_char.len_utf8()..];
59 }
60 name
61 }
62 }
63
64 pub fn default_model(&self) -> Option<&str> {
65 self.default_model
66 .as_deref()
67 .map(str::trim)
68 .filter(|value| !value.is_empty())
69 }
70}
71
72pub struct ProviderRegistry {
74 providers: Vec<ProviderSpec>,
75}
76
77impl ProviderRegistry {
78 pub fn new() -> Self {
80 Self {
81 providers: Self::default_providers(),
82 }
83 }
84
85 pub fn all(&self) -> &[ProviderSpec] {
87 &self.providers
88 }
89
90 pub fn find_by_model(&self, model: &str) -> Option<&ProviderSpec> {
92 let model_lower = model.to_lowercase();
93 self.providers
94 .iter()
95 .find(|spec| spec.keywords.iter().any(|kw| model_lower.contains(kw)))
96 }
97
98 pub fn find_by_name(&self, name: &str) -> Option<&ProviderSpec> {
100 self.providers.iter().find(|spec| spec.name == name)
101 }
102
103 fn default_providers() -> Vec<ProviderSpec> {
104 let yaml = include_str!("providers.yaml");
105 serde_yaml::from_str(yaml).expect("Failed to parse default providers configuration")
106 }
107}
108
109impl Default for ProviderRegistry {
110 fn default() -> Self {
111 Self::new()
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
120 fn test_find_by_model() {
121 let registry = ProviderRegistry::new();
122
123 let spec = registry.find_by_model("claude-3-opus");
125 assert!(spec.is_some());
126 assert_eq!(spec.unwrap().name, "anthropic");
127
128 let spec = registry.find_by_model("deepseek-chat");
130 assert!(spec.is_some());
131 assert_eq!(spec.unwrap().name, "deepseek");
132
133 let spec = registry.find_by_model("qwen-max");
135 assert!(spec.is_some());
136 assert_eq!(spec.unwrap().name, "dashscope");
137
138 let spec = registry.find_by_model("MiniMax-M2.1");
140 assert!(spec.is_some());
141 assert_eq!(spec.unwrap().name, "minimax");
142 }
143
144 #[test]
145 fn test_find_by_name() {
146 let registry = ProviderRegistry::new();
147 let spec = registry.find_by_name("anthropic");
148 assert!(spec.is_some());
149 assert_eq!(spec.unwrap().display_name, "Anthropic");
150 }
151}