1use anyhow::Result;
6use serde::{Deserialize, Serialize};
7use serde_json::json;
8use std::path::Path;
9use std::time::Instant;
10
11use crate::llm::{ChatMessage, LlmClient, OllamaTool, OllamaToolFunction};
12use crate::swebench_tools;
13
14struct RepoConfig {
17 install_cmd: &'static str,
18 test_cmd_template: &'static str,
19 env_setup: &'static str,
20}
21
22fn get_repo_config(repo: &str) -> RepoConfig {
23 match repo {
24 "django/django" => RepoConfig {
25 install_cmd: "python3 -m pip install -e . -q 2>/dev/null",
26 test_cmd_template: "cd tests && python3 runtests.py {test} --verbosity=0 2>&1",
27 env_setup: "",
28 },
29 "sympy/sympy" => RepoConfig {
30 install_cmd: "python3 -m pip install -e . -q 2>/dev/null",
31 test_cmd_template: "python3 -m pytest -xvs -k {test} 2>&1",
32 env_setup: "",
33 },
34 "matplotlib/matplotlib" => RepoConfig {
35 install_cmd: "python3 -m pip install -e '.[dev]' -q 2>/dev/null",
36 test_cmd_template: "python3 -m pytest -xvs {test} 2>&1",
37 env_setup: "",
38 },
39 "scikit-learn/scikit-learn" => RepoConfig {
40 install_cmd: "python3 -m pip install -e . -q 2>/dev/null",
41 test_cmd_template: "python3 -m pytest -xvs {test} 2>&1",
42 env_setup: "",
43 },
44 "pytest-dev/pytest" => RepoConfig {
45 install_cmd: "python3 -m pip install -e . -q 2>/dev/null",
46 test_cmd_template: "python3 -m pytest -xvs {test} 2>&1",
47 env_setup: "",
48 },
49 "sphinx-doc/sphinx" => RepoConfig {
50 install_cmd: "python3 -m pip install -e '.[test]' -q 2>/dev/null",
51 test_cmd_template: "python3 -m pytest -xvs {test} 2>&1",
52 env_setup: "",
53 },
54 _ => RepoConfig {
55 install_cmd: "python3 -m pip install -e . -q 2>/dev/null",
56 test_cmd_template: "python3 -m pytest -xvs {test} 2>&1",
57 env_setup: "",
58 },
59 }
60}
61
62fn normalize_django_test(test: &str) -> String {
63 if let Some(paren_start) = test.find('(') {
64 test[paren_start + 1..]
65 .trim_end_matches(')')
66 .trim()
67 .to_string()
68 } else {
69 test.to_string()
70 }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct SwebenchInstance {
77 pub instance_id: String,
78 pub repo: String,
79 pub base_commit: String,
80 pub problem_statement: String,
81 #[serde(default)]
82 pub hints_text: Option<String>,
83 pub test_patch: String,
84 #[serde(rename = "FAIL_TO_PASS")]
85 pub fail_to_pass: serde_json::Value,
86 #[serde(rename = "PASS_TO_PASS", default)]
87 pub pass_to_pass: serde_json::Value,
88 #[serde(default)]
89 pub version: Option<String>,
90}
91
92impl SwebenchInstance {
93 pub fn fail_to_pass_tests(&self) -> Vec<String> {
94 parse_test_list(&self.fail_to_pass)
95 }
96 pub fn pass_to_pass_tests(&self) -> Vec<String> {
97 parse_test_list(&self.pass_to_pass)
98 }
99}
100
101fn parse_test_list(v: &serde_json::Value) -> Vec<String> {
102 match v {
103 serde_json::Value::Array(arr) => arr
104 .iter()
105 .filter_map(|v| v.as_str().map(String::from))
106 .collect(),
107 serde_json::Value::String(s) => {
108 if let Ok(arr) = serde_json::from_str::<Vec<String>>(s) {
109 arr
110 } else {
111 vec![s.clone()]
112 }
113 }
114 _ => vec![],
115 }
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct InstanceResult {
120 pub instance_id: String,
121 pub repo: String,
122 pub model: String,
123 pub resolved: bool,
124 pub fail_to_pass_count: usize,
125 pub fail_to_pass_passed: usize,
126 pub turns_used: u32,
127 pub tokens_used: u64,
128 pub duration_secs: f64,
129 pub files_modified: Vec<String>,
130 pub patch: String,
131 #[serde(skip_serializing_if = "Option::is_none")]
132 pub error: Option<String>,
133}
134
135#[derive(Debug, Serialize)]
136struct Prediction {
137 instance_id: String,
138 model_name_or_path: String,
139 model_patch: String,
140}
141
142#[derive(Debug, Clone)]
143pub struct SwebenchOpts {
144 pub dataset_path: Option<String>,
145 pub variant: String,
146 pub instance_filter: Option<String>,
147 pub limit: Option<u32>,
148 pub offset: u32,
149 pub output_dir: String,
150 pub model_override: Option<String>,
151 pub max_turns: u32,
152 pub timeout_secs: u64,
153 pub resume: bool,
154}
155
156impl Default for SwebenchOpts {
157 fn default() -> Self {
158 Self {
159 dataset_path: None,
160 variant: "lite".into(),
161 instance_filter: None,
162 limit: None,
163 offset: 0,
164 output_dir: ".battlecommand/swebench".into(),
165 model_override: None,
166 max_turns: 25,
167 timeout_secs: 1800,
168 resume: false,
169 }
170 }
171}
172
173pub fn load_dataset(opts: &SwebenchOpts) -> Result<Vec<SwebenchInstance>> {
176 let path = if let Some(ref p) = opts.dataset_path {
177 p.clone()
178 } else {
179 format!("{}/datasets/{}.json", opts.output_dir, opts.variant)
180 };
181
182 if !Path::new(&path).exists() {
183 return Err(anyhow::anyhow!(
184 "Dataset not found at '{}'. Download it first:\n \
185 wget -O {} https://raw.githubusercontent.com/princeton-nlp/SWE-bench/main/swebench/test/{}.json\n \
186 Or use --dataset <path> to specify a local file.",
187 path, path, opts.variant
188 ));
189 }
190
191 let data = std::fs::read_to_string(&path)?;
192 let instances: Vec<SwebenchInstance> = if data.trim_start().starts_with('[') {
193 serde_json::from_str(&data)?
194 } else {
195 data.lines()
196 .filter(|l| !l.trim().is_empty())
197 .map(serde_json::from_str)
198 .collect::<Result<Vec<_>, _>>()?
199 };
200
201 println!("Loaded {} instances from {}", instances.len(), path);
202
203 let mut filtered = instances;
204 if let Some(ref id) = opts.instance_filter {
205 filtered.retain(|i| i.instance_id == *id);
206 }
207 let start = opts.offset as usize;
208 if start > 0 && start < filtered.len() {
209 filtered = filtered[start..].to_vec();
210 }
211 if let Some(limit) = opts.limit {
212 filtered.truncate(limit as usize);
213 }
214 Ok(filtered)
215}
216
217pub fn list_instances(opts: &SwebenchOpts, repo_filter: Option<&str>) -> Result<()> {
218 let instances = load_dataset(opts)?;
219 let mut by_repo: std::collections::BTreeMap<String, Vec<String>> =
220 std::collections::BTreeMap::new();
221 for inst in &instances {
222 by_repo
223 .entry(inst.repo.clone())
224 .or_default()
225 .push(inst.instance_id.clone());
226 }
227
228 if let Some(repo) = repo_filter {
229 if let Some(ids) = by_repo.get(repo) {
230 println!("\n{} ({} instances):", repo, ids.len());
231 for id in ids {
232 println!(" {}", id);
233 }
234 } else {
235 println!("No instances found for repo '{}'", repo);
236 println!("Available repos:");
237 for r in by_repo.keys() {
238 println!(" {}", r);
239 }
240 }
241 } else {
242 println!(
243 "\n{} total instances across {} repos:\n",
244 instances.len(),
245 by_repo.len()
246 );
247 println!("{:<35} Instances", "Repository");
248 println!("{}", "-".repeat(50));
249 for (repo, ids) in &by_repo {
250 println!("{:<35} {}", repo, ids.len());
251 }
252 }
253 Ok(())
254}
255
256async fn setup_instance_workspace(instance: &SwebenchInstance, output_dir: &str) -> Result<String> {
259 let abs_output = std::fs::canonicalize(output_dir)
260 .unwrap_or_else(|_| std::path::PathBuf::from(output_dir))
261 .to_string_lossy()
262 .to_string();
263 let workspace = format!("{}/workspaces/{}", abs_output, instance.instance_id);
264
265 if Path::new(&format!("{}/.git", workspace)).exists() {
266 println!(" Workspace exists, resetting to base commit...");
267 let status = tokio::process::Command::new("git")
268 .args(["reset", "--hard", &instance.base_commit])
269 .current_dir(&workspace)
270 .output()
271 .await?;
272 if !status.status.success() {
273 return Err(anyhow::anyhow!("Failed to reset to base commit"));
274 }
275 let _ = tokio::process::Command::new("git")
276 .args(["clean", "-fd"])
277 .current_dir(&workspace)
278 .output()
279 .await;
280 return Ok(workspace);
281 }
282
283 let repo_slug = instance.repo.replace('/', "__");
284 let cache_dir = format!("{}/repos/{}", abs_output, repo_slug);
285
286 if !Path::new(&format!("{}/.git", cache_dir)).exists() {
287 println!(
288 " Cloning {} (first time, will be cached)...",
289 instance.repo
290 );
291 tokio::fs::create_dir_all(&format!("{}/repos", abs_output)).await?;
292 let clone_url = format!("https://github.com/{}.git", instance.repo);
293 let status = tokio::process::Command::new("git")
294 .args(["clone", "--quiet", &clone_url, &cache_dir])
295 .output()
296 .await?;
297 if !status.status.success() {
298 let stderr = String::from_utf8_lossy(&status.stderr);
299 return Err(anyhow::anyhow!(
300 "Failed to clone {}: {}",
301 instance.repo,
302 stderr
303 ));
304 }
305 }
306
307 println!(" Creating worktree at {}...", instance.instance_id);
308 tokio::fs::create_dir_all(&format!("{}/workspaces", abs_output)).await?;
309 let _ = tokio::process::Command::new("git")
310 .args(["fetch", "--quiet", "origin", &instance.base_commit])
311 .current_dir(&cache_dir)
312 .output()
313 .await;
314
315 let wt_status = tokio::process::Command::new("git")
316 .args([
317 "worktree",
318 "add",
319 "--detach",
320 &workspace,
321 &instance.base_commit,
322 ])
323 .current_dir(&cache_dir)
324 .output()
325 .await?;
326
327 if !wt_status.status.success() {
328 println!(" Worktree failed, falling back to direct clone...");
329 let _ = tokio::fs::remove_dir_all(&workspace).await;
330 tokio::fs::create_dir_all(&workspace).await?;
331 let status = tokio::process::Command::new("git")
332 .args(["clone", "--quiet", "--shared", &cache_dir, &workspace])
333 .output()
334 .await?;
335 if !status.status.success() {
336 return Err(anyhow::anyhow!(
337 "Failed to create workspace for {}",
338 instance.instance_id
339 ));
340 }
341 let status = tokio::process::Command::new("git")
342 .args(["checkout", "--quiet", &instance.base_commit])
343 .current_dir(&workspace)
344 .output()
345 .await?;
346 if !status.status.success() {
347 return Err(anyhow::anyhow!(
348 "Failed to checkout {}",
349 instance.base_commit
350 ));
351 }
352 }
353
354 if !instance.test_patch.is_empty() {
356 let test_patch_path = format!("{}/test_patch.diff", workspace);
357 tokio::fs::write(&test_patch_path, &instance.test_patch).await?;
358 let status = tokio::process::Command::new("git")
359 .args(["apply", "--allow-empty", "test_patch.diff"])
360 .current_dir(&workspace)
361 .output()
362 .await?;
363 if !status.status.success() {
364 let stderr = String::from_utf8_lossy(&status.stderr);
365 println!(
366 " Warning: test_patch apply failed (may be ok): {}",
367 stderr.trim()
368 );
369 }
370 let _ = tokio::fs::remove_file(&test_patch_path).await;
371 }
372
373 let marker = format!("{}/.battlecommand_deps_installed", workspace);
375 if !Path::new(&marker).exists() {
376 let config = get_repo_config(&instance.repo);
377 println!(" Installing dependencies for {}...", instance.repo);
378 let install = tokio::process::Command::new("sh")
379 .arg("-c")
380 .arg(config.install_cmd)
381 .current_dir(&workspace)
382 .output()
383 .await;
384 match install {
385 Ok(output) if output.status.success() => {
386 println!(" Dependencies installed.");
387 let _ = tokio::fs::write(&marker, "installed").await;
388 }
389 Ok(output) => {
390 let stderr = String::from_utf8_lossy(&output.stderr);
391 println!(
392 " Warning: dependency install may have failed: {}",
393 stderr.lines().last().unwrap_or("unknown")
394 );
395 let _ = tokio::fs::write(&marker, "attempted").await;
396 }
397 Err(e) => println!(" Warning: could not install dependencies: {}", e),
398 }
399 }
400
401 Ok(workspace)
402}
403
404fn swebench_tools_def() -> Vec<OllamaTool> {
407 vec![
408 OllamaTool { tool_type: "function".into(), function: OllamaToolFunction {
409 name: "read_file".into(),
410 description: "Read a file from the repository with line numbers. Max 200 lines per call.".into(),
411 parameters: json!({"type":"object","properties":{"path":{"type":"string","description":"File path relative to repo root"},"start_line":{"type":"integer","description":"Starting line (default 1)"},"end_line":{"type":"integer","description":"Ending line"}},"required":["path"]}),
412 }},
413 OllamaTool { tool_type: "function".into(), function: OllamaToolFunction {
414 name: "grep_search".into(),
415 description: "Search for text patterns in the repository. Returns matching lines with file paths and line numbers.".into(),
416 parameters: json!({"type":"object","properties":{"pattern":{"type":"string","description":"Text or regex pattern to search for"},"path":{"type":"string","description":"Directory to search in (default: root)"}},"required":["pattern"]}),
417 }},
418 OllamaTool { tool_type: "function".into(), function: OllamaToolFunction {
419 name: "list_directory".into(),
420 description: "List files and subdirectories. Shows file sizes.".into(),
421 parameters: json!({"type":"object","properties":{"path":{"type":"string","description":"Directory path (default: root)"}},"required":[]}),
422 }},
423 OllamaTool { tool_type: "function".into(), function: OllamaToolFunction {
424 name: "run_command".into(),
425 description: "Run a shell command in the repository directory. 30-second timeout.".into(),
426 parameters: json!({"type":"object","properties":{"command":{"type":"string","description":"Shell command to execute"}},"required":["command"]}),
427 }},
428 OllamaTool { tool_type: "function".into(), function: OllamaToolFunction {
429 name: "write_file".into(),
430 description: "Write complete file contents. Creates parent directories if needed.".into(),
431 parameters: json!({"type":"object","properties":{"path":{"type":"string","description":"File path relative to repo root"},"content":{"type":"string","description":"Complete file contents"}},"required":["path","content"]}),
432 }},
433 OllamaTool { tool_type: "function".into(), function: OllamaToolFunction {
434 name: "apply_edit".into(),
435 description: "Replace specific text in a file. Preferred for small changes (1-10 lines).".into(),
436 parameters: json!({"type":"object","properties":{"path":{"type":"string","description":"File path"},"old_text":{"type":"string","description":"Exact text to find"},"new_text":{"type":"string","description":"Replacement text"}},"required":["path","old_text","new_text"]}),
437 }},
438 OllamaTool { tool_type: "function".into(), function: OllamaToolFunction {
439 name: "submit".into(),
440 description: "Signal that you have finished fixing the bug.".into(),
441 parameters: json!({"type":"object","properties":{},"required":[]}),
442 }},
443 ]
444}
445
446fn build_system_prompt() -> String {
447 r#"You are an expert software engineer tasked with fixing a bug in a Python repository.
448
449You have tools to explore the codebase, understand the issue, and apply a fix.
450
451## WORKFLOW
4521. Read the issue carefully — identify what behavior is wrong and what is expected
4532. Search the codebase to find relevant files (grep for error messages, class names, function names)
4543. Read the relevant source files to understand the code structure
4554. Identify the root cause of the bug
4565. Make the MINIMAL change needed to fix the bug
4576. Run the failing test(s) to verify your fix works
4587. Call submit() when done
459
460## RULES
461- Make MINIMAL changes — fix the bug, do NOT refactor, add features, or change unrelated code
462- Prefer apply_edit for small changes (1-10 lines)
463- When using write_file, include the COMPLETE file contents
464- Always verify your fix by running tests before calling submit()
465- Do NOT modify test files — only fix the source code
466- Use `python3` (not `python`) for all commands"#
467 .to_string()
468}
469
470async fn run_agent_loop(
473 llm: &LlmClient,
474 instance: &SwebenchInstance,
475 workspace: &str,
476 max_turns: u32,
477) -> Result<(u32, u64)> {
478 let tools = swebench_tools_def();
479 let system_prompt = build_system_prompt();
480
481 let _repo_config = get_repo_config(&instance.repo);
482 let test_hint = if instance.repo == "django/django" {
483 let test_labels: Vec<String> = instance
484 .fail_to_pass_tests()
485 .iter()
486 .map(|t| normalize_django_test(t))
487 .collect();
488 format!("\n\n## How to run tests\nDjango project. Run tests with:\n```\ncd tests && python3 runtests.py {}\n```", test_labels.join(" "))
489 } else if instance.repo == "sympy/sympy" {
490 let tests = instance.fail_to_pass_tests();
491 format!("\n\n## How to run tests\nsympy project. Run tests with:\n```\npython3 -m pytest -xvs -k \"{}\"\n```", tests.join(" or "))
492 } else {
493 let tests = instance.fail_to_pass_tests();
494 format!(
495 "\n\n## How to run tests\nRun failing tests with:\n```\npython3 -m pytest -xvs {}\n```",
496 tests.join(" ")
497 )
498 };
499
500 let user_content = format!(
501 "## Issue to Fix\n\n{}\n\n## Failing Tests\nThe following tests should PASS after your fix:\n{}{}",
502 instance.problem_statement, instance.fail_to_pass_tests().join("\n"), test_hint,
503 );
504
505 let mut messages: Vec<ChatMessage> = vec![
506 ChatMessage {
507 role: "system".into(),
508 content: system_prompt,
509 tool_calls: None,
510 tool_call_id: None,
511 },
512 ChatMessage {
513 role: "user".into(),
514 content: user_content,
515 tool_calls: None,
516 tool_call_id: None,
517 },
518 ];
519
520 let mut turns_used: u32 = 0;
521 let mut write_turns: u32 = 0;
522
523 for turn in 0..max_turns {
524 turns_used = turn + 1;
525
526 let resp = llm.chat_with_tools(&messages, &tools).await?;
527
528 if !resp.tool_calls.is_empty() {
529 messages.push(ChatMessage {
530 role: "assistant".into(),
531 content: resp.content.clone(),
532 tool_calls: Some(resp.tool_calls.clone()),
533 tool_call_id: None,
534 });
535
536 for tc in &resp.tool_calls {
537 let result =
538 swebench_tools::execute(&tc.function.name, &tc.function.arguments, workspace)
539 .await;
540 println!(
541 " [turn {}] {} → {}",
542 turn + 1,
543 tc.function.name,
544 if result.success { "ok" } else { "FAIL" }
545 );
546
547 if result.is_submit {
548 messages.push(ChatMessage {
549 role: "tool".into(),
550 content: result.content,
551 tool_calls: None,
552 tool_call_id: Some(tc.function.name.clone()),
553 });
554 return Ok((turns_used, 0));
555 }
556 if result.is_write {
557 write_turns += 1;
558 }
559
560 let content = if result.content.len() > 4096 {
561 format!(
562 "{}...\n[truncated, {} chars total]",
563 &result.content[..result.content.len().min(4096)],
564 result.content.len()
565 )
566 } else {
567 result.content
568 };
569 messages.push(ChatMessage {
570 role: "tool".into(),
571 content,
572 tool_calls: None,
573 tool_call_id: Some(tc.function.name.clone()),
574 });
575 }
576
577 if turn >= 14 && write_turns == 0 {
578 messages.push(ChatMessage {
579 role: "system".into(),
580 content: "You have used 15 turns without making any code changes. Please apply your fix NOW using apply_edit or write_file. Then run tests and submit.".into(),
581 tool_calls: None, tool_call_id: None,
582 });
583 }
584
585 if turn >= 19 && messages.len() > 30 {
586 compact_messages(&mut messages);
587 }
588 continue;
589 }
590
591 messages.push(ChatMessage {
592 role: "assistant".into(),
593 content: resp.content.clone(),
594 tool_calls: None,
595 tool_call_id: None,
596 });
597
598 let lower = resp.content.to_lowercase();
599 if lower.contains("submit")
600 || lower.contains("finished")
601 || lower.contains("fix has been applied")
602 {
603 println!(" Agent indicated completion at turn {}", turn + 1);
604 return Ok((turns_used, 0));
605 }
606
607 messages.push(ChatMessage {
608 role: "user".into(),
609 content: "Please continue. Use the tools to explore the code and fix the bug. Call submit() when done.".into(),
610 tool_calls: None, tool_call_id: None,
611 });
612 }
613
614 println!(" Agent reached max turns ({})", max_turns);
615 Ok((turns_used, 0))
616}
617
618fn compact_messages(messages: &mut Vec<ChatMessage>) {
619 if messages.len() <= 12 {
620 return;
621 }
622 let keep_start = 2;
623 let keep_end = messages.len().saturating_sub(10);
624 if keep_end <= keep_start {
625 return;
626 }
627
628 let mut summary = String::from("## Previous exploration summary\n");
629 for msg in &messages[keep_start..keep_end] {
630 if msg.role == "tool" {
631 let preview = if msg.content.len() > 150 {
632 &msg.content[..150]
633 } else {
634 &msg.content
635 };
636 summary.push_str(&format!("- tool: {}\n", preview));
637 } else if msg.role == "assistant" {
638 let preview = if msg.content.len() > 200 {
639 &msg.content[..200]
640 } else {
641 &msg.content
642 };
643 summary.push_str(&format!("- thought: {}\n", preview));
644 }
645 }
646
647 let replacement = ChatMessage {
648 role: "system".into(),
649 content: summary,
650 tool_calls: None,
651 tool_call_id: None,
652 };
653 messages.splice(keep_start..keep_end, std::iter::once(replacement));
654}
655
656async fn generate_patch(workspace: &str) -> Result<(String, Vec<String>)> {
659 let output = tokio::process::Command::new("git")
660 .args(["diff", "HEAD"])
661 .current_dir(workspace)
662 .output()
663 .await?;
664 let diff = String::from_utf8_lossy(&output.stdout).to_string();
665 let files_output = tokio::process::Command::new("git")
666 .args(["diff", "--name-only", "HEAD"])
667 .current_dir(workspace)
668 .output()
669 .await?;
670 let files: Vec<String> = String::from_utf8_lossy(&files_output.stdout)
671 .lines()
672 .filter(|l| !l.is_empty())
673 .map(String::from)
674 .collect();
675 Ok((diff, files))
676}
677
678fn load_completed_ids(output_dir: &str) -> std::collections::HashSet<String> {
681 let path = format!("{}/swebench_results.jsonl", output_dir);
682 let mut ids = std::collections::HashSet::new();
683 if let Ok(data) = std::fs::read_to_string(&path) {
684 for line in data.lines() {
685 if let Ok(result) = serde_json::from_str::<InstanceResult>(line) {
686 ids.insert(result.instance_id);
687 }
688 }
689 }
690 ids
691}
692
693fn append_result(output_dir: &str, result: &InstanceResult) -> Result<()> {
694 use std::io::Write;
695 let path = format!("{}/swebench_results.jsonl", output_dir);
696 let mut file = std::fs::OpenOptions::new()
697 .create(true)
698 .append(true)
699 .open(&path)?;
700 writeln!(file, "{}", serde_json::to_string(result)?)?;
701 Ok(())
702}
703
704fn append_prediction(output_dir: &str, prediction: &Prediction) -> Result<()> {
705 use std::io::Write;
706 let path = format!("{}/predictions.jsonl", output_dir);
707 let mut file = std::fs::OpenOptions::new()
708 .create(true)
709 .append(true)
710 .open(&path)?;
711 writeln!(file, "{}", serde_json::to_string(prediction)?)?;
712 Ok(())
713}
714
715pub async fn run_single(
718 instance: &SwebenchInstance,
719 opts: &SwebenchOpts,
720) -> Result<InstanceResult> {
721 let model = opts
722 .model_override
723 .as_deref()
724 .unwrap_or("claude-sonnet-4-6");
725 let llm = LlmClient::with_limits(model, 65536, 8192);
726
727 println!("\n[SWE-bench] Instance: {}", instance.instance_id);
728 println!(
729 " Repo: {} @ {}",
730 instance.repo,
731 &instance.base_commit[..7.min(instance.base_commit.len())]
732 );
733 println!(
734 " Tests: {} FAIL_TO_PASS",
735 instance.fail_to_pass_tests().len()
736 );
737 println!(" Model: {}", model);
738
739 let workspace = setup_instance_workspace(instance, &opts.output_dir).await?;
740
741 let start = Instant::now();
742 let (turns, tokens) = match tokio::time::timeout(
743 std::time::Duration::from_secs(opts.timeout_secs),
744 run_agent_loop(&llm, instance, &workspace, opts.max_turns),
745 )
746 .await
747 {
748 Ok(Ok(result)) => result,
749 Ok(Err(e)) => {
750 return Ok(InstanceResult {
751 instance_id: instance.instance_id.clone(),
752 repo: instance.repo.clone(),
753 model: model.to_string(),
754 resolved: false,
755 fail_to_pass_count: instance.fail_to_pass_tests().len(),
756 fail_to_pass_passed: 0,
757 turns_used: 0,
758 tokens_used: 0,
759 duration_secs: start.elapsed().as_secs_f64(),
760 files_modified: vec![],
761 patch: String::new(),
762 error: Some(format!("Agent error: {}", e)),
763 });
764 }
765 Err(_) => {
766 return Ok(InstanceResult {
767 instance_id: instance.instance_id.clone(),
768 repo: instance.repo.clone(),
769 model: model.to_string(),
770 resolved: false,
771 fail_to_pass_count: instance.fail_to_pass_tests().len(),
772 fail_to_pass_passed: 0,
773 turns_used: opts.max_turns,
774 tokens_used: 0,
775 duration_secs: start.elapsed().as_secs_f64(),
776 files_modified: vec![],
777 patch: String::new(),
778 error: Some(format!("Timeout after {}s", opts.timeout_secs)),
779 });
780 }
781 };
782 let duration = start.elapsed().as_secs_f64();
783
784 let (patch, files_modified) = generate_patch(&workspace).await?;
785
786 let fail_tests = instance.fail_to_pass_tests();
788 let repo_config = get_repo_config(&instance.repo);
789 let mut passed = 0;
790 for test in &fail_tests {
791 let normalized = if instance.repo == "django/django" {
792 normalize_django_test(test)
793 } else {
794 test.to_string()
795 };
796 let test_cmd = format!(
797 "{}{}",
798 repo_config.env_setup,
799 repo_config.test_cmd_template.replace("{test}", &normalized)
800 );
801 let test_result = tokio::process::Command::new("sh")
802 .arg("-c")
803 .arg(&test_cmd)
804 .current_dir(&workspace)
805 .env("PYTHONDONTWRITEBYTECODE", "1")
806 .output()
807 .await;
808 if let Ok(output) = test_result {
809 if output.status.success() {
810 passed += 1;
811 println!(" PASS: {}", test);
812 } else {
813 let stdout = String::from_utf8_lossy(&output.stdout);
814 let stderr = String::from_utf8_lossy(&output.stderr);
815 let last = stdout
816 .lines()
817 .rev()
818 .chain(stderr.lines().rev())
819 .find(|l| !l.trim().is_empty() && !l.starts_with("="))
820 .unwrap_or("unknown error");
821 println!(" FAIL: {} — {}", test, last);
822 }
823 } else {
824 println!(" ERROR: {} — could not execute test command", test);
825 }
826 }
827
828 let resolved = passed == fail_tests.len() && !fail_tests.is_empty();
829 println!(
830 " Result: {} ({}/{} tests) | {} turns | {:.0}s",
831 if resolved { "RESOLVED" } else { "FAILED" },
832 passed,
833 fail_tests.len(),
834 turns,
835 duration
836 );
837
838 Ok(InstanceResult {
839 instance_id: instance.instance_id.clone(),
840 repo: instance.repo.clone(),
841 model: model.to_string(),
842 resolved,
843 fail_to_pass_count: fail_tests.len(),
844 fail_to_pass_passed: passed,
845 turns_used: turns,
846 tokens_used: tokens,
847 duration_secs: duration,
848 files_modified,
849 patch,
850 error: None,
851 })
852}
853
854pub async fn run_batch(opts: &SwebenchOpts) -> Result<()> {
855 tokio::fs::create_dir_all(&opts.output_dir).await?;
856 tokio::fs::create_dir_all(&format!("{}/datasets", opts.output_dir)).await?;
857
858 let instances = load_dataset(opts)?;
859 if instances.is_empty() {
860 println!("No instances to run.");
861 return Ok(());
862 }
863
864 let completed = if opts.resume {
865 let ids = load_completed_ids(&opts.output_dir);
866 if !ids.is_empty() {
867 println!("Resuming: {} instances already completed", ids.len());
868 }
869 ids
870 } else {
871 std::collections::HashSet::new()
872 };
873
874 let total = instances.len();
875 let mut resolved_count = 0u32;
876 let mut completed_count = 0u32;
877 let mut total_duration = 0.0f64;
878 let mut total_tokens = 0u64;
879
880 println!("\n========================================");
881 println!(" SWE-bench Run: {} instances", total);
882 println!("========================================\n");
883
884 for (i, instance) in instances.iter().enumerate() {
885 if completed.contains(&instance.instance_id) {
886 println!(
887 "[{}/{}] SKIP {} (already completed)",
888 i + 1,
889 total,
890 instance.instance_id
891 );
892 continue;
893 }
894
895 println!("\n[{}/{}] {}", i + 1, total, instance.instance_id);
896
897 let result = match run_single(instance, opts).await {
898 Ok(r) => r,
899 Err(e) => {
900 println!(" ERROR: {}", e);
901 InstanceResult {
902 instance_id: instance.instance_id.clone(),
903 repo: instance.repo.clone(),
904 model: opts
905 .model_override
906 .clone()
907 .unwrap_or_else(|| "default".into()),
908 resolved: false,
909 fail_to_pass_count: instance.fail_to_pass_tests().len(),
910 fail_to_pass_passed: 0,
911 turns_used: 0,
912 tokens_used: 0,
913 duration_secs: 0.0,
914 files_modified: vec![],
915 patch: String::new(),
916 error: Some(format!("{}", e)),
917 }
918 }
919 };
920
921 if result.resolved {
922 resolved_count += 1;
923 }
924 completed_count += 1;
925 total_duration += result.duration_secs;
926 total_tokens += result.tokens_used;
927
928 if let Err(e) = append_result(&opts.output_dir, &result) {
929 eprintln!("Warning: failed to save result: {}", e);
930 }
931 if let Err(e) = append_prediction(
932 &opts.output_dir,
933 &Prediction {
934 instance_id: result.instance_id.clone(),
935 model_name_or_path: result.model.clone(),
936 model_patch: result.patch.clone(),
937 },
938 ) {
939 eprintln!("Warning: failed to save prediction: {}", e);
940 }
941
942 let rate = if completed_count > 0 {
943 resolved_count as f64 / completed_count as f64 * 100.0
944 } else {
945 0.0
946 };
947 println!(
948 "\n Progress: {}/{} completed | {}/{} resolved ({:.1}%) | {:.0}s avg",
949 completed_count,
950 total,
951 resolved_count,
952 completed_count,
953 rate,
954 if completed_count > 0 {
955 total_duration / completed_count as f64
956 } else {
957 0.0
958 }
959 );
960 }
961
962 println!("\n========================================");
963 println!(" SWE-bench Run Complete");
964 println!("========================================");
965 let rate = if completed_count > 0 {
966 resolved_count as f64 / completed_count as f64 * 100.0
967 } else {
968 0.0
969 };
970 println!(
971 " Resolved: {}/{} ({:.1}%)",
972 resolved_count, completed_count, rate
973 );
974 println!(" Total tokens: {}", total_tokens);
975 println!(
976 " Total time: {:.0}s ({:.0}s avg)",
977 total_duration,
978 if completed_count > 0 {
979 total_duration / completed_count as f64
980 } else {
981 0.0
982 }
983 );
984 println!(" Results: {}/swebench_results.jsonl", opts.output_dir);
985 println!(" Predictions: {}/predictions.jsonl", opts.output_dir);
986 Ok(())
987}
988
989#[cfg(test)]
990mod tests {
991 use super::*;
992
993 #[test]
994 fn test_parse_test_list_array() {
995 let v = json!(["test_foo.py::TestBar", "test_baz.py::TestQux"]);
996 let tests = parse_test_list(&v);
997 assert_eq!(tests.len(), 2);
998 }
999
1000 #[test]
1001 fn test_parse_test_list_string() {
1002 let v = json!("[\"test_foo.py::TestBar\", \"test_baz.py::TestQux\"]");
1003 let tests = parse_test_list(&v);
1004 assert_eq!(tests.len(), 2);
1005 }
1006
1007 #[test]
1008 fn test_parse_test_list_single() {
1009 let v = json!("test_foo.py::TestBar");
1010 let tests = parse_test_list(&v);
1011 assert_eq!(tests.len(), 1);
1012 }
1013}