Skip to main content

rab/provider/
generate_models.rs

1//! `rab generate-models` subcommand.
2//!
3//! Fetches https://models.dev/api.json, applies pi-style corrections,
4//! and writes `provider/models.json` in the repo root.
5//!
6//! All-or-nothing: any error aborts before writing.
7
8use serde_json::Value;
9
10const MODELS_DEV_URL: &str = "https://models.dev/api.json";
11/// Relative path to the models catalog, checked against CWD.
12const OUTPUT_PATH: &str = "src/provider/models.json";
13
14/// Providers we care about and their model-dev key.
15const TARGET_PROVIDERS: &[(&str, &str)] = &[
16    ("github-copilot", "github-copilot"),
17    ("opencode", "opencode"),
18    ("opencode-go", "opencode-go"),
19    ("deepseek", "deepseek"),
20];
21
22/// Run the generate. Called from main.rs when args contain "generate-models".
23pub async fn run_generate_models() -> anyhow::Result<()> {
24    // 1. Fetch models.dev
25    eprintln!("Fetching {} ...", MODELS_DEV_URL);
26    let raw = fetch(MODELS_DEV_URL).await?;
27    let root: Value = serde_json::from_str(&raw)
28        .map_err(|e| anyhow::anyhow!("Failed to parse models.dev response: {}", e))?;
29
30    // 2. Resolve output path and verify it exists (safety guard)
31    let output_path = std::env::current_dir()?.join(OUTPUT_PATH);
32    if !output_path.exists() {
33        anyhow::bail!(
34            "{} not found.\nRun this from the rab repo root, or specify a project that has the built-in catalog.\n  cargo run -- generate-models",
35            output_path.display()
36        );
37    }
38
39    // 3. Read existing file (preserve user edits to other providers)
40    let mut output: Value = if output_path.exists() {
41        let content = std::fs::read_to_string(&output_path)?;
42        serde_json::from_str(&content).unwrap_or(Value::Object(serde_json::Map::new()))
43    } else {
44        Value::Object(serde_json::Map::new())
45    };
46
47    if !output.is_object() {
48        output = Value::Object(serde_json::Map::new());
49    }
50
51    // 4. Process each target provider — all processing inside a block
52    //    so the mutable borrow on `output` drops before the write below.
53    let total: usize = {
54        let obj = output
55            .as_object_mut()
56            .ok_or_else(|| anyhow::anyhow!("output is not an object"))?;
57
58        if !obj.contains_key("providers") {
59            obj.insert("providers".into(), Value::Object(serde_json::Map::new()));
60        }
61
62        let providers_obj = obj["providers"]
63            .as_object_mut()
64            .ok_or_else(|| anyhow::anyhow!("providers is not an object"))?;
65
66        for &(provider_key, models_dev_key) in TARGET_PROVIDERS {
67            let models_map = root
68                .get(models_dev_key)
69                .and_then(|s| s.get("models"))
70                .and_then(|m| m.as_object())
71                .ok_or_else(|| {
72                    anyhow::anyhow!(
73                        "No models for '{}' in models.dev. Aborting.",
74                        models_dev_key
75                    )
76                })?;
77
78            let models: Vec<Value> = models_map
79                .iter()
80                .filter(|(_, v)| {
81                    v.get("tool_call").and_then(|x| x.as_bool()) == Some(true)
82                        && v.get("status").and_then(|x| x.as_str()) != Some("deprecated")
83                })
84                .map(|(model_id, model_val)| build_model_entry(provider_key, model_id, model_val))
85                .collect::<Result<Vec<_>, _>>()?;
86
87            let headers = provider_headers(provider_key);
88            let mut provider_entry = serde_json::json!({
89                "name": provider_display_name(provider_key),
90                "baseUrl": provider_base_url(provider_key),
91                "api": provider_base_api(provider_key),
92                "env": { "apiKey": provider_env_var(provider_key) },
93                "models": models
94            });
95            if !headers.is_empty() {
96                let headers_obj: serde_json::Map<String, Value> = headers
97                    .iter()
98                    .map(|(k, v)| ((*k).to_string(), Value::String((*v).to_string())))
99                    .collect();
100                provider_entry["headers"] = Value::Object(headers_obj);
101            }
102
103            providers_obj.insert(provider_key.to_string(), provider_entry);
104        }
105
106        // Count total models
107        providers_obj
108            .values()
109            .filter_map(|p| p.get("models").and_then(|m| m.as_array()))
110            .map(|m| m.len())
111            .sum()
112    }; // mutable borrow on `output` ends here
113
114    // 5. Write back (only reached if all processing succeeded)
115    let json = serde_json::to_string_pretty(&output)?;
116    if let Some(parent) = output_path.parent() {
117        std::fs::create_dir_all(parent)?;
118    }
119    std::fs::write(&output_path, &json)?;
120
121    eprintln!(
122        "Wrote {} models across {} providers to {}",
123        total,
124        TARGET_PROVIDERS.len(),
125        output_path.display()
126    );
127    Ok(())
128}
129
130fn provider_display_name(key: &str) -> &'static str {
131    match key {
132        "github-copilot" => "GitHub Copilot",
133        "opencode-go" => "OpenCode Zen Go",
134        "deepseek" => "DeepSeek",
135        _ => "OpenCode Zen",
136    }
137}
138
139fn provider_base_url(key: &str) -> &'static str {
140    match key {
141        "github-copilot" => "https://api.individual.githubcopilot.com",
142        "opencode-go" => "https://opencode.ai/zen/go",
143        "deepseek" => "https://api.deepseek.com",
144        _ => "https://opencode.ai/zen",
145    }
146}
147
148fn provider_env_var(key: &str) -> &'static str {
149    match key {
150        "github-copilot" => "COPILOT_GITHUB_TOKEN",
151        "deepseek" => "DEEPSEEK_API_KEY",
152        _ => "OPENCODE_API_KEY",
153    }
154}
155
156fn provider_base_api(key: &str) -> &'static str {
157    let _ = key;
158    "openai-completions"
159}
160
161/// Provider-level HTTP headers (e.g. for GitHub Copilot).
162fn provider_headers(key: &str) -> Vec<(&'static str, &'static str)> {
163    match key {
164        "github-copilot" => vec![
165            ("User-Agent", "GitHubCopilotChat/0.35.0"),
166            ("Editor-Version", "vscode/1.107.0"),
167            ("Editor-Plugin-Version", "copilot-chat/0.35.0"),
168            ("Copilot-Integration-Id", "vscode-chat"),
169        ],
170        _ => vec![],
171    }
172}
173
174fn build_model_entry(
175    provider_key: &str,
176    model_id: &str,
177    model_val: &Value,
178) -> anyhow::Result<Value> {
179    let obj = model_val
180        .as_object()
181        .ok_or_else(|| anyhow::anyhow!("Model '{}' is not an object", model_id))?;
182
183    let npm = obj
184        .get("provider")
185        .and_then(|p| p.get("npm"))
186        .and_then(|v| v.as_str());
187
188    let (api, base_url) = resolve_api_and_base_url(provider_key, model_id, npm, obj);
189    let reasoning = obj
190        .get("reasoning")
191        .and_then(|v| v.as_bool())
192        .unwrap_or(false);
193
194    let mut input: Vec<Value> = vec!["text".into()];
195    if let Some(mods) = obj
196        .get("modalities")
197        .and_then(|m| m.get("input"))
198        .and_then(|m| m.as_array())
199        && mods.iter().any(|m| m.as_str() == Some("image"))
200    {
201        input.push("image".into());
202    }
203
204    let input_cost = obj
205        .get("cost")
206        .and_then(|c| c.get("input"))
207        .and_then(|v| v.as_f64())
208        .unwrap_or(0.0);
209    let output_cost = obj
210        .get("cost")
211        .and_then(|c| c.get("output"))
212        .and_then(|v| v.as_f64())
213        .unwrap_or(0.0);
214    let cache_read = obj
215        .get("cost")
216        .and_then(|c| c.get("cache_read"))
217        .and_then(|v| v.as_f64())
218        .unwrap_or(0.0);
219    let cache_write = obj
220        .get("cost")
221        .and_then(|c| c.get("cache_write"))
222        .and_then(|v| v.as_f64())
223        .unwrap_or(0.0);
224    let context_window = obj
225        .get("limit")
226        .and_then(|l| l.get("context"))
227        .and_then(|v| v.as_u64())
228        .unwrap_or(4096);
229    let max_tokens = obj
230        .get("limit")
231        .and_then(|l| l.get("output"))
232        .and_then(|v| v.as_u64())
233        .unwrap_or(4096);
234
235    let mut entry = serde_json::json!({
236        "id": model_id,
237        "name": obj.get("name").and_then(|v| v.as_str()).unwrap_or(model_id),
238        "api": api,
239        "reasoning": reasoning,
240        "input": input,
241        "cost": {
242            "input": input_cost,
243            "output": output_cost,
244            "cacheRead": cache_read,
245            "cacheWrite": cache_write
246        },
247        "contextWindow": context_window,
248        "maxTokens": max_tokens
249    });
250
251    if let Some(bu) = base_url {
252        entry["baseUrl"] = Value::String(bu);
253    }
254
255    apply_corrections(provider_key, model_id, &mut entry, api, reasoning, obj, npm);
256
257    Ok(entry)
258}
259
260/// Determine the API identifier and optional base URL override for a model.
261fn resolve_api_and_base_url<'a>(
262    provider_key: &str,
263    model_id: &str,
264    npm: Option<&str>,
265    _obj: &'a serde_json::Map<String, Value>,
266) -> (&'a str, Option<String>) {
267    let base_path = provider_base_url(provider_key);
268
269    match npm {
270        Some("@ai-sdk/openai") => ("openai-responses", Some(format!("{}/v1", base_path))),
271        Some("@ai-sdk/anthropic") => ("anthropic-messages", Some(base_path.into())),
272        Some("@ai-sdk/google") => ("google-generative-ai", Some(format!("{}/v1", base_path))),
273        _ => {
274            // GitHub Copilot's openai-completions API is at the root, not under /v1
275            if provider_key == "github-copilot" {
276                return ("openai-completions", Some(base_path.into()));
277            }
278            if provider_key == "opencode-go" && model_id == "minimax-m2.7" {
279                return ("openai-completions", Some(format!("{}/v1", base_path)));
280            }
281            if provider_key == "opencode-go"
282                && (model_id == "qwen3.5-plus" || model_id == "qwen3.6-plus")
283            {
284                return ("openai-completions", Some(format!("{}/v1", base_path)));
285            }
286            ("openai-completions", Some(format!("{}/v1", base_path)))
287        }
288    }
289}
290
291/// Apply pi-style corrections to a model entry.
292fn apply_corrections(
293    provider_key: &str,
294    model_id: &str,
295    entry: &mut Value,
296    api: &str,
297    _reasoning: bool,
298    _obj: &serde_json::Map<String, Value>,
299    _npm: Option<&str>,
300) {
301    if api != "openai-completions" {
302        return;
303    }
304
305    let mut compat = serde_json::json!({
306        "supportsStore": false,
307        "supportsDeveloperRole": false,
308        "maxTokensField": "max_tokens"
309    });
310
311    if model_id.contains("deepseek-v4") {
312        compat["requiresReasoningContentOnAssistantMessages"] = Value::Bool(true);
313        compat["thinkingFormat"] = Value::String("deepseek".into());
314        compat["supportsReasoningEffort"] = Value::Bool(false);
315        if provider_key == "opencode" {
316            compat["supportsLongCacheRetention"] = Value::Bool(false);
317        }
318        if provider_key == "deepseek" {
319            compat["supportsThinkingControl"] = Value::Bool(true);
320        }
321        entry["thinkingLevelMap"] = serde_json::json!({
322            "minimal": null, "low": null, "medium": null, "high": "high", "xhigh": "max"
323        });
324    }
325
326    if model_id == "kimi-k2.6" {
327        compat["thinkingFormat"] = Value::String("deepseek".into());
328        compat["supportsReasoningEffort"] = Value::Bool(false);
329        compat["supportsLongCacheRetention"] = Value::Bool(false);
330    }
331
332    if model_id == "kimi-k2.5" {
333        compat["supportsLongCacheRetention"] = Value::Bool(false);
334    }
335
336    if model_id == "minimax-m2.7" {
337        compat["supportsLongCacheRetention"] = Value::Bool(false);
338    }
339
340    if model_id == "deepseek-reasoner" {
341        compat["requiresReasoningContentOnAssistantMessages"] = Value::Bool(true);
342        compat["thinkingFormat"] = Value::String("deepseek".into());
343        compat["supportsReasoningEffort"] = Value::Bool(false);
344        compat["supportsThinkingControl"] = Value::Bool(true);
345        entry["thinkingLevelMap"] = serde_json::json!({
346            "minimal": null, "low": null, "medium": null, "high": "high", "xhigh": "max"
347        });
348    }
349
350    if model_id == "grok-build-0.1" {
351        compat["supportsReasoningEffort"] = Value::Bool(false);
352        entry["thinkingLevelMap"] = serde_json::json!({
353            "off": null, "minimal": null, "low": null, "medium": null
354        });
355    }
356
357    if provider_key == "opencode-go" && (model_id.starts_with("qwen3")) {
358        compat["thinkingFormat"] = Value::String("qwen".into());
359    }
360
361    // GitHub Copilot: openai-completions models need standard Copilot compat
362    if provider_key == "github-copilot" {
363        compat["supportsReasoningEffort"] = Value::Bool(false);
364    }
365
366    entry["compat"] = compat;
367}
368
369async fn fetch(url: &str) -> anyhow::Result<String> {
370    let response = reqwest::get(url)
371        .await
372        .map_err(|e| anyhow::anyhow!("Network error fetching {}: {}", url, e))?;
373
374    if !response.status().is_success() {
375        anyhow::bail!("HTTP {} fetching {}", response.status(), url);
376    }
377
378    response
379        .text()
380        .await
381        .map_err(|e| anyhow::anyhow!("Failed to read response body: {}", e))
382}