Skip to main content

battlecommand_forge/
models.rs

1//! Model configuration, listing, and benchmarking.
2//!
3//! Reads presets from .battlecommand/models.toml.
4//! Queries Ollama API for available models.
5//! Benchmarks models with a standard prompt.
6
7use anyhow::Result;
8use serde::{Deserialize, Serialize};
9use std::time::Instant;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ModelInfo {
13    pub name: String,
14    pub size: String,
15    pub modified: String,
16}
17
18#[derive(Debug, Clone)]
19pub struct PresetConfig {
20    pub name: String,
21    pub architect: String,
22    pub tester: String,
23    pub coder: String,
24    pub reviewer: String,
25    pub security: String,
26}
27
28impl PresetConfig {
29    pub fn default_premium() -> Self {
30        Self {
31            name: "premium".into(),
32            architect: "qwen3-coder-next:q8_0".into(),
33            tester: "qwen3-coder-next:q8_0".into(),
34            coder: "qwen3-coder-next:q8_0".into(),
35            reviewer: "qwen3-coder-next:q8_0".into(),
36            security: "qwen3-coder-next:q8_0".into(),
37        }
38    }
39
40    pub fn default_balanced() -> Self {
41        Self {
42            name: "balanced".into(),
43            architect: "qwen2.5-coder:32b".into(),
44            tester: "qwen2.5-coder:32b".into(),
45            coder: "qwen2.5-coder:32b".into(),
46            reviewer: "qwen2.5-coder:32b".into(),
47            security: "qwen2.5-coder:32b".into(),
48        }
49    }
50
51    pub fn default_fast() -> Self {
52        Self {
53            name: "fast".into(),
54            architect: "qwen2.5-coder:7b".into(),
55            tester: "qwen2.5-coder:7b".into(),
56            coder: "qwen2.5-coder:7b".into(),
57            reviewer: "qwen2.5-coder:7b".into(),
58            security: "qwen2.5-coder:7b".into(),
59        }
60    }
61}
62
63/// List all available Ollama models.
64pub async fn list_ollama_models() -> Result<Vec<ModelInfo>> {
65    let client = reqwest::Client::new();
66    let resp = client
67        .get(format!("{}/api/tags", crate::llm::ollama_url()))
68        .send()
69        .await?;
70
71    let body: serde_json::Value = resp.json().await?;
72    let models = body["models"]
73        .as_array()
74        .map(|arr| {
75            arr.iter()
76                .map(|m| ModelInfo {
77                    name: m["name"].as_str().unwrap_or("").to_string(),
78                    size: format_bytes(m["size"].as_u64().unwrap_or(0)),
79                    modified: m["modified_at"]
80                        .as_str()
81                        .unwrap_or("")
82                        .chars()
83                        .take(10)
84                        .collect(),
85                })
86                .collect()
87        })
88        .unwrap_or_default();
89
90    Ok(models)
91}
92
93/// Get the preset configuration for a given preset name.
94pub fn get_preset(name: &str) -> PresetConfig {
95    match name {
96        "fast" => PresetConfig::default_fast(),
97        "balanced" => PresetConfig::default_balanced(),
98        _ => PresetConfig::default_premium(),
99    }
100}
101
102/// Benchmark a model by sending a standard prompt and measuring tokens/sec.
103pub async fn benchmark_model(model: &str) -> Result<BenchmarkResult> {
104    let client = reqwest::Client::builder()
105        .timeout(std::time::Duration::from_secs(120))
106        .build()?;
107
108    let prompt = "Write a Python function that checks if a number is prime. Include docstring and type hints.";
109
110    let body = serde_json::json!({
111        "model": model,
112        "prompt": prompt,
113        "stream": false,
114        "options": { "temperature": 0.0, "num_ctx": 4096 }
115    });
116
117    let start = Instant::now();
118    let resp = client
119        .post(format!("{}/api/generate", crate::llm::ollama_url()))
120        .json(&body)
121        .send()
122        .await?;
123
124    let elapsed = start.elapsed();
125    let json: serde_json::Value = resp.json().await?;
126
127    let response = json["response"].as_str().unwrap_or("").to_string();
128    let eval_count = json["eval_count"].as_u64().unwrap_or(0);
129    let eval_duration_ns = json["eval_duration"].as_u64().unwrap_or(1);
130    let tokens_per_sec = if eval_duration_ns > 0 {
131        (eval_count as f64) / (eval_duration_ns as f64 / 1_000_000_000.0)
132    } else {
133        0.0
134    };
135
136    Ok(BenchmarkResult {
137        model: model.to_string(),
138        tokens_generated: eval_count as u32,
139        total_time_secs: elapsed.as_secs_f64(),
140        tokens_per_sec,
141        response_lines: response.lines().count() as u32,
142    })
143}
144
145#[derive(Debug)]
146pub struct BenchmarkResult {
147    pub model: String,
148    pub tokens_generated: u32,
149    pub total_time_secs: f64,
150    pub tokens_per_sec: f64,
151    pub response_lines: u32,
152}
153
154impl std::fmt::Display for BenchmarkResult {
155    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156        write!(
157            f,
158            "{}: {} tokens in {:.1}s ({:.1} tok/s), {} lines",
159            self.model,
160            self.tokens_generated,
161            self.total_time_secs,
162            self.tokens_per_sec,
163            self.response_lines
164        )
165    }
166}
167
168/// Estimate VRAM usage for a model based on parameter count and quantization.
169pub fn estimate_vram_gb(model_name: &str) -> f64 {
170    let lower = model_name.to_lowercase();
171
172    // Extract parameter count
173    let params_b: f64 = if lower.contains("80b") || lower.contains("70b") {
174        75.0
175    } else if lower.contains("35b") || lower.contains("32b") || lower.contains("30b") {
176        32.0
177    } else if lower.contains("27b") || lower.contains("24b") {
178        25.0
179    } else if lower.contains("14b") || lower.contains("16b") {
180        15.0
181    } else if lower.contains("7b") || lower.contains("9b") || lower.contains("8b") {
182        8.0
183    } else if lower.contains("4b") || lower.contains("3b") {
184        4.0
185    } else {
186        7.0
187    };
188
189    // Quantization multiplier (bytes per parameter)
190    let bytes_per_param: f64 = if lower.contains("bf16") || lower.contains("fp16") {
191        2.0
192    } else if lower.contains("q8") {
193        1.0
194    } else if lower.contains("q4") {
195        0.5
196    } else {
197        0.6 // default ~Q5
198    };
199
200    // VRAM = params * bytes_per_param + overhead (~2GB for KV cache)
201    (params_b * bytes_per_param) + 2.0
202}
203
204fn format_bytes(bytes: u64) -> String {
205    if bytes >= 1_000_000_000 {
206        format!("{:.1} GB", bytes as f64 / 1_000_000_000.0)
207    } else if bytes >= 1_000_000 {
208        format!("{:.0} MB", bytes as f64 / 1_000_000.0)
209    } else {
210        format!("{} B", bytes)
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    #[test]
219    fn test_get_preset() {
220        let p = get_preset("fast");
221        assert_eq!(p.coder, "qwen2.5-coder:7b");
222
223        let p = get_preset("premium");
224        assert_eq!(p.coder, "qwen3-coder-next:q8_0");
225    }
226
227    #[test]
228    fn test_estimate_vram() {
229        let vram = estimate_vram_gb("qwen2.5-coder:7b");
230        assert!(vram > 4.0 && vram < 15.0);
231
232        // "qwen3-coder-next:q8_0" doesn't have "80b" in name, so defaults to 7B estimate
233        // This is a known limitation — model name doesn't always encode param count
234        let vram = estimate_vram_gb("model-80b-q8_0");
235        assert!(vram > 70.0);
236    }
237
238    #[test]
239    fn test_format_bytes() {
240        assert_eq!(format_bytes(8_900_000_000), "8.9 GB");
241        assert_eq!(format_bytes(500_000_000), "500 MB");
242    }
243}