Skip to main content

rab/provider/
models.rs

1//! models.json — parses built-in and user models.json files
2//! and constructs yoagent ModelConfigs with rich compat stored in headers.
3
4use std::collections::HashMap;
5use std::path::Path;
6
7use anyhow::Context;
8use serde::Deserialize;
9use yoagent::provider::model::{
10    ApiProtocol, CostConfig, MaxTokensField, ModelConfig, OpenAiCompat, ThinkingFormat,
11};
12
13use super::compat::{RabMaxTokensField, RabOpenAiCompat, RabThinkingFormat};
14
15/// Root structure of models.json
16#[derive(Debug, Deserialize)]
17struct ModelsJson {
18    providers: HashMap<String, ProviderDef>,
19}
20
21/// A provider definition in models.json
22#[derive(Debug, Deserialize)]
23#[serde(rename_all = "camelCase")]
24struct ProviderDef {
25    #[serde(default)]
26    name: Option<String>,
27    #[serde(default)]
28    base_url: Option<String>,
29    #[serde(default)]
30    api: Option<String>,
31    #[serde(default)]
32    env: Option<HashMap<String, String>>,
33    #[serde(default)]
34    headers: Option<HashMap<String, String>>,
35    #[serde(default)]
36    models: Vec<ModelDef>,
37}
38
39/// A single model entry in models.json
40#[derive(Debug, Deserialize)]
41#[serde(rename_all = "camelCase")]
42struct ModelDef {
43    id: String,
44    #[serde(default)]
45    name: Option<String>,
46    #[serde(default)]
47    api: Option<String>,
48    #[serde(default)]
49    base_url: Option<String>,
50    #[serde(default)]
51    reasoning: bool,
52    #[serde(default)]
53    #[allow(dead_code)]
54    thinking_level_map: Option<HashMap<String, serde_json::Value>>,
55    #[serde(default)]
56    input: Option<Vec<String>>,
57    #[serde(default)]
58    cost: Option<CostDef>,
59    #[serde(default)]
60    context_window: Option<u32>,
61    #[serde(default)]
62    max_tokens: Option<u32>,
63    #[serde(default)]
64    compat: Option<RabOpenAiCompat>,
65    #[serde(default)]
66    headers: Option<HashMap<String, String>>,
67}
68
69#[derive(Debug, Clone, Deserialize)]
70#[serde(rename_all = "camelCase")]
71struct CostDef {
72    input: f64,
73    output: f64,
74    #[serde(default)]
75    cache_read: f64,
76    #[serde(default)]
77    cache_write: f64,
78}
79
80/// A resolved provider entry in the registry.
81#[derive(Debug, Clone)]
82pub struct ProviderEntry {
83    pub id: String,
84    pub name: String,
85    pub models: Vec<ModelConfig>,
86    pub env_var_hint: Option<String>,
87}
88
89impl ProviderEntry {
90    pub fn env_var_name(&self) -> &str {
91        self.env_var_hint.as_deref().unwrap_or("API_KEY")
92    }
93}
94
95/// Parse a single provider definition from models.json into a `ProviderEntry`.
96fn parse_provider(id: &str, def: ProviderDef) -> anyhow::Result<ProviderEntry> {
97    let mut models = Vec::new();
98    let provider_api = def.api.as_deref();
99
100    for m in &def.models {
101        let api_str = m
102            .api
103            .as_deref()
104            .or(provider_api)
105            .unwrap_or("openai-completions");
106        let api = match api_str {
107            "openai-completions" => ApiProtocol::OpenAiCompletions,
108            "anthropic-messages" => ApiProtocol::AnthropicMessages,
109            "openai-responses" => ApiProtocol::OpenAiResponses,
110            "google-generative-ai" => ApiProtocol::GoogleGenerativeAi,
111            "google-vertex" => ApiProtocol::GoogleVertex,
112            "bedrock-converse-stream" => ApiProtocol::BedrockConverseStream,
113            "azure-openai-responses" => ApiProtocol::AzureOpenAiResponses,
114            _ => anyhow::bail!("Unknown API type: {}", api_str),
115        };
116
117        let base_url = m
118            .base_url
119            .clone()
120            .or_else(|| def.base_url.clone())
121            .unwrap_or_default();
122
123        let input = m.input.clone().unwrap_or_else(|| vec!["text".to_string()]);
124        let _has_image = input.iter().any(|s| s == "image");
125
126        let cost = m
127            .cost
128            .as_ref()
129            .map(|c| CostConfig {
130                input_per_million: c.input,
131                output_per_million: c.output,
132                cache_read_per_million: c.cache_read,
133                cache_write_per_million: c.cache_write,
134            })
135            .unwrap_or_default();
136
137        let context_window = m.context_window.unwrap_or(128_000);
138        let max_tokens = m.max_tokens.unwrap_or(16_384);
139
140        // Build the compat and store it as JSON in headers["_rab_compat"]
141        let compat = m.compat.clone().unwrap_or_default();
142        let compat_json = serde_json::to_string(&compat).unwrap_or_else(|_| "{}".to_string());
143
144        // Also build yoagent's OpenAiCompat for models that use openai-completions
145        let yoagent_compat = if api == ApiProtocol::OpenAiCompletions {
146            Some(convert_to_yoagent_compat(&compat))
147        } else {
148            None
149        };
150
151        let mut headers = HashMap::new();
152        headers.insert("_rab_compat".to_string(), compat_json);
153        if let Some(tlm) = &m.thinking_level_map
154            && let Ok(json) = serde_json::to_string(tlm)
155        {
156            headers.insert("_rab_thinking_map".to_string(), json);
157        }
158
159        // Merge provider-level headers, then model-level headers
160        if let Some(provider_headers) = &def.headers {
161            for (k, v) in provider_headers {
162                headers.entry(k.clone()).or_insert_with(|| v.clone());
163            }
164        }
165        if let Some(model_headers) = &m.headers {
166            for (k, v) in model_headers {
167                headers.insert(k.clone(), v.clone());
168            }
169        }
170
171        let model = ModelConfig {
172            id: m.id.clone(),
173            name: m.name.clone().unwrap_or_else(|| m.id.clone()),
174            api,
175            provider: id.to_string(),
176            base_url,
177            reasoning: m.reasoning,
178            context_window,
179            max_tokens,
180            cost,
181            headers,
182            compat: yoagent_compat,
183        };
184
185        models.push(model);
186    }
187
188    let env_var = def.env.as_ref().and_then(|e| e.get("apiKey")).cloned();
189
190    Ok(ProviderEntry {
191        id: id.to_string(),
192        name: def.name.unwrap_or_else(|| id.to_string()),
193        models,
194        env_var_hint: env_var,
195    })
196}
197
198/// Convert our rich compat to yoagent's OpenAiCompat for the fields they share.
199fn convert_to_yoagent_compat(rab: &RabOpenAiCompat) -> OpenAiCompat {
200    let max_tokens_field = match rab.max_tokens_field {
201        RabMaxTokensField::MaxTokens => MaxTokensField::MaxTokens,
202        RabMaxTokensField::MaxCompletionTokens => MaxTokensField::MaxCompletionTokens,
203    };
204
205    let thinking_format = match rab.thinking_format {
206        RabThinkingFormat::OpenAi
207        | RabThinkingFormat::OpenRouter
208        | RabThinkingFormat::DeepSeek
209        | RabThinkingFormat::Together
210        | RabThinkingFormat::Zai
211        | RabThinkingFormat::ChatTemplate
212        | RabThinkingFormat::QwenChatTemplate
213        | RabThinkingFormat::StringThinking
214        | RabThinkingFormat::AntLing => ThinkingFormat::OpenAi,
215        RabThinkingFormat::Qwen => ThinkingFormat::Qwen,
216    };
217
218    OpenAiCompat {
219        supports_store: rab.supports_store,
220        supports_developer_role: rab.supports_developer_role,
221        supports_reasoning_effort: rab.supports_reasoning_effort,
222        supports_thinking_control: rab.supports_thinking_control
223            || rab.thinking_format == RabThinkingFormat::DeepSeek,
224        supports_usage_in_streaming: rab.supports_usage_in_streaming,
225        max_tokens_field,
226        requires_tool_result_name: rab.requires_tool_result_name,
227        requires_assistant_after_tool_result: rab.requires_assistant_after_tool_result,
228        thinking_format,
229    }
230}
231
232/// Load providers from an embedded JSON string (from `include_str!`).
233pub fn load_builtin(builtin_json: &str) -> anyhow::Result<Vec<ProviderEntry>> {
234    let parsed: ModelsJson =
235        serde_json::from_str(builtin_json).context("Failed to parse built-in models.json")?;
236
237    let mut entries = Vec::new();
238    for (id, def) in parsed.providers {
239        match parse_provider(&id, def) {
240            Ok(entry) => entries.push(entry),
241            Err(e) => {
242                eprintln!("Warning: skipping provider '{}': {}", id, e);
243            }
244        }
245    }
246    Ok(entries)
247}
248
249/// Load providers from a user's models.json file (returns empty vec if file missing).
250pub fn load_user(path: &Path) -> anyhow::Result<Vec<ProviderEntry>> {
251    if !path.exists() {
252        return Ok(Vec::new());
253    }
254    let content = std::fs::read_to_string(path)
255        .with_context(|| format!("Failed to read {}", path.display()))?;
256    load_builtin(&content)
257}
258
259/// Merge user providers on top of built-in providers.
260/// User providers with the same `id` replace built-in entries entirely.
261pub fn merge(builtin: Vec<ProviderEntry>, user: Vec<ProviderEntry>) -> Vec<ProviderEntry> {
262    let mut map: HashMap<String, ProviderEntry> = HashMap::new();
263    for entry in builtin {
264        map.insert(entry.id.clone(), entry);
265    }
266    for entry in user {
267        map.insert(entry.id.clone(), entry);
268    }
269    map.into_values().collect()
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn test_parse_builtin() {
278        let json = r#"{
279            "providers": {
280                "test-provider": {
281                    "name": "Test",
282                    "baseUrl": "https://test.example/v1",
283                    "api": "openai-completions",
284                    "env": { "apiKey": "TEST_API_KEY" },
285                    "models": [
286                        {
287                            "id": "test-model",
288                            "name": "Test Model",
289                            "reasoning": true,
290                            "cost": { "input": 1.0, "output": 2.0 },
291                            "contextWindow": 100000,
292                            "maxTokens": 32000
293                        }
294                    ]
295                }
296            }
297        }"#;
298        let entries = load_builtin(json).unwrap();
299        assert_eq!(entries.len(), 1);
300        let entry = &entries[0];
301        assert_eq!(entry.id, "test-provider");
302        assert_eq!(entry.models.len(), 1);
303        let model = &entry.models[0];
304        assert_eq!(model.id, "test-model");
305        assert_eq!(model.api, ApiProtocol::OpenAiCompletions);
306        assert!(model.reasoning);
307        assert!(model.headers.contains_key("_rab_compat"));
308        assert_eq!(model.cost.input_per_million as u32, 1);
309    }
310
311    #[test]
312    fn test_merge_user_overrides_builtin() {
313        let builtin = load_builtin(r#"{"providers":{"p1":{"name":"Builtin","baseUrl":"https://builtin.example","models":[{"id":"m1","cost":{"input":1,"output":2},"contextWindow":1000,"maxTokens":500}]}}}"#).unwrap();
314        let user = load_builtin(r#"{"providers":{"p1":{"name":"User","baseUrl":"https://user.example","models":[{"id":"m1","cost":{"input":3,"output":4},"contextWindow":2000,"maxTokens":1000}]}}}"#).unwrap();
315        let merged = merge(builtin, user);
316        assert_eq!(merged.len(), 1);
317        assert_eq!(merged[0].name, "User");
318        assert_eq!(merged[0].models[0].cost.input_per_million as u32, 3);
319    }
320}