battlecommand_forge/
models.rs1use anyhow::Result;
8use serde::{Deserialize, Serialize};
9use std::time::Instant;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ModelInfo {
13 pub name: String,
14 pub size: String,
15 pub modified: String,
16}
17
18#[derive(Debug, Clone)]
19pub struct PresetConfig {
20 pub name: String,
21 pub architect: String,
22 pub tester: String,
23 pub coder: String,
24 pub reviewer: String,
25 pub security: String,
26}
27
28impl PresetConfig {
29 pub fn default_premium() -> Self {
30 Self {
31 name: "premium".into(),
32 architect: "qwen3-coder-next:q8_0".into(),
33 tester: "qwen3-coder-next:q8_0".into(),
34 coder: "qwen3-coder-next:q8_0".into(),
35 reviewer: "qwen3-coder-next:q8_0".into(),
36 security: "qwen3-coder-next:q8_0".into(),
37 }
38 }
39
40 pub fn default_balanced() -> Self {
41 Self {
42 name: "balanced".into(),
43 architect: "qwen2.5-coder:32b".into(),
44 tester: "qwen2.5-coder:32b".into(),
45 coder: "qwen2.5-coder:32b".into(),
46 reviewer: "qwen2.5-coder:32b".into(),
47 security: "qwen2.5-coder:32b".into(),
48 }
49 }
50
51 pub fn default_fast() -> Self {
52 Self {
53 name: "fast".into(),
54 architect: "qwen2.5-coder:7b".into(),
55 tester: "qwen2.5-coder:7b".into(),
56 coder: "qwen2.5-coder:7b".into(),
57 reviewer: "qwen2.5-coder:7b".into(),
58 security: "qwen2.5-coder:7b".into(),
59 }
60 }
61}
62
63pub async fn list_ollama_models() -> Result<Vec<ModelInfo>> {
65 let client = reqwest::Client::new();
66 let resp = client
67 .get(format!("{}/api/tags", crate::llm::ollama_url()))
68 .send()
69 .await?;
70
71 let body: serde_json::Value = resp.json().await?;
72 let models = body["models"]
73 .as_array()
74 .map(|arr| {
75 arr.iter()
76 .map(|m| ModelInfo {
77 name: m["name"].as_str().unwrap_or("").to_string(),
78 size: format_bytes(m["size"].as_u64().unwrap_or(0)),
79 modified: m["modified_at"]
80 .as_str()
81 .unwrap_or("")
82 .chars()
83 .take(10)
84 .collect(),
85 })
86 .collect()
87 })
88 .unwrap_or_default();
89
90 Ok(models)
91}
92
93pub fn get_preset(name: &str) -> PresetConfig {
95 match name {
96 "fast" => PresetConfig::default_fast(),
97 "balanced" => PresetConfig::default_balanced(),
98 _ => PresetConfig::default_premium(),
99 }
100}
101
102pub async fn benchmark_model(model: &str) -> Result<BenchmarkResult> {
104 let client = reqwest::Client::builder()
105 .timeout(std::time::Duration::from_secs(120))
106 .build()?;
107
108 let prompt = "Write a Python function that checks if a number is prime. Include docstring and type hints.";
109
110 let body = serde_json::json!({
111 "model": model,
112 "prompt": prompt,
113 "stream": false,
114 "options": { "temperature": 0.0, "num_ctx": 4096 }
115 });
116
117 let start = Instant::now();
118 let resp = client
119 .post(format!("{}/api/generate", crate::llm::ollama_url()))
120 .json(&body)
121 .send()
122 .await?;
123
124 let elapsed = start.elapsed();
125 let json: serde_json::Value = resp.json().await?;
126
127 let response = json["response"].as_str().unwrap_or("").to_string();
128 let eval_count = json["eval_count"].as_u64().unwrap_or(0);
129 let eval_duration_ns = json["eval_duration"].as_u64().unwrap_or(1);
130 let tokens_per_sec = if eval_duration_ns > 0 {
131 (eval_count as f64) / (eval_duration_ns as f64 / 1_000_000_000.0)
132 } else {
133 0.0
134 };
135
136 Ok(BenchmarkResult {
137 model: model.to_string(),
138 tokens_generated: eval_count as u32,
139 total_time_secs: elapsed.as_secs_f64(),
140 tokens_per_sec,
141 response_lines: response.lines().count() as u32,
142 })
143}
144
145#[derive(Debug)]
146pub struct BenchmarkResult {
147 pub model: String,
148 pub tokens_generated: u32,
149 pub total_time_secs: f64,
150 pub tokens_per_sec: f64,
151 pub response_lines: u32,
152}
153
154impl std::fmt::Display for BenchmarkResult {
155 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156 write!(
157 f,
158 "{}: {} tokens in {:.1}s ({:.1} tok/s), {} lines",
159 self.model,
160 self.tokens_generated,
161 self.total_time_secs,
162 self.tokens_per_sec,
163 self.response_lines
164 )
165 }
166}
167
168pub fn estimate_vram_gb(model_name: &str) -> f64 {
170 let lower = model_name.to_lowercase();
171
172 let params_b: f64 = if lower.contains("80b") || lower.contains("70b") {
174 75.0
175 } else if lower.contains("35b") || lower.contains("32b") || lower.contains("30b") {
176 32.0
177 } else if lower.contains("27b") || lower.contains("24b") {
178 25.0
179 } else if lower.contains("14b") || lower.contains("16b") {
180 15.0
181 } else if lower.contains("7b") || lower.contains("9b") || lower.contains("8b") {
182 8.0
183 } else if lower.contains("4b") || lower.contains("3b") {
184 4.0
185 } else {
186 7.0
187 };
188
189 let bytes_per_param: f64 = if lower.contains("bf16") || lower.contains("fp16") {
191 2.0
192 } else if lower.contains("q8") {
193 1.0
194 } else if lower.contains("q4") {
195 0.5
196 } else {
197 0.6 };
199
200 (params_b * bytes_per_param) + 2.0
202}
203
204fn format_bytes(bytes: u64) -> String {
205 if bytes >= 1_000_000_000 {
206 format!("{:.1} GB", bytes as f64 / 1_000_000_000.0)
207 } else if bytes >= 1_000_000 {
208 format!("{:.0} MB", bytes as f64 / 1_000_000.0)
209 } else {
210 format!("{} B", bytes)
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217
218 #[test]
219 fn test_get_preset() {
220 let p = get_preset("fast");
221 assert_eq!(p.coder, "qwen2.5-coder:7b");
222
223 let p = get_preset("premium");
224 assert_eq!(p.coder, "qwen3-coder-next:q8_0");
225 }
226
227 #[test]
228 fn test_estimate_vram() {
229 let vram = estimate_vram_gb("qwen2.5-coder:7b");
230 assert!(vram > 4.0 && vram < 15.0);
231
232 let vram = estimate_vram_gb("model-80b-q8_0");
235 assert!(vram > 70.0);
236 }
237
238 #[test]
239 fn test_format_bytes() {
240 assert_eq!(format_bytes(8_900_000_000), "8.9 GB");
241 assert_eq!(format_bytes(500_000_000), "500 MB");
242 }
243}