Skip to main content

battlecommand_forge/
swarm.rs

1use crate::llm::LlmClient;
2use crate::model_config::ModelConfig;
3/// Swarm mode: planner → coder → QA iteration.
4/// Runs multiple iterations of code generation and picks the best version.
5/// Each iteration: plan → code → validate → QA review.
6/// Best version selected by validation pass + QA score.
7use anyhow::Result;
8use std::time::Instant;
9
10#[derive(Debug, Clone)]
11pub struct SwarmVersion {
12    pub iteration: u32,
13    pub plan: String,
14    pub code: String,
15    pub qa_feedback: String,
16    pub qa_score: u32,
17    pub validated: bool,
18    pub validation_output: String,
19}
20
21pub struct SwarmOpts {
22    pub iterations: u32,
23    pub output_dir: String,
24    pub language: String,
25}
26
27impl Default for SwarmOpts {
28    fn default() -> Self {
29        Self {
30            iterations: 3,
31            output_dir: "output/swarm".into(),
32            language: "python".into(),
33        }
34    }
35}
36
37pub async fn run_swarm(prompt: &str, config: &ModelConfig, opts: &SwarmOpts) -> Result<()> {
38    let start = Instant::now();
39    let mut versions: Vec<SwarmVersion> = Vec::new();
40
41    println!("Swarm Mode: {} iterations", opts.iterations);
42    println!("  Coder: {}", config.coder.model);
43    println!(
44        "  Prompt: {}\n",
45        if prompt.len() > 80 {
46            &prompt[..80]
47        } else {
48            prompt
49        }
50    );
51
52    let planner = LlmClient::with_limits(
53        &config.architect.model,
54        config.architect.context_size(),
55        2048,
56    );
57    let coder = LlmClient::with_limits(
58        &config.coder.model,
59        config.coder.context_size(),
60        config.coder.max_predict(),
61    );
62    let qa = LlmClient::with_limits(&config.critique.model, config.critique.context_size(), 1024);
63
64    // Setup output directory
65    tokio::fs::create_dir_all(&opts.output_dir).await?;
66
67    for iter in 1..=opts.iterations {
68        let iter_start = Instant::now();
69        let prev = versions.last();
70
71        // ── Phase 1: Planner ──
72        let plan_system = "You are a senior engineer creating an implementation plan. Be specific about function names, signatures, data structures, and edge cases. Output a numbered list, no code.";
73        let plan_user = if let Some(prev) = prev {
74            format!(
75                "TASK: {}\nLANGUAGE: {}\n\nPREVIOUS ATTEMPT (iteration {}) scored {}/10:\nQA FEEDBACK:\n{}\n\nCreate an improved plan that addresses the feedback.",
76                prompt, opts.language, prev.iteration, prev.qa_score, prev.qa_feedback
77            )
78        } else {
79            format!(
80                "TASK: {}\nLANGUAGE: {}\n\nCreate a precise implementation plan.",
81                prompt, opts.language
82            )
83        };
84
85        let plan = planner
86            .generate("architect", plan_system, &plan_user)
87            .await
88            .unwrap_or_else(|e| format!("Planning failed: {}", e));
89
90        // ── Phase 2: Coder ──
91        let code_system = format!(
92            "You are an elite AI coder. Follow the implementation plan exactly.\n\
93             RULES:\n\
94             1. Write COMPLETE, WORKING code. No placeholders, no TODOs.\n\
95             2. Include ALL imports at the top.\n\
96             3. Handle all edge cases from the plan.\n\
97             4. Output ONLY code in a ```{}``` code block.",
98            opts.language
99        );
100        let code_user = format!(
101            "TASK: {}\n\nIMPLEMENTATION PLAN:\n{}\n\nWrite the complete implementation.",
102            prompt, plan
103        );
104
105        let code_resp = coder
106            .generate_live("coder", &code_system, &code_user)
107            .await
108            .unwrap_or_else(|e| format!("// Code generation failed: {}", e));
109        let code = extract_code(&code_resp, &opts.language);
110
111        // Write code to file
112        let file_ext = match opts.language.as_str() {
113            "python" => "py",
114            "javascript" | "js" => "js",
115            "typescript" | "ts" => "ts",
116            "rust" => "rs",
117            "go" => "go",
118            "cpp" | "c++" => "cpp",
119            _ => "txt",
120        };
121        let file_name = format!("{}/main.{}", opts.output_dir, file_ext);
122        tokio::fs::write(&file_name, &code).await?;
123
124        // ── Phase 3: Validate ──
125        let (validated, validation_output) = run_validation(&file_name, &opts.language).await;
126
127        // ── Phase 4: Mini-QA ──
128        let qa_system = "Review this code against the task requirements. Score 1-10. Be STRICT. Respond with ONLY JSON: {\"score\": N, \"feedback\": \"...\"}";
129        let code_snippet = if code.len() > 8000 {
130            &code[..8000]
131        } else {
132            &code
133        };
134        let validation_status = if validated {
135            "PASSED".to_string()
136        } else {
137            format!(
138                "FAILED: {}",
139                &validation_output[..validation_output.len().min(300)]
140            )
141        };
142        let qa_user = format!(
143            "TASK: {}\nVALIDATION: {}\n\nCODE:\n```{}\n{}\n```\n\nScore and list specific issues.",
144            prompt, validation_status, opts.language, code_snippet
145        );
146        let qa_resp = qa
147            .generate("critique", qa_system, &qa_user)
148            .await
149            .unwrap_or_else(|_| "{\"score\": 5, \"feedback\": \"QA failed\"}".into());
150        let (qa_feedback, qa_score) = parse_qa_response(&qa_resp);
151
152        let iter_time = iter_start.elapsed().as_secs_f64();
153        println!(
154            "  [iter {}/{}] score {}/10 {} {:.1}s",
155            iter,
156            opts.iterations,
157            qa_score,
158            if validated { "PASS" } else { "FAIL" },
159            iter_time
160        );
161
162        let version = SwarmVersion {
163            iteration: iter,
164            plan,
165            code,
166            qa_feedback,
167            qa_score,
168            validated,
169            validation_output,
170        };
171
172        // Early exit on high score + validation pass
173        if qa_score >= 9 && validated {
174            println!("  Early exit: iter {} scored {}/10 PASS", iter, qa_score);
175            versions.push(version);
176            break;
177        }
178
179        versions.push(version);
180    }
181
182    // ── Select best version ──
183    let best_idx = select_best(&versions);
184    let best = &versions[best_idx];
185
186    // Write best version to disk
187    let file_ext = match opts.language.as_str() {
188        "python" => "py",
189        "javascript" | "js" => "js",
190        "typescript" | "ts" => "ts",
191        "rust" => "rs",
192        "go" => "go",
193        "cpp" | "c++" => "cpp",
194        _ => "txt",
195    };
196    let file_name = format!("{}/main.{}", opts.output_dir, file_ext);
197    tokio::fs::write(&file_name, &best.code).await?;
198
199    let duration = start.elapsed().as_secs_f64();
200    println!(
201        "\nSwarm complete: best=iter{} (score {}/10 {}) | {:.1}s total",
202        best.iteration,
203        best.qa_score,
204        if best.validated { "PASS" } else { "FAIL" },
205        duration
206    );
207    println!("Output: {}", file_name);
208
209    Ok(())
210}
211
212fn select_best(versions: &[SwarmVersion]) -> usize {
213    versions
214        .iter()
215        .enumerate()
216        .max_by_key(|(_, v)| (v.validated as u32 * 100 + v.qa_score, v.iteration))
217        .map(|(i, _)| i)
218        .unwrap_or(0)
219}
220
221fn extract_code(response: &str, _language: &str) -> String {
222    let mut in_block = false;
223    let mut code = String::new();
224    for line in response.lines() {
225        if line.trim_start().starts_with("```") {
226            if in_block {
227                break;
228            }
229            in_block = true;
230            continue;
231        }
232        if in_block {
233            code.push_str(line);
234            code.push('\n');
235        }
236    }
237    if code.is_empty() {
238        response.to_string()
239    } else {
240        code
241    }
242}
243
244async fn run_validation(file_path: &str, language: &str) -> (bool, String) {
245    let cmd = match language {
246        "python" => format!(
247            "python3 -c \"import ast; ast.parse(open('{}').read()); print('SYNTAX OK')\"",
248            file_path
249        ),
250        "rust" => format!("rustc --edition 2021 {} -o /dev/null 2>&1", file_path),
251        "cpp" | "c++" => format!("c++ -std=c++17 -fsyntax-only {} 2>&1", file_path),
252        "go" => format!("go vet {} 2>&1", file_path),
253        _ => format!("test -f {} && echo 'FILE EXISTS'", file_path),
254    };
255
256    match tokio::process::Command::new("sh")
257        .arg("-c")
258        .arg(&cmd)
259        .output()
260        .await
261    {
262        Ok(output) => {
263            let stdout = String::from_utf8_lossy(&output.stdout).to_string();
264            let stderr = String::from_utf8_lossy(&output.stderr).to_string();
265            (output.status.success(), format!("{}{}", stdout, stderr))
266        }
267        Err(e) => (false, format!("Validation error: {}", e)),
268    }
269}
270
271fn parse_qa_response(text: &str) -> (String, u32) {
272    if let Some(start) = text.find('{') {
273        if let Some(end) = text.rfind('}') {
274            if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text[start..=end]) {
275                let score = parsed["score"].as_u64().unwrap_or(5) as u32;
276                let feedback = parsed["feedback"]
277                    .as_str()
278                    .unwrap_or("No feedback")
279                    .to_string();
280                return (feedback, score.min(10));
281            }
282        }
283    }
284    (text.to_string(), 5)
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn parse_qa_valid() {
293        let resp = r#"{"score": 7, "feedback": "Missing error handling"}"#;
294        let (feedback, score) = parse_qa_response(resp);
295        assert_eq!(score, 7);
296        assert!(feedback.contains("error handling"));
297    }
298
299    #[test]
300    fn parse_qa_garbage() {
301        let (_, score) = parse_qa_response("not json");
302        assert_eq!(score, 5);
303    }
304
305    #[test]
306    fn select_best_prefers_validated() {
307        let versions = vec![
308            SwarmVersion {
309                iteration: 1,
310                plan: String::new(),
311                code: "a".into(),
312                qa_feedback: String::new(),
313                qa_score: 9,
314                validated: false,
315                validation_output: String::new(),
316            },
317            SwarmVersion {
318                iteration: 2,
319                plan: String::new(),
320                code: "b".into(),
321                qa_feedback: String::new(),
322                qa_score: 6,
323                validated: true,
324                validation_output: String::new(),
325            },
326        ];
327        assert_eq!(select_best(&versions), 1);
328    }
329}