1use crate::llm::LlmClient;
2use crate::model_config::ModelConfig;
3use anyhow::Result;
8use std::time::Instant;
9
10#[derive(Debug, Clone)]
11pub struct SwarmVersion {
12 pub iteration: u32,
13 pub plan: String,
14 pub code: String,
15 pub qa_feedback: String,
16 pub qa_score: u32,
17 pub validated: bool,
18 pub validation_output: String,
19}
20
21pub struct SwarmOpts {
22 pub iterations: u32,
23 pub output_dir: String,
24 pub language: String,
25}
26
27impl Default for SwarmOpts {
28 fn default() -> Self {
29 Self {
30 iterations: 3,
31 output_dir: "output/swarm".into(),
32 language: "python".into(),
33 }
34 }
35}
36
37pub async fn run_swarm(prompt: &str, config: &ModelConfig, opts: &SwarmOpts) -> Result<()> {
38 let start = Instant::now();
39 let mut versions: Vec<SwarmVersion> = Vec::new();
40
41 println!("Swarm Mode: {} iterations", opts.iterations);
42 println!(" Coder: {}", config.coder.model);
43 println!(
44 " Prompt: {}\n",
45 if prompt.len() > 80 {
46 &prompt[..80]
47 } else {
48 prompt
49 }
50 );
51
52 let planner = LlmClient::with_limits(
53 &config.architect.model,
54 config.architect.context_size(),
55 2048,
56 );
57 let coder = LlmClient::with_limits(
58 &config.coder.model,
59 config.coder.context_size(),
60 config.coder.max_predict(),
61 );
62 let qa = LlmClient::with_limits(&config.critique.model, config.critique.context_size(), 1024);
63
64 tokio::fs::create_dir_all(&opts.output_dir).await?;
66
67 for iter in 1..=opts.iterations {
68 let iter_start = Instant::now();
69 let prev = versions.last();
70
71 let plan_system = "You are a senior engineer creating an implementation plan. Be specific about function names, signatures, data structures, and edge cases. Output a numbered list, no code.";
73 let plan_user = if let Some(prev) = prev {
74 format!(
75 "TASK: {}\nLANGUAGE: {}\n\nPREVIOUS ATTEMPT (iteration {}) scored {}/10:\nQA FEEDBACK:\n{}\n\nCreate an improved plan that addresses the feedback.",
76 prompt, opts.language, prev.iteration, prev.qa_score, prev.qa_feedback
77 )
78 } else {
79 format!(
80 "TASK: {}\nLANGUAGE: {}\n\nCreate a precise implementation plan.",
81 prompt, opts.language
82 )
83 };
84
85 let plan = planner
86 .generate("architect", plan_system, &plan_user)
87 .await
88 .unwrap_or_else(|e| format!("Planning failed: {}", e));
89
90 let code_system = format!(
92 "You are an elite AI coder. Follow the implementation plan exactly.\n\
93 RULES:\n\
94 1. Write COMPLETE, WORKING code. No placeholders, no TODOs.\n\
95 2. Include ALL imports at the top.\n\
96 3. Handle all edge cases from the plan.\n\
97 4. Output ONLY code in a ```{}``` code block.",
98 opts.language
99 );
100 let code_user = format!(
101 "TASK: {}\n\nIMPLEMENTATION PLAN:\n{}\n\nWrite the complete implementation.",
102 prompt, plan
103 );
104
105 let code_resp = coder
106 .generate_live("coder", &code_system, &code_user)
107 .await
108 .unwrap_or_else(|e| format!("// Code generation failed: {}", e));
109 let code = extract_code(&code_resp, &opts.language);
110
111 let file_ext = match opts.language.as_str() {
113 "python" => "py",
114 "javascript" | "js" => "js",
115 "typescript" | "ts" => "ts",
116 "rust" => "rs",
117 "go" => "go",
118 "cpp" | "c++" => "cpp",
119 _ => "txt",
120 };
121 let file_name = format!("{}/main.{}", opts.output_dir, file_ext);
122 tokio::fs::write(&file_name, &code).await?;
123
124 let (validated, validation_output) = run_validation(&file_name, &opts.language).await;
126
127 let qa_system = "Review this code against the task requirements. Score 1-10. Be STRICT. Respond with ONLY JSON: {\"score\": N, \"feedback\": \"...\"}";
129 let code_snippet = if code.len() > 8000 {
130 &code[..8000]
131 } else {
132 &code
133 };
134 let validation_status = if validated {
135 "PASSED".to_string()
136 } else {
137 format!(
138 "FAILED: {}",
139 &validation_output[..validation_output.len().min(300)]
140 )
141 };
142 let qa_user = format!(
143 "TASK: {}\nVALIDATION: {}\n\nCODE:\n```{}\n{}\n```\n\nScore and list specific issues.",
144 prompt, validation_status, opts.language, code_snippet
145 );
146 let qa_resp = qa
147 .generate("critique", qa_system, &qa_user)
148 .await
149 .unwrap_or_else(|_| "{\"score\": 5, \"feedback\": \"QA failed\"}".into());
150 let (qa_feedback, qa_score) = parse_qa_response(&qa_resp);
151
152 let iter_time = iter_start.elapsed().as_secs_f64();
153 println!(
154 " [iter {}/{}] score {}/10 {} {:.1}s",
155 iter,
156 opts.iterations,
157 qa_score,
158 if validated { "PASS" } else { "FAIL" },
159 iter_time
160 );
161
162 let version = SwarmVersion {
163 iteration: iter,
164 plan,
165 code,
166 qa_feedback,
167 qa_score,
168 validated,
169 validation_output,
170 };
171
172 if qa_score >= 9 && validated {
174 println!(" Early exit: iter {} scored {}/10 PASS", iter, qa_score);
175 versions.push(version);
176 break;
177 }
178
179 versions.push(version);
180 }
181
182 let best_idx = select_best(&versions);
184 let best = &versions[best_idx];
185
186 let file_ext = match opts.language.as_str() {
188 "python" => "py",
189 "javascript" | "js" => "js",
190 "typescript" | "ts" => "ts",
191 "rust" => "rs",
192 "go" => "go",
193 "cpp" | "c++" => "cpp",
194 _ => "txt",
195 };
196 let file_name = format!("{}/main.{}", opts.output_dir, file_ext);
197 tokio::fs::write(&file_name, &best.code).await?;
198
199 let duration = start.elapsed().as_secs_f64();
200 println!(
201 "\nSwarm complete: best=iter{} (score {}/10 {}) | {:.1}s total",
202 best.iteration,
203 best.qa_score,
204 if best.validated { "PASS" } else { "FAIL" },
205 duration
206 );
207 println!("Output: {}", file_name);
208
209 Ok(())
210}
211
212fn select_best(versions: &[SwarmVersion]) -> usize {
213 versions
214 .iter()
215 .enumerate()
216 .max_by_key(|(_, v)| (v.validated as u32 * 100 + v.qa_score, v.iteration))
217 .map(|(i, _)| i)
218 .unwrap_or(0)
219}
220
221fn extract_code(response: &str, _language: &str) -> String {
222 let mut in_block = false;
223 let mut code = String::new();
224 for line in response.lines() {
225 if line.trim_start().starts_with("```") {
226 if in_block {
227 break;
228 }
229 in_block = true;
230 continue;
231 }
232 if in_block {
233 code.push_str(line);
234 code.push('\n');
235 }
236 }
237 if code.is_empty() {
238 response.to_string()
239 } else {
240 code
241 }
242}
243
244async fn run_validation(file_path: &str, language: &str) -> (bool, String) {
245 let cmd = match language {
246 "python" => format!(
247 "python3 -c \"import ast; ast.parse(open('{}').read()); print('SYNTAX OK')\"",
248 file_path
249 ),
250 "rust" => format!("rustc --edition 2021 {} -o /dev/null 2>&1", file_path),
251 "cpp" | "c++" => format!("c++ -std=c++17 -fsyntax-only {} 2>&1", file_path),
252 "go" => format!("go vet {} 2>&1", file_path),
253 _ => format!("test -f {} && echo 'FILE EXISTS'", file_path),
254 };
255
256 match tokio::process::Command::new("sh")
257 .arg("-c")
258 .arg(&cmd)
259 .output()
260 .await
261 {
262 Ok(output) => {
263 let stdout = String::from_utf8_lossy(&output.stdout).to_string();
264 let stderr = String::from_utf8_lossy(&output.stderr).to_string();
265 (output.status.success(), format!("{}{}", stdout, stderr))
266 }
267 Err(e) => (false, format!("Validation error: {}", e)),
268 }
269}
270
271fn parse_qa_response(text: &str) -> (String, u32) {
272 if let Some(start) = text.find('{') {
273 if let Some(end) = text.rfind('}') {
274 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&text[start..=end]) {
275 let score = parsed["score"].as_u64().unwrap_or(5) as u32;
276 let feedback = parsed["feedback"]
277 .as_str()
278 .unwrap_or("No feedback")
279 .to_string();
280 return (feedback, score.min(10));
281 }
282 }
283 }
284 (text.to_string(), 5)
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 #[test]
292 fn parse_qa_valid() {
293 let resp = r#"{"score": 7, "feedback": "Missing error handling"}"#;
294 let (feedback, score) = parse_qa_response(resp);
295 assert_eq!(score, 7);
296 assert!(feedback.contains("error handling"));
297 }
298
299 #[test]
300 fn parse_qa_garbage() {
301 let (_, score) = parse_qa_response("not json");
302 assert_eq!(score, 5);
303 }
304
305 #[test]
306 fn select_best_prefers_validated() {
307 let versions = vec![
308 SwarmVersion {
309 iteration: 1,
310 plan: String::new(),
311 code: "a".into(),
312 qa_feedback: String::new(),
313 qa_score: 9,
314 validated: false,
315 validation_output: String::new(),
316 },
317 SwarmVersion {
318 iteration: 2,
319 plan: String::new(),
320 code: "b".into(),
321 qa_feedback: String::new(),
322 qa_score: 6,
323 validated: true,
324 validation_output: String::new(),
325 },
326 ];
327 assert_eq!(select_best(&versions), 1);
328 }
329}