1use std::collections::HashMap;
5use std::path::Path;
6
7use anyhow::Context;
8use serde::Deserialize;
9use yoagent::provider::model::{
10 ApiProtocol, CostConfig, MaxTokensField, ModelConfig, OpenAiCompat, ThinkingFormat,
11};
12
13use super::compat::{RabMaxTokensField, RabOpenAiCompat, RabThinkingFormat};
14
15#[derive(Debug, Deserialize)]
17struct ModelsJson {
18 providers: HashMap<String, ProviderDef>,
19}
20
21#[derive(Debug, Deserialize)]
23#[serde(rename_all = "camelCase")]
24struct ProviderDef {
25 #[serde(default)]
26 name: Option<String>,
27 #[serde(default)]
28 base_url: Option<String>,
29 #[serde(default)]
30 api: Option<String>,
31 #[serde(default)]
32 env: Option<HashMap<String, String>>,
33 #[serde(default)]
34 headers: Option<HashMap<String, String>>,
35 #[serde(default)]
36 models: Vec<ModelDef>,
37}
38
39#[derive(Debug, Deserialize)]
41#[serde(rename_all = "camelCase")]
42struct ModelDef {
43 id: String,
44 #[serde(default)]
45 name: Option<String>,
46 #[serde(default)]
47 api: Option<String>,
48 #[serde(default)]
49 base_url: Option<String>,
50 #[serde(default)]
51 reasoning: bool,
52 #[serde(default)]
53 #[allow(dead_code)]
54 thinking_level_map: Option<HashMap<String, serde_json::Value>>,
55 #[serde(default)]
56 input: Option<Vec<String>>,
57 #[serde(default)]
58 cost: Option<CostDef>,
59 #[serde(default)]
60 context_window: Option<u32>,
61 #[serde(default)]
62 max_tokens: Option<u32>,
63 #[serde(default)]
64 compat: Option<RabOpenAiCompat>,
65 #[serde(default)]
66 headers: Option<HashMap<String, String>>,
67}
68
69#[derive(Debug, Clone, Deserialize)]
70#[serde(rename_all = "camelCase")]
71struct CostDef {
72 input: f64,
73 output: f64,
74 #[serde(default)]
75 cache_read: f64,
76 #[serde(default)]
77 cache_write: f64,
78}
79
80#[derive(Debug, Clone)]
82pub struct ProviderEntry {
83 pub id: String,
84 pub name: String,
85 pub models: Vec<ModelConfig>,
86 pub env_var_hint: Option<String>,
87}
88
89impl ProviderEntry {
90 pub fn env_var_name(&self) -> &str {
91 self.env_var_hint.as_deref().unwrap_or("API_KEY")
92 }
93}
94
95fn parse_provider(id: &str, def: ProviderDef) -> anyhow::Result<ProviderEntry> {
97 let mut models = Vec::new();
98 let provider_api = def.api.as_deref();
99
100 for m in &def.models {
101 let api_str = m
102 .api
103 .as_deref()
104 .or(provider_api)
105 .unwrap_or("openai-completions");
106 let api = match api_str {
107 "openai-completions" => ApiProtocol::OpenAiCompletions,
108 "anthropic-messages" => ApiProtocol::AnthropicMessages,
109 "openai-responses" => ApiProtocol::OpenAiResponses,
110 "google-generative-ai" => ApiProtocol::GoogleGenerativeAi,
111 "google-vertex" => ApiProtocol::GoogleVertex,
112 "bedrock-converse-stream" => ApiProtocol::BedrockConverseStream,
113 "azure-openai-responses" => ApiProtocol::AzureOpenAiResponses,
114 _ => anyhow::bail!("Unknown API type: {}", api_str),
115 };
116
117 let base_url = m
118 .base_url
119 .clone()
120 .or_else(|| def.base_url.clone())
121 .unwrap_or_default();
122
123 let input = m.input.clone().unwrap_or_else(|| vec!["text".to_string()]);
124 let _has_image = input.iter().any(|s| s == "image");
125
126 let cost = m
127 .cost
128 .as_ref()
129 .map(|c| CostConfig {
130 input_per_million: c.input,
131 output_per_million: c.output,
132 cache_read_per_million: c.cache_read,
133 cache_write_per_million: c.cache_write,
134 })
135 .unwrap_or_default();
136
137 let context_window = m.context_window.unwrap_or(128_000);
138 let max_tokens = m.max_tokens.unwrap_or(16_384);
139
140 let compat = m.compat.clone().unwrap_or_default();
142 let compat_json = serde_json::to_string(&compat).unwrap_or_else(|_| "{}".to_string());
143
144 let yoagent_compat = if api == ApiProtocol::OpenAiCompletions {
146 Some(convert_to_yoagent_compat(&compat))
147 } else {
148 None
149 };
150
151 let mut headers = HashMap::new();
152 headers.insert("_rab_compat".to_string(), compat_json);
153 if let Some(tlm) = &m.thinking_level_map
154 && let Ok(json) = serde_json::to_string(tlm)
155 {
156 headers.insert("_rab_thinking_map".to_string(), json);
157 }
158
159 if let Some(provider_headers) = &def.headers {
161 for (k, v) in provider_headers {
162 headers.entry(k.clone()).or_insert_with(|| v.clone());
163 }
164 }
165 if let Some(model_headers) = &m.headers {
166 for (k, v) in model_headers {
167 headers.insert(k.clone(), v.clone());
168 }
169 }
170
171 let model = ModelConfig {
172 id: m.id.clone(),
173 name: m.name.clone().unwrap_or_else(|| m.id.clone()),
174 api,
175 provider: id.to_string(),
176 base_url,
177 reasoning: m.reasoning,
178 context_window,
179 max_tokens,
180 cost,
181 headers,
182 compat: yoagent_compat,
183 };
184
185 models.push(model);
186 }
187
188 let env_var = def.env.as_ref().and_then(|e| e.get("apiKey")).cloned();
189
190 Ok(ProviderEntry {
191 id: id.to_string(),
192 name: def.name.unwrap_or_else(|| id.to_string()),
193 models,
194 env_var_hint: env_var,
195 })
196}
197
198fn convert_to_yoagent_compat(rab: &RabOpenAiCompat) -> OpenAiCompat {
200 let max_tokens_field = match rab.max_tokens_field {
201 RabMaxTokensField::MaxTokens => MaxTokensField::MaxTokens,
202 RabMaxTokensField::MaxCompletionTokens => MaxTokensField::MaxCompletionTokens,
203 };
204
205 let thinking_format = match rab.thinking_format {
206 RabThinkingFormat::OpenAi
207 | RabThinkingFormat::OpenRouter
208 | RabThinkingFormat::DeepSeek
209 | RabThinkingFormat::Together
210 | RabThinkingFormat::Zai
211 | RabThinkingFormat::ChatTemplate
212 | RabThinkingFormat::QwenChatTemplate
213 | RabThinkingFormat::StringThinking
214 | RabThinkingFormat::AntLing => ThinkingFormat::OpenAi,
215 RabThinkingFormat::Qwen => ThinkingFormat::Qwen,
216 };
217
218 OpenAiCompat {
219 supports_store: rab.supports_store,
220 supports_developer_role: rab.supports_developer_role,
221 supports_reasoning_effort: rab.supports_reasoning_effort,
222 supports_thinking_control: rab.supports_thinking_control
223 || rab.thinking_format == RabThinkingFormat::DeepSeek,
224 supports_usage_in_streaming: rab.supports_usage_in_streaming,
225 max_tokens_field,
226 requires_tool_result_name: rab.requires_tool_result_name,
227 requires_assistant_after_tool_result: rab.requires_assistant_after_tool_result,
228 thinking_format,
229 }
230}
231
232pub fn load_builtin(builtin_json: &str) -> anyhow::Result<Vec<ProviderEntry>> {
234 let parsed: ModelsJson =
235 serde_json::from_str(builtin_json).context("Failed to parse built-in models.json")?;
236
237 let mut entries = Vec::new();
238 for (id, def) in parsed.providers {
239 match parse_provider(&id, def) {
240 Ok(entry) => entries.push(entry),
241 Err(e) => {
242 eprintln!("Warning: skipping provider '{}': {}", id, e);
243 }
244 }
245 }
246 Ok(entries)
247}
248
249pub fn load_user(path: &Path) -> anyhow::Result<Vec<ProviderEntry>> {
251 if !path.exists() {
252 return Ok(Vec::new());
253 }
254 let content = std::fs::read_to_string(path)
255 .with_context(|| format!("Failed to read {}", path.display()))?;
256 load_builtin(&content)
257}
258
259pub fn merge(builtin: Vec<ProviderEntry>, user: Vec<ProviderEntry>) -> Vec<ProviderEntry> {
262 let mut map: HashMap<String, ProviderEntry> = HashMap::new();
263 for entry in builtin {
264 map.insert(entry.id.clone(), entry);
265 }
266 for entry in user {
267 map.insert(entry.id.clone(), entry);
268 }
269 map.into_values().collect()
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn test_parse_builtin() {
278 let json = r#"{
279 "providers": {
280 "test-provider": {
281 "name": "Test",
282 "baseUrl": "https://test.example/v1",
283 "api": "openai-completions",
284 "env": { "apiKey": "TEST_API_KEY" },
285 "models": [
286 {
287 "id": "test-model",
288 "name": "Test Model",
289 "reasoning": true,
290 "cost": { "input": 1.0, "output": 2.0 },
291 "contextWindow": 100000,
292 "maxTokens": 32000
293 }
294 ]
295 }
296 }
297 }"#;
298 let entries = load_builtin(json).unwrap();
299 assert_eq!(entries.len(), 1);
300 let entry = &entries[0];
301 assert_eq!(entry.id, "test-provider");
302 assert_eq!(entry.models.len(), 1);
303 let model = &entry.models[0];
304 assert_eq!(model.id, "test-model");
305 assert_eq!(model.api, ApiProtocol::OpenAiCompletions);
306 assert!(model.reasoning);
307 assert!(model.headers.contains_key("_rab_compat"));
308 assert_eq!(model.cost.input_per_million as u32, 1);
309 }
310
311 #[test]
312 fn test_merge_user_overrides_builtin() {
313 let builtin = load_builtin(r#"{"providers":{"p1":{"name":"Builtin","baseUrl":"https://builtin.example","models":[{"id":"m1","cost":{"input":1,"output":2},"contextWindow":1000,"maxTokens":500}]}}}"#).unwrap();
314 let user = load_builtin(r#"{"providers":{"p1":{"name":"User","baseUrl":"https://user.example","models":[{"id":"m1","cost":{"input":3,"output":4},"contextWindow":2000,"maxTokens":1000}]}}}"#).unwrap();
315 let merged = merge(builtin, user);
316 assert_eq!(merged.len(), 1);
317 assert_eq!(merged[0].name, "User");
318 assert_eq!(merged[0].models[0].cost.input_per_million as u32, 3);
319 }
320}