Skip to main content

battlecommand_forge/
model_config.rs

1/// Per-role model configuration with preset/TOML/env/CLI resolution.
2///
3/// Ported from battleclaw-v2 model_config.rs, adapted for forge's 9-stage pipeline.
4/// Resolution order (highest priority last): preset → env → TOML → CLI.
5use anyhow::Result;
6use serde::{Deserialize, Serialize};
7use std::path::Path;
8
9// ─── Model Provider ───
10
11#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
12#[serde(rename_all = "lowercase")]
13pub enum ModelProvider {
14    Local, // Ollama
15    Cloud, // Anthropic API
16}
17
18impl std::fmt::Display for ModelProvider {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        match self {
21            ModelProvider::Local => write!(f, "local"),
22            ModelProvider::Cloud => write!(f, "cloud"),
23        }
24    }
25}
26
27// ─── Role Config ───
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct RoleConfig {
31    pub model: String,
32    pub provider: ModelProvider,
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub context_size: Option<u32>,
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub max_predict: Option<u32>,
37}
38
39impl RoleConfig {
40    pub fn local(model: &str) -> Self {
41        Self {
42            model: model.to_string(),
43            provider: ModelProvider::Local,
44            context_size: None,
45            max_predict: None,
46        }
47    }
48
49    pub fn local_with_limits(model: &str, ctx: u32, predict: u32) -> Self {
50        Self {
51            model: model.to_string(),
52            provider: ModelProvider::Local,
53            context_size: Some(ctx),
54            max_predict: Some(predict),
55        }
56    }
57
58    pub fn cloud(model: &str) -> Self {
59        Self {
60            model: model.to_string(),
61            provider: ModelProvider::Cloud,
62            context_size: None,
63            max_predict: None,
64        }
65    }
66
67    pub fn cloud_with_limits(model: &str, ctx: u32, predict: u32) -> Self {
68        Self {
69            model: model.to_string(),
70            provider: ModelProvider::Cloud,
71            context_size: Some(ctx),
72            max_predict: Some(predict),
73        }
74    }
75
76    pub fn context_size(&self) -> u32 {
77        self.context_size.unwrap_or(32768)
78    }
79
80    pub fn max_predict(&self) -> u32 {
81        self.max_predict.unwrap_or(8192)
82    }
83}
84
85// ─── Preset ───
86
87#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
88#[serde(rename_all = "lowercase")]
89pub enum Preset {
90    Fast,
91    Balanced,
92    Premium,
93}
94
95impl std::fmt::Display for Preset {
96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97        match self {
98            Preset::Fast => write!(f, "fast"),
99            Preset::Balanced => write!(f, "balanced"),
100            Preset::Premium => write!(f, "premium"),
101        }
102    }
103}
104
105impl std::str::FromStr for Preset {
106    type Err = anyhow::Error;
107
108    fn from_str(s: &str) -> Result<Self> {
109        match s.to_lowercase().as_str() {
110            "fast" => Ok(Preset::Fast),
111            "balanced" => Ok(Preset::Balanced),
112            "premium" => Ok(Preset::Premium),
113            _ => Err(anyhow::anyhow!(
114                "Unknown preset '{}'. Use: fast, balanced, premium",
115                s
116            )),
117        }
118    }
119}
120
121// ─── Model Config ───
122
123/// Per-role model assignments for the 9-stage pipeline.
124///
125/// Pipeline roles:
126///   architect  — Stage 2: spec + file manifest + TDD plan
127///   tester     — Stage 3: write test suite before implementation
128///   coder      — Stage 4: implement against tests
129///   security   — Stage 6: OWASP review
130///   critique   — Stage 7: 5-in-1 scoring (DEV/ARCH/TEST/SEC/DOCS)
131///   cto        — Stage 8: mission-level coherence
132///   complexity — Router: AI complexity scoring
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct ModelConfig {
135    pub preset: Preset,
136    pub architect: RoleConfig,
137    pub tester: RoleConfig,
138    pub coder: RoleConfig,
139    pub fix_coder: RoleConfig,
140    pub security: RoleConfig,
141    pub critique: RoleConfig,
142    pub cto: RoleConfig,
143    pub complexity: RoleConfig,
144}
145
146impl ModelConfig {
147    /// Build from preset defaults.
148    pub fn from_preset(preset: Preset) -> Self {
149        match preset {
150            Preset::Fast => Self {
151                preset,
152                architect: RoleConfig::local_with_limits("qwen2.5-coder:7b", 32768, 8192),
153                tester: RoleConfig::local_with_limits("qwen2.5-coder:7b", 32768, 16384),
154                coder: RoleConfig::local_with_limits("qwen2.5-coder:7b", 32768, 16384),
155                fix_coder: RoleConfig::local_with_limits("qwen2.5-coder:7b", 32768, 16384),
156                security: RoleConfig::local_with_limits("qwen2.5-coder:7b", 32768, 1024),
157                critique: RoleConfig::local_with_limits("qwen2.5-coder:7b", 32768, 1024),
158                cto: RoleConfig::local_with_limits("qwen2.5-coder:7b", 32768, 1024),
159                complexity: RoleConfig::local("qwen3.5:4b-q8_0"),
160            },
161            Preset::Balanced => Self {
162                preset,
163                architect: RoleConfig::local_with_limits("qwen2.5-coder:32b", 32768, 8192),
164                tester: RoleConfig::local_with_limits("qwen2.5-coder:32b", 32768, 16384),
165                coder: RoleConfig::local_with_limits("qwen2.5-coder:32b", 32768, 16384),
166                fix_coder: RoleConfig::local_with_limits("qwen2.5-coder:32b", 32768, 16384),
167                security: RoleConfig::local_with_limits("qwen2.5-coder:7b", 32768, 1024),
168                critique: RoleConfig::local_with_limits("qwen2.5-coder:7b", 32768, 1024),
169                cto: RoleConfig::local_with_limits("qwen2.5-coder:7b", 32768, 1024),
170                complexity: RoleConfig::local("qwen3.5:4b-q8_0"),
171            },
172            Preset::Premium => Self {
173                preset,
174                architect: RoleConfig::local_with_limits("qwen2.5-coder:32b", 32768, 4096),
175                tester: RoleConfig::cloud_with_limits("claude-opus-4-6", 200000, 8192),
176                coder: RoleConfig::local_with_limits("qwen3-coder-next:q8_0", 65536, 32768),
177                fix_coder: RoleConfig::cloud_with_limits("claude-sonnet-4-6", 200000, 16384),
178                security: RoleConfig::local_with_limits("qwen3-coder:30b-a3b-q8_0", 65536, 1024),
179                critique: RoleConfig::local_with_limits("qwen3-coder:30b-a3b-q8_0", 65536, 1024),
180                cto: RoleConfig::cloud_with_limits("claude-sonnet-4-6", 200000, 1024),
181                complexity: RoleConfig::local_with_limits("qwen3-coder:30b-a3b-q8_0", 32768, 1024),
182            },
183        }
184    }
185
186    /// Merge overrides from `.battlecommand/models.toml` (if it exists).
187    pub fn merge_toml(mut self, workspace: &str) -> Self {
188        let path = format!("{}/.battlecommand/models.toml", workspace);
189        if !Path::new(&path).exists() {
190            return self;
191        }
192
193        let content = match std::fs::read_to_string(&path) {
194            Ok(c) => c,
195            Err(e) => {
196                eprintln!("[model_config] Failed to read {}: {}", path, e);
197                return self;
198            }
199        };
200
201        let toml_val: TomlConfig = match toml::from_str(&content) {
202            Ok(v) => v,
203            Err(e) => {
204                eprintln!("[model_config] Failed to parse {}: {}", path, e);
205                return self;
206            }
207        };
208
209        // If toml specifies a different preset, rebuild from that preset first
210        if let Some(preset_str) = &toml_val.preset {
211            if let Ok(preset) = preset_str.parse::<Preset>() {
212                if preset != self.preset {
213                    self = Self::from_preset(preset);
214                }
215            }
216        }
217
218        // Apply per-role overrides
219        if let Some(r) = toml_val.architect {
220            apply_role_override(&mut self.architect, r);
221        }
222        if let Some(r) = toml_val.tester {
223            apply_role_override(&mut self.tester, r);
224        }
225        if let Some(r) = toml_val.coder {
226            apply_role_override(&mut self.coder, r);
227        }
228        if let Some(r) = toml_val.fix_coder {
229            apply_role_override(&mut self.fix_coder, r);
230        }
231        if let Some(r) = toml_val.security {
232            apply_role_override(&mut self.security, r);
233        }
234        if let Some(r) = toml_val.critique {
235            apply_role_override(&mut self.critique, r);
236        }
237        if let Some(r) = toml_val.cto {
238            apply_role_override(&mut self.cto, r);
239        }
240        if let Some(r) = toml_val.complexity {
241            apply_role_override(&mut self.complexity, r);
242        }
243
244        println!("[CONFIG] Loaded model overrides from {}", path);
245        self
246    }
247
248    /// Merge from environment variables.
249    pub fn merge_env(mut self) -> Self {
250        if let Ok(v) = std::env::var("ARCHITECT_MODEL") {
251            self.architect.model = v.clone();
252            self.architect.provider = infer_provider(&v);
253        }
254        if let Ok(v) = std::env::var("TESTER_MODEL") {
255            self.tester.model = v.clone();
256            self.tester.provider = infer_provider(&v);
257        }
258        if let Ok(v) = std::env::var("CODER_MODEL") {
259            self.coder.model = v.clone();
260            self.coder.provider = infer_provider(&v);
261        }
262        if let Ok(v) = std::env::var("FIX_CODER_MODEL") {
263            self.fix_coder.model = v.clone();
264            self.fix_coder.provider = infer_provider(&v);
265        }
266        if let Ok(v) = std::env::var("SECURITY_MODEL") {
267            self.security.model = v.clone();
268            self.security.provider = infer_provider(&v);
269        }
270        if let Ok(v) = std::env::var("CRITIQUE_MODEL") {
271            self.critique.model = v.clone();
272            self.critique.provider = infer_provider(&v);
273        }
274        if let Ok(v) = std::env::var("CTO_MODEL") {
275            self.cto.model = v.clone();
276            self.cto.provider = infer_provider(&v);
277        }
278        if let Ok(v) = std::env::var("COMPLEXITY_MODEL") {
279            self.complexity.model = v.clone();
280            self.complexity.provider = infer_provider(&v);
281        }
282        // Legacy: OLLAMA_MODEL sets coder
283        if let Ok(v) = std::env::var("OLLAMA_MODEL") {
284            self.coder.model = v;
285            self.coder.provider = ModelProvider::Local;
286        }
287        // Convenience: REVIEWER_MODEL sets security+critique+cto
288        if let Ok(v) = std::env::var("REVIEWER_MODEL") {
289            let provider = infer_provider(&v);
290            self.security.model = v.clone();
291            self.security.provider = provider;
292            self.critique.model = v.clone();
293            self.critique.provider = provider;
294            self.cto.model = v;
295            self.cto.provider = provider;
296        }
297        self
298    }
299
300    /// Merge CLI flag overrides (highest priority).
301    pub fn merge_cli(
302        mut self,
303        architect: Option<&str>,
304        tester: Option<&str>,
305        coder: Option<&str>,
306        reviewer: Option<&str>,
307    ) -> Self {
308        if let Some(m) = architect {
309            self.architect.model = m.to_string();
310            self.architect.provider = infer_provider(m);
311        }
312        if let Some(m) = tester {
313            self.tester.model = m.to_string();
314            self.tester.provider = infer_provider(m);
315        }
316        if let Some(m) = coder {
317            self.coder.model = m.to_string();
318            self.coder.provider = infer_provider(m);
319        }
320        if let Some(m) = reviewer {
321            let provider = infer_provider(m);
322            self.security.model = m.to_string();
323            self.security.provider = provider;
324            self.critique.model = m.to_string();
325            self.critique.provider = provider;
326            self.cto.model = m.to_string();
327            self.cto.provider = provider;
328        }
329        self
330    }
331
332    /// Full resolution: preset → env → TOML → CLI.
333    pub fn resolve(
334        preset: Preset,
335        workspace: &str,
336        architect: Option<&str>,
337        tester: Option<&str>,
338        coder: Option<&str>,
339        reviewer: Option<&str>,
340    ) -> Self {
341        Self::from_preset(preset)
342            .merge_env()
343            .merge_toml(workspace)
344            .merge_cli(architect, tester, coder, reviewer)
345    }
346
347    /// Generate a default models.toml content.
348    pub fn generate_default_toml() -> String {
349        r#"# BattleCommand Forge — Model Configuration
350# Presets: fast, balanced, premium
351preset = "premium"
352
353# Premium dream team: Opus tester, local 80B coder, Sonnet fixer/CTO
354# Cost: ~$0.30/mission (Opus tester + Sonnet fixes). C7+ auto-upgrades coder to Sonnet.
355# Per-role overrides (uncomment to customize)
356
357# [architect]
358# model = "qwen2.5-coder:32b"           # Local 32B — concise specs, no overengineering
359# context_size = 32768
360# max_predict = 4096
361
362# [tester]
363# model = "claude-opus-4-6"             # Opus writes correct test fixtures (~$0.20)
364# context_size = 200000
365# max_predict = 8192
366
367# [coder]
368# model = "qwen3-coder-next:q8_0"       # Local 80B, single-shot generation
369# context_size = 65536
370# max_predict = 32768
371
372# [fix_coder]
373# model = "claude-sonnet-4-6"           # Sonnet for surgical fixes (~$0.05-0.10)
374# context_size = 200000
375# max_predict = 16384
376
377# [security]
378# model = "qwen3-coder:30b-a3b-q8_0"   # Most honest scorer
379# context_size = 65536
380# max_predict = 1024
381
382# [critique]
383# model = "qwen3-coder:30b-a3b-q8_0"   # DEV:3 SEC:1 for bad code
384# context_size = 65536
385# max_predict = 1024
386
387# [cto]
388# model = "claude-sonnet-4-6"           # Fast coherence checks (~$0.05)
389# context_size = 200000
390# max_predict = 1024
391"#
392        .to_string()
393    }
394
395    /// Print resolved config summary.
396    pub fn print_summary(&self) {
397        println!("Model Configuration (preset: {})", self.preset);
398        println!("{:-<60}", "");
399        println!(
400            "  Architect:   {:<35} ({})",
401            self.architect.model, self.architect.provider
402        );
403        println!(
404            "  Tester:      {:<35} ({})",
405            self.tester.model, self.tester.provider
406        );
407        println!(
408            "  Coder:       {:<35} ({})",
409            self.coder.model, self.coder.provider
410        );
411        if self.fix_coder.model != self.coder.model {
412            println!(
413                "  Fix Coder:   {:<35} ({})",
414                self.fix_coder.model, self.fix_coder.provider
415            );
416        }
417        println!(
418            "  Security:    {:<35} ({})",
419            self.security.model, self.security.provider
420        );
421        println!(
422            "  Critique:    {:<35} ({})",
423            self.critique.model, self.critique.provider
424        );
425        println!(
426            "  CTO:         {:<35} ({})",
427            self.cto.model, self.cto.provider
428        );
429        println!(
430            "  Complexity:  {:<35} ({})",
431            self.complexity.model, self.complexity.provider
432        );
433    }
434}
435
436impl Default for ModelConfig {
437    fn default() -> Self {
438        Self::from_preset(Preset::Premium)
439    }
440}
441
442// ─── TOML Schema ───
443
444#[derive(Debug, Deserialize)]
445struct TomlConfig {
446    preset: Option<String>,
447    architect: Option<TomlRoleOverride>,
448    tester: Option<TomlRoleOverride>,
449    coder: Option<TomlRoleOverride>,
450    fix_coder: Option<TomlRoleOverride>,
451    security: Option<TomlRoleOverride>,
452    critique: Option<TomlRoleOverride>,
453    cto: Option<TomlRoleOverride>,
454    complexity: Option<TomlRoleOverride>,
455}
456
457#[derive(Debug, Deserialize)]
458struct TomlRoleOverride {
459    model: Option<String>,
460    provider: Option<ModelProvider>,
461    context_size: Option<u32>,
462    max_predict: Option<u32>,
463}
464
465fn apply_role_override(role: &mut RoleConfig, ov: TomlRoleOverride) {
466    if let Some(m) = ov.model {
467        // Auto-infer provider from model name unless explicitly set
468        if ov.provider.is_none() {
469            role.provider = infer_provider(&m);
470        }
471        role.model = m;
472    }
473    if let Some(p) = ov.provider {
474        role.provider = p;
475    }
476    if let Some(c) = ov.context_size {
477        role.context_size = Some(c);
478    }
479    if let Some(p) = ov.max_predict {
480        role.max_predict = Some(p);
481    }
482}
483
484/// Infer provider from model name.
485fn infer_provider(model: &str) -> ModelProvider {
486    if model.starts_with("claude-") || model.starts_with("grok-") {
487        ModelProvider::Cloud
488    } else {
489        ModelProvider::Local
490    }
491}
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496
497    #[test]
498    fn preset_fast() {
499        let cfg = ModelConfig::from_preset(Preset::Fast);
500        assert_eq!(cfg.architect.model, "qwen2.5-coder:7b");
501        assert_eq!(cfg.coder.model, "qwen2.5-coder:7b");
502    }
503
504    #[test]
505    fn preset_balanced() {
506        let cfg = ModelConfig::from_preset(Preset::Balanced);
507        assert_eq!(cfg.architect.model, "qwen2.5-coder:32b");
508        assert_eq!(cfg.coder.model, "qwen2.5-coder:32b");
509    }
510
511    #[test]
512    fn preset_premium() {
513        let cfg = ModelConfig::from_preset(Preset::Premium);
514        assert_eq!(cfg.architect.model, "qwen2.5-coder:32b");
515        assert_eq!(cfg.tester.model, "claude-opus-4-6");
516        assert_eq!(cfg.coder.model, "qwen3-coder-next:q8_0");
517        assert_eq!(cfg.fix_coder.model, "claude-sonnet-4-6");
518        assert_eq!(cfg.security.model, "qwen3-coder:30b-a3b-q8_0");
519        assert_eq!(cfg.cto.model, "claude-sonnet-4-6");
520    }
521
522    #[test]
523    fn cli_overrides() {
524        let cfg = ModelConfig::from_preset(Preset::Premium).merge_cli(
525            Some("nemotron-3-super"),
526            None,
527            Some("nemotron"),
528            None,
529        );
530        assert_eq!(cfg.architect.model, "nemotron-3-super");
531        assert_eq!(cfg.coder.model, "nemotron");
532        // Unchanged
533        assert_eq!(cfg.tester.model, "claude-opus-4-6");
534    }
535
536    #[test]
537    fn reviewer_override_sets_all_three() {
538        let cfg = ModelConfig::from_preset(Preset::Premium).merge_cli(
539            None,
540            None,
541            None,
542            Some("nemotron-3-nano"),
543        );
544        assert_eq!(cfg.security.model, "nemotron-3-nano");
545        assert_eq!(cfg.critique.model, "nemotron-3-nano");
546        assert_eq!(cfg.cto.model, "nemotron-3-nano");
547    }
548
549    #[test]
550    fn infer_provider_cloud() {
551        assert_eq!(infer_provider("claude-sonnet-4-6"), ModelProvider::Cloud);
552        assert_eq!(infer_provider("grok-3"), ModelProvider::Cloud);
553        assert_eq!(infer_provider("qwen3.5:35b-a3b"), ModelProvider::Local);
554    }
555
556    #[test]
557    fn preset_parse() {
558        assert_eq!("fast".parse::<Preset>().unwrap(), Preset::Fast);
559        assert_eq!("balanced".parse::<Preset>().unwrap(), Preset::Balanced);
560        assert_eq!("premium".parse::<Preset>().unwrap(), Preset::Premium);
561        assert!("invalid".parse::<Preset>().is_err());
562    }
563
564    #[test]
565    fn default_is_premium() {
566        let cfg = ModelConfig::default();
567        assert_eq!(cfg.preset, Preset::Premium);
568    }
569
570    #[test]
571    fn role_config_defaults() {
572        let r = RoleConfig::local("test-model");
573        assert_eq!(r.context_size(), 32768);
574        assert_eq!(r.max_predict(), 8192);
575
576        let r = RoleConfig::local_with_limits("test-model", 16384, 4096);
577        assert_eq!(r.context_size(), 16384);
578        assert_eq!(r.max_predict(), 4096);
579    }
580}