Skip to main content

battlecommand_forge/
swebench.rs

1/// SWE-bench evaluation framework for BattleCommand Forge.
2/// Runs against real GitHub issues from the SWE-bench dataset.
3/// ReAct agent loop explores repos and generates patches, then validates
4/// against the repo's test suite.
5use anyhow::Result;
6use serde::{Deserialize, Serialize};
7use serde_json::json;
8use std::path::Path;
9use std::time::Instant;
10
11use crate::llm::{ChatMessage, LlmClient, OllamaTool, OllamaToolFunction};
12use crate::swebench_tools;
13
14// ─── Per-Repo Configuration ───
15
16struct RepoConfig {
17    install_cmd: &'static str,
18    test_cmd_template: &'static str,
19    env_setup: &'static str,
20}
21
22fn get_repo_config(repo: &str) -> RepoConfig {
23    match repo {
24        "django/django" => RepoConfig {
25            install_cmd: "python3 -m pip install -e . -q 2>/dev/null",
26            test_cmd_template: "cd tests && python3 runtests.py {test} --verbosity=0 2>&1",
27            env_setup: "",
28        },
29        "sympy/sympy" => RepoConfig {
30            install_cmd: "python3 -m pip install -e . -q 2>/dev/null",
31            test_cmd_template: "python3 -m pytest -xvs -k {test} 2>&1",
32            env_setup: "",
33        },
34        "matplotlib/matplotlib" => RepoConfig {
35            install_cmd: "python3 -m pip install -e '.[dev]' -q 2>/dev/null",
36            test_cmd_template: "python3 -m pytest -xvs {test} 2>&1",
37            env_setup: "",
38        },
39        "scikit-learn/scikit-learn" => RepoConfig {
40            install_cmd: "python3 -m pip install -e . -q 2>/dev/null",
41            test_cmd_template: "python3 -m pytest -xvs {test} 2>&1",
42            env_setup: "",
43        },
44        "pytest-dev/pytest" => RepoConfig {
45            install_cmd: "python3 -m pip install -e . -q 2>/dev/null",
46            test_cmd_template: "python3 -m pytest -xvs {test} 2>&1",
47            env_setup: "",
48        },
49        "sphinx-doc/sphinx" => RepoConfig {
50            install_cmd: "python3 -m pip install -e '.[test]' -q 2>/dev/null",
51            test_cmd_template: "python3 -m pytest -xvs {test} 2>&1",
52            env_setup: "",
53        },
54        _ => RepoConfig {
55            install_cmd: "python3 -m pip install -e . -q 2>/dev/null",
56            test_cmd_template: "python3 -m pytest -xvs {test} 2>&1",
57            env_setup: "",
58        },
59    }
60}
61
62fn normalize_django_test(test: &str) -> String {
63    if let Some(paren_start) = test.find('(') {
64        test[paren_start + 1..]
65            .trim_end_matches(')')
66            .trim()
67            .to_string()
68    } else {
69        test.to_string()
70    }
71}
72
73// ─── Types ───
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct SwebenchInstance {
77    pub instance_id: String,
78    pub repo: String,
79    pub base_commit: String,
80    pub problem_statement: String,
81    #[serde(default)]
82    pub hints_text: Option<String>,
83    pub test_patch: String,
84    #[serde(rename = "FAIL_TO_PASS")]
85    pub fail_to_pass: serde_json::Value,
86    #[serde(rename = "PASS_TO_PASS", default)]
87    pub pass_to_pass: serde_json::Value,
88    #[serde(default)]
89    pub version: Option<String>,
90}
91
92impl SwebenchInstance {
93    pub fn fail_to_pass_tests(&self) -> Vec<String> {
94        parse_test_list(&self.fail_to_pass)
95    }
96    pub fn pass_to_pass_tests(&self) -> Vec<String> {
97        parse_test_list(&self.pass_to_pass)
98    }
99}
100
101fn parse_test_list(v: &serde_json::Value) -> Vec<String> {
102    match v {
103        serde_json::Value::Array(arr) => arr
104            .iter()
105            .filter_map(|v| v.as_str().map(String::from))
106            .collect(),
107        serde_json::Value::String(s) => {
108            if let Ok(arr) = serde_json::from_str::<Vec<String>>(s) {
109                arr
110            } else {
111                vec![s.clone()]
112            }
113        }
114        _ => vec![],
115    }
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct InstanceResult {
120    pub instance_id: String,
121    pub repo: String,
122    pub model: String,
123    pub resolved: bool,
124    pub fail_to_pass_count: usize,
125    pub fail_to_pass_passed: usize,
126    pub turns_used: u32,
127    pub tokens_used: u64,
128    pub duration_secs: f64,
129    pub files_modified: Vec<String>,
130    pub patch: String,
131    #[serde(skip_serializing_if = "Option::is_none")]
132    pub error: Option<String>,
133}
134
135#[derive(Debug, Serialize)]
136struct Prediction {
137    instance_id: String,
138    model_name_or_path: String,
139    model_patch: String,
140}
141
142#[derive(Debug, Clone)]
143pub struct SwebenchOpts {
144    pub dataset_path: Option<String>,
145    pub variant: String,
146    pub instance_filter: Option<String>,
147    pub limit: Option<u32>,
148    pub offset: u32,
149    pub output_dir: String,
150    pub model_override: Option<String>,
151    pub max_turns: u32,
152    pub timeout_secs: u64,
153    pub resume: bool,
154}
155
156impl Default for SwebenchOpts {
157    fn default() -> Self {
158        Self {
159            dataset_path: None,
160            variant: "lite".into(),
161            instance_filter: None,
162            limit: None,
163            offset: 0,
164            output_dir: ".battlecommand/swebench".into(),
165            model_override: None,
166            max_turns: 25,
167            timeout_secs: 1800,
168            resume: false,
169        }
170    }
171}
172
173// ─── Dataset Loading ───
174
175pub fn load_dataset(opts: &SwebenchOpts) -> Result<Vec<SwebenchInstance>> {
176    let path = if let Some(ref p) = opts.dataset_path {
177        p.clone()
178    } else {
179        format!("{}/datasets/{}.json", opts.output_dir, opts.variant)
180    };
181
182    if !Path::new(&path).exists() {
183        return Err(anyhow::anyhow!(
184            "Dataset not found at '{}'. Download it first:\n  \
185             wget -O {} https://raw.githubusercontent.com/princeton-nlp/SWE-bench/main/swebench/test/{}.json\n  \
186             Or use --dataset <path> to specify a local file.",
187            path, path, opts.variant
188        ));
189    }
190
191    let data = std::fs::read_to_string(&path)?;
192    let instances: Vec<SwebenchInstance> = if data.trim_start().starts_with('[') {
193        serde_json::from_str(&data)?
194    } else {
195        data.lines()
196            .filter(|l| !l.trim().is_empty())
197            .map(serde_json::from_str)
198            .collect::<Result<Vec<_>, _>>()?
199    };
200
201    println!("Loaded {} instances from {}", instances.len(), path);
202
203    let mut filtered = instances;
204    if let Some(ref id) = opts.instance_filter {
205        filtered.retain(|i| i.instance_id == *id);
206    }
207    let start = opts.offset as usize;
208    if start > 0 && start < filtered.len() {
209        filtered = filtered[start..].to_vec();
210    }
211    if let Some(limit) = opts.limit {
212        filtered.truncate(limit as usize);
213    }
214    Ok(filtered)
215}
216
217pub fn list_instances(opts: &SwebenchOpts, repo_filter: Option<&str>) -> Result<()> {
218    let instances = load_dataset(opts)?;
219    let mut by_repo: std::collections::BTreeMap<String, Vec<String>> =
220        std::collections::BTreeMap::new();
221    for inst in &instances {
222        by_repo
223            .entry(inst.repo.clone())
224            .or_default()
225            .push(inst.instance_id.clone());
226    }
227
228    if let Some(repo) = repo_filter {
229        if let Some(ids) = by_repo.get(repo) {
230            println!("\n{} ({} instances):", repo, ids.len());
231            for id in ids {
232                println!("  {}", id);
233            }
234        } else {
235            println!("No instances found for repo '{}'", repo);
236            println!("Available repos:");
237            for r in by_repo.keys() {
238                println!("  {}", r);
239            }
240        }
241    } else {
242        println!(
243            "\n{} total instances across {} repos:\n",
244            instances.len(),
245            by_repo.len()
246        );
247        println!("{:<35} Instances", "Repository");
248        println!("{}", "-".repeat(50));
249        for (repo, ids) in &by_repo {
250            println!("{:<35} {}", repo, ids.len());
251        }
252    }
253    Ok(())
254}
255
256// ─── Workspace Setup ───
257
258async fn setup_instance_workspace(instance: &SwebenchInstance, output_dir: &str) -> Result<String> {
259    let abs_output = std::fs::canonicalize(output_dir)
260        .unwrap_or_else(|_| std::path::PathBuf::from(output_dir))
261        .to_string_lossy()
262        .to_string();
263    let workspace = format!("{}/workspaces/{}", abs_output, instance.instance_id);
264
265    if Path::new(&format!("{}/.git", workspace)).exists() {
266        println!("  Workspace exists, resetting to base commit...");
267        let status = tokio::process::Command::new("git")
268            .args(["reset", "--hard", &instance.base_commit])
269            .current_dir(&workspace)
270            .output()
271            .await?;
272        if !status.status.success() {
273            return Err(anyhow::anyhow!("Failed to reset to base commit"));
274        }
275        let _ = tokio::process::Command::new("git")
276            .args(["clean", "-fd"])
277            .current_dir(&workspace)
278            .output()
279            .await;
280        return Ok(workspace);
281    }
282
283    let repo_slug = instance.repo.replace('/', "__");
284    let cache_dir = format!("{}/repos/{}", abs_output, repo_slug);
285
286    if !Path::new(&format!("{}/.git", cache_dir)).exists() {
287        println!(
288            "  Cloning {} (first time, will be cached)...",
289            instance.repo
290        );
291        tokio::fs::create_dir_all(&format!("{}/repos", abs_output)).await?;
292        let clone_url = format!("https://github.com/{}.git", instance.repo);
293        let status = tokio::process::Command::new("git")
294            .args(["clone", "--quiet", &clone_url, &cache_dir])
295            .output()
296            .await?;
297        if !status.status.success() {
298            let stderr = String::from_utf8_lossy(&status.stderr);
299            return Err(anyhow::anyhow!(
300                "Failed to clone {}: {}",
301                instance.repo,
302                stderr
303            ));
304        }
305    }
306
307    println!("  Creating worktree at {}...", instance.instance_id);
308    tokio::fs::create_dir_all(&format!("{}/workspaces", abs_output)).await?;
309    let _ = tokio::process::Command::new("git")
310        .args(["fetch", "--quiet", "origin", &instance.base_commit])
311        .current_dir(&cache_dir)
312        .output()
313        .await;
314
315    let wt_status = tokio::process::Command::new("git")
316        .args([
317            "worktree",
318            "add",
319            "--detach",
320            &workspace,
321            &instance.base_commit,
322        ])
323        .current_dir(&cache_dir)
324        .output()
325        .await?;
326
327    if !wt_status.status.success() {
328        println!("  Worktree failed, falling back to direct clone...");
329        let _ = tokio::fs::remove_dir_all(&workspace).await;
330        tokio::fs::create_dir_all(&workspace).await?;
331        let status = tokio::process::Command::new("git")
332            .args(["clone", "--quiet", "--shared", &cache_dir, &workspace])
333            .output()
334            .await?;
335        if !status.status.success() {
336            return Err(anyhow::anyhow!(
337                "Failed to create workspace for {}",
338                instance.instance_id
339            ));
340        }
341        let status = tokio::process::Command::new("git")
342            .args(["checkout", "--quiet", &instance.base_commit])
343            .current_dir(&workspace)
344            .output()
345            .await?;
346        if !status.status.success() {
347            return Err(anyhow::anyhow!(
348                "Failed to checkout {}",
349                instance.base_commit
350            ));
351        }
352    }
353
354    // Apply test patch
355    if !instance.test_patch.is_empty() {
356        let test_patch_path = format!("{}/test_patch.diff", workspace);
357        tokio::fs::write(&test_patch_path, &instance.test_patch).await?;
358        let status = tokio::process::Command::new("git")
359            .args(["apply", "--allow-empty", "test_patch.diff"])
360            .current_dir(&workspace)
361            .output()
362            .await?;
363        if !status.status.success() {
364            let stderr = String::from_utf8_lossy(&status.stderr);
365            println!(
366                "  Warning: test_patch apply failed (may be ok): {}",
367                stderr.trim()
368            );
369        }
370        let _ = tokio::fs::remove_file(&test_patch_path).await;
371    }
372
373    // Install dependencies
374    let marker = format!("{}/.battlecommand_deps_installed", workspace);
375    if !Path::new(&marker).exists() {
376        let config = get_repo_config(&instance.repo);
377        println!("  Installing dependencies for {}...", instance.repo);
378        let install = tokio::process::Command::new("sh")
379            .arg("-c")
380            .arg(config.install_cmd)
381            .current_dir(&workspace)
382            .output()
383            .await;
384        match install {
385            Ok(output) if output.status.success() => {
386                println!("  Dependencies installed.");
387                let _ = tokio::fs::write(&marker, "installed").await;
388            }
389            Ok(output) => {
390                let stderr = String::from_utf8_lossy(&output.stderr);
391                println!(
392                    "  Warning: dependency install may have failed: {}",
393                    stderr.lines().last().unwrap_or("unknown")
394                );
395                let _ = tokio::fs::write(&marker, "attempted").await;
396            }
397            Err(e) => println!("  Warning: could not install dependencies: {}", e),
398        }
399    }
400
401    Ok(workspace)
402}
403
404// ─── Tool Definitions ───
405
406fn swebench_tools_def() -> Vec<OllamaTool> {
407    vec![
408        OllamaTool { tool_type: "function".into(), function: OllamaToolFunction {
409            name: "read_file".into(),
410            description: "Read a file from the repository with line numbers. Max 200 lines per call.".into(),
411            parameters: json!({"type":"object","properties":{"path":{"type":"string","description":"File path relative to repo root"},"start_line":{"type":"integer","description":"Starting line (default 1)"},"end_line":{"type":"integer","description":"Ending line"}},"required":["path"]}),
412        }},
413        OllamaTool { tool_type: "function".into(), function: OllamaToolFunction {
414            name: "grep_search".into(),
415            description: "Search for text patterns in the repository. Returns matching lines with file paths and line numbers.".into(),
416            parameters: json!({"type":"object","properties":{"pattern":{"type":"string","description":"Text or regex pattern to search for"},"path":{"type":"string","description":"Directory to search in (default: root)"}},"required":["pattern"]}),
417        }},
418        OllamaTool { tool_type: "function".into(), function: OllamaToolFunction {
419            name: "list_directory".into(),
420            description: "List files and subdirectories. Shows file sizes.".into(),
421            parameters: json!({"type":"object","properties":{"path":{"type":"string","description":"Directory path (default: root)"}},"required":[]}),
422        }},
423        OllamaTool { tool_type: "function".into(), function: OllamaToolFunction {
424            name: "run_command".into(),
425            description: "Run a shell command in the repository directory. 30-second timeout.".into(),
426            parameters: json!({"type":"object","properties":{"command":{"type":"string","description":"Shell command to execute"}},"required":["command"]}),
427        }},
428        OllamaTool { tool_type: "function".into(), function: OllamaToolFunction {
429            name: "write_file".into(),
430            description: "Write complete file contents. Creates parent directories if needed.".into(),
431            parameters: json!({"type":"object","properties":{"path":{"type":"string","description":"File path relative to repo root"},"content":{"type":"string","description":"Complete file contents"}},"required":["path","content"]}),
432        }},
433        OllamaTool { tool_type: "function".into(), function: OllamaToolFunction {
434            name: "apply_edit".into(),
435            description: "Replace specific text in a file. Preferred for small changes (1-10 lines).".into(),
436            parameters: json!({"type":"object","properties":{"path":{"type":"string","description":"File path"},"old_text":{"type":"string","description":"Exact text to find"},"new_text":{"type":"string","description":"Replacement text"}},"required":["path","old_text","new_text"]}),
437        }},
438        OllamaTool { tool_type: "function".into(), function: OllamaToolFunction {
439            name: "submit".into(),
440            description: "Signal that you have finished fixing the bug.".into(),
441            parameters: json!({"type":"object","properties":{},"required":[]}),
442        }},
443    ]
444}
445
446fn build_system_prompt() -> String {
447    r#"You are an expert software engineer tasked with fixing a bug in a Python repository.
448
449You have tools to explore the codebase, understand the issue, and apply a fix.
450
451## WORKFLOW
4521. Read the issue carefully — identify what behavior is wrong and what is expected
4532. Search the codebase to find relevant files (grep for error messages, class names, function names)
4543. Read the relevant source files to understand the code structure
4554. Identify the root cause of the bug
4565. Make the MINIMAL change needed to fix the bug
4576. Run the failing test(s) to verify your fix works
4587. Call submit() when done
459
460## RULES
461- Make MINIMAL changes — fix the bug, do NOT refactor, add features, or change unrelated code
462- Prefer apply_edit for small changes (1-10 lines)
463- When using write_file, include the COMPLETE file contents
464- Always verify your fix by running tests before calling submit()
465- Do NOT modify test files — only fix the source code
466- Use `python3` (not `python`) for all commands"#
467        .to_string()
468}
469
470// ─── ReAct Agent Loop ───
471
472async fn run_agent_loop(
473    llm: &LlmClient,
474    instance: &SwebenchInstance,
475    workspace: &str,
476    max_turns: u32,
477) -> Result<(u32, u64)> {
478    let tools = swebench_tools_def();
479    let system_prompt = build_system_prompt();
480
481    let _repo_config = get_repo_config(&instance.repo);
482    let test_hint = if instance.repo == "django/django" {
483        let test_labels: Vec<String> = instance
484            .fail_to_pass_tests()
485            .iter()
486            .map(|t| normalize_django_test(t))
487            .collect();
488        format!("\n\n## How to run tests\nDjango project. Run tests with:\n```\ncd tests && python3 runtests.py {}\n```", test_labels.join(" "))
489    } else if instance.repo == "sympy/sympy" {
490        let tests = instance.fail_to_pass_tests();
491        format!("\n\n## How to run tests\nsympy project. Run tests with:\n```\npython3 -m pytest -xvs -k \"{}\"\n```", tests.join(" or "))
492    } else {
493        let tests = instance.fail_to_pass_tests();
494        format!(
495            "\n\n## How to run tests\nRun failing tests with:\n```\npython3 -m pytest -xvs {}\n```",
496            tests.join(" ")
497        )
498    };
499
500    let user_content = format!(
501        "## Issue to Fix\n\n{}\n\n## Failing Tests\nThe following tests should PASS after your fix:\n{}{}",
502        instance.problem_statement, instance.fail_to_pass_tests().join("\n"), test_hint,
503    );
504
505    let mut messages: Vec<ChatMessage> = vec![
506        ChatMessage {
507            role: "system".into(),
508            content: system_prompt,
509            tool_calls: None,
510            tool_call_id: None,
511        },
512        ChatMessage {
513            role: "user".into(),
514            content: user_content,
515            tool_calls: None,
516            tool_call_id: None,
517        },
518    ];
519
520    let mut turns_used: u32 = 0;
521    let mut write_turns: u32 = 0;
522
523    for turn in 0..max_turns {
524        turns_used = turn + 1;
525
526        let resp = llm.chat_with_tools(&messages, &tools).await?;
527
528        if !resp.tool_calls.is_empty() {
529            messages.push(ChatMessage {
530                role: "assistant".into(),
531                content: resp.content.clone(),
532                tool_calls: Some(resp.tool_calls.clone()),
533                tool_call_id: None,
534            });
535
536            for tc in &resp.tool_calls {
537                let result =
538                    swebench_tools::execute(&tc.function.name, &tc.function.arguments, workspace)
539                        .await;
540                println!(
541                    "    [turn {}] {} → {}",
542                    turn + 1,
543                    tc.function.name,
544                    if result.success { "ok" } else { "FAIL" }
545                );
546
547                if result.is_submit {
548                    messages.push(ChatMessage {
549                        role: "tool".into(),
550                        content: result.content,
551                        tool_calls: None,
552                        tool_call_id: Some(tc.function.name.clone()),
553                    });
554                    return Ok((turns_used, 0));
555                }
556                if result.is_write {
557                    write_turns += 1;
558                }
559
560                let content = if result.content.len() > 4096 {
561                    format!(
562                        "{}...\n[truncated, {} chars total]",
563                        &result.content[..result.content.len().min(4096)],
564                        result.content.len()
565                    )
566                } else {
567                    result.content
568                };
569                messages.push(ChatMessage {
570                    role: "tool".into(),
571                    content,
572                    tool_calls: None,
573                    tool_call_id: Some(tc.function.name.clone()),
574                });
575            }
576
577            if turn >= 14 && write_turns == 0 {
578                messages.push(ChatMessage {
579                    role: "system".into(),
580                    content: "You have used 15 turns without making any code changes. Please apply your fix NOW using apply_edit or write_file. Then run tests and submit.".into(),
581                    tool_calls: None, tool_call_id: None,
582                });
583            }
584
585            if turn >= 19 && messages.len() > 30 {
586                compact_messages(&mut messages);
587            }
588            continue;
589        }
590
591        messages.push(ChatMessage {
592            role: "assistant".into(),
593            content: resp.content.clone(),
594            tool_calls: None,
595            tool_call_id: None,
596        });
597
598        let lower = resp.content.to_lowercase();
599        if lower.contains("submit")
600            || lower.contains("finished")
601            || lower.contains("fix has been applied")
602        {
603            println!("    Agent indicated completion at turn {}", turn + 1);
604            return Ok((turns_used, 0));
605        }
606
607        messages.push(ChatMessage {
608            role: "user".into(),
609            content: "Please continue. Use the tools to explore the code and fix the bug. Call submit() when done.".into(),
610            tool_calls: None, tool_call_id: None,
611        });
612    }
613
614    println!("    Agent reached max turns ({})", max_turns);
615    Ok((turns_used, 0))
616}
617
618fn compact_messages(messages: &mut Vec<ChatMessage>) {
619    if messages.len() <= 12 {
620        return;
621    }
622    let keep_start = 2;
623    let keep_end = messages.len().saturating_sub(10);
624    if keep_end <= keep_start {
625        return;
626    }
627
628    let mut summary = String::from("## Previous exploration summary\n");
629    for msg in &messages[keep_start..keep_end] {
630        if msg.role == "tool" {
631            let preview = if msg.content.len() > 150 {
632                &msg.content[..150]
633            } else {
634                &msg.content
635            };
636            summary.push_str(&format!("- tool: {}\n", preview));
637        } else if msg.role == "assistant" {
638            let preview = if msg.content.len() > 200 {
639                &msg.content[..200]
640            } else {
641                &msg.content
642            };
643            summary.push_str(&format!("- thought: {}\n", preview));
644        }
645    }
646
647    let replacement = ChatMessage {
648        role: "system".into(),
649        content: summary,
650        tool_calls: None,
651        tool_call_id: None,
652    };
653    messages.splice(keep_start..keep_end, std::iter::once(replacement));
654}
655
656// ─── Patch Generation ───
657
658async fn generate_patch(workspace: &str) -> Result<(String, Vec<String>)> {
659    let output = tokio::process::Command::new("git")
660        .args(["diff", "HEAD"])
661        .current_dir(workspace)
662        .output()
663        .await?;
664    let diff = String::from_utf8_lossy(&output.stdout).to_string();
665    let files_output = tokio::process::Command::new("git")
666        .args(["diff", "--name-only", "HEAD"])
667        .current_dir(workspace)
668        .output()
669        .await?;
670    let files: Vec<String> = String::from_utf8_lossy(&files_output.stdout)
671        .lines()
672        .filter(|l| !l.is_empty())
673        .map(String::from)
674        .collect();
675    Ok((diff, files))
676}
677
678// ─── Checkpoint / Resume ───
679
680fn load_completed_ids(output_dir: &str) -> std::collections::HashSet<String> {
681    let path = format!("{}/swebench_results.jsonl", output_dir);
682    let mut ids = std::collections::HashSet::new();
683    if let Ok(data) = std::fs::read_to_string(&path) {
684        for line in data.lines() {
685            if let Ok(result) = serde_json::from_str::<InstanceResult>(line) {
686                ids.insert(result.instance_id);
687            }
688        }
689    }
690    ids
691}
692
693fn append_result(output_dir: &str, result: &InstanceResult) -> Result<()> {
694    use std::io::Write;
695    let path = format!("{}/swebench_results.jsonl", output_dir);
696    let mut file = std::fs::OpenOptions::new()
697        .create(true)
698        .append(true)
699        .open(&path)?;
700    writeln!(file, "{}", serde_json::to_string(result)?)?;
701    Ok(())
702}
703
704fn append_prediction(output_dir: &str, prediction: &Prediction) -> Result<()> {
705    use std::io::Write;
706    let path = format!("{}/predictions.jsonl", output_dir);
707    let mut file = std::fs::OpenOptions::new()
708        .create(true)
709        .append(true)
710        .open(&path)?;
711    writeln!(file, "{}", serde_json::to_string(prediction)?)?;
712    Ok(())
713}
714
715// ─── Entry Points ───
716
717pub async fn run_single(
718    instance: &SwebenchInstance,
719    opts: &SwebenchOpts,
720) -> Result<InstanceResult> {
721    let model = opts
722        .model_override
723        .as_deref()
724        .unwrap_or("claude-sonnet-4-6");
725    let llm = LlmClient::with_limits(model, 65536, 8192);
726
727    println!("\n[SWE-bench] Instance: {}", instance.instance_id);
728    println!(
729        "  Repo: {} @ {}",
730        instance.repo,
731        &instance.base_commit[..7.min(instance.base_commit.len())]
732    );
733    println!(
734        "  Tests: {} FAIL_TO_PASS",
735        instance.fail_to_pass_tests().len()
736    );
737    println!("  Model: {}", model);
738
739    let workspace = setup_instance_workspace(instance, &opts.output_dir).await?;
740
741    let start = Instant::now();
742    let (turns, tokens) = match tokio::time::timeout(
743        std::time::Duration::from_secs(opts.timeout_secs),
744        run_agent_loop(&llm, instance, &workspace, opts.max_turns),
745    )
746    .await
747    {
748        Ok(Ok(result)) => result,
749        Ok(Err(e)) => {
750            return Ok(InstanceResult {
751                instance_id: instance.instance_id.clone(),
752                repo: instance.repo.clone(),
753                model: model.to_string(),
754                resolved: false,
755                fail_to_pass_count: instance.fail_to_pass_tests().len(),
756                fail_to_pass_passed: 0,
757                turns_used: 0,
758                tokens_used: 0,
759                duration_secs: start.elapsed().as_secs_f64(),
760                files_modified: vec![],
761                patch: String::new(),
762                error: Some(format!("Agent error: {}", e)),
763            });
764        }
765        Err(_) => {
766            return Ok(InstanceResult {
767                instance_id: instance.instance_id.clone(),
768                repo: instance.repo.clone(),
769                model: model.to_string(),
770                resolved: false,
771                fail_to_pass_count: instance.fail_to_pass_tests().len(),
772                fail_to_pass_passed: 0,
773                turns_used: opts.max_turns,
774                tokens_used: 0,
775                duration_secs: start.elapsed().as_secs_f64(),
776                files_modified: vec![],
777                patch: String::new(),
778                error: Some(format!("Timeout after {}s", opts.timeout_secs)),
779            });
780        }
781    };
782    let duration = start.elapsed().as_secs_f64();
783
784    let (patch, files_modified) = generate_patch(&workspace).await?;
785
786    // Run tests to check if resolved
787    let fail_tests = instance.fail_to_pass_tests();
788    let repo_config = get_repo_config(&instance.repo);
789    let mut passed = 0;
790    for test in &fail_tests {
791        let normalized = if instance.repo == "django/django" {
792            normalize_django_test(test)
793        } else {
794            test.to_string()
795        };
796        let test_cmd = format!(
797            "{}{}",
798            repo_config.env_setup,
799            repo_config.test_cmd_template.replace("{test}", &normalized)
800        );
801        let test_result = tokio::process::Command::new("sh")
802            .arg("-c")
803            .arg(&test_cmd)
804            .current_dir(&workspace)
805            .env("PYTHONDONTWRITEBYTECODE", "1")
806            .output()
807            .await;
808        if let Ok(output) = test_result {
809            if output.status.success() {
810                passed += 1;
811                println!("  PASS: {}", test);
812            } else {
813                let stdout = String::from_utf8_lossy(&output.stdout);
814                let stderr = String::from_utf8_lossy(&output.stderr);
815                let last = stdout
816                    .lines()
817                    .rev()
818                    .chain(stderr.lines().rev())
819                    .find(|l| !l.trim().is_empty() && !l.starts_with("="))
820                    .unwrap_or("unknown error");
821                println!("  FAIL: {} — {}", test, last);
822            }
823        } else {
824            println!("  ERROR: {} — could not execute test command", test);
825        }
826    }
827
828    let resolved = passed == fail_tests.len() && !fail_tests.is_empty();
829    println!(
830        "  Result: {} ({}/{} tests) | {} turns | {:.0}s",
831        if resolved { "RESOLVED" } else { "FAILED" },
832        passed,
833        fail_tests.len(),
834        turns,
835        duration
836    );
837
838    Ok(InstanceResult {
839        instance_id: instance.instance_id.clone(),
840        repo: instance.repo.clone(),
841        model: model.to_string(),
842        resolved,
843        fail_to_pass_count: fail_tests.len(),
844        fail_to_pass_passed: passed,
845        turns_used: turns,
846        tokens_used: tokens,
847        duration_secs: duration,
848        files_modified,
849        patch,
850        error: None,
851    })
852}
853
854pub async fn run_batch(opts: &SwebenchOpts) -> Result<()> {
855    tokio::fs::create_dir_all(&opts.output_dir).await?;
856    tokio::fs::create_dir_all(&format!("{}/datasets", opts.output_dir)).await?;
857
858    let instances = load_dataset(opts)?;
859    if instances.is_empty() {
860        println!("No instances to run.");
861        return Ok(());
862    }
863
864    let completed = if opts.resume {
865        let ids = load_completed_ids(&opts.output_dir);
866        if !ids.is_empty() {
867            println!("Resuming: {} instances already completed", ids.len());
868        }
869        ids
870    } else {
871        std::collections::HashSet::new()
872    };
873
874    let total = instances.len();
875    let mut resolved_count = 0u32;
876    let mut completed_count = 0u32;
877    let mut total_duration = 0.0f64;
878    let mut total_tokens = 0u64;
879
880    println!("\n========================================");
881    println!(" SWE-bench Run: {} instances", total);
882    println!("========================================\n");
883
884    for (i, instance) in instances.iter().enumerate() {
885        if completed.contains(&instance.instance_id) {
886            println!(
887                "[{}/{}] SKIP {} (already completed)",
888                i + 1,
889                total,
890                instance.instance_id
891            );
892            continue;
893        }
894
895        println!("\n[{}/{}] {}", i + 1, total, instance.instance_id);
896
897        let result = match run_single(instance, opts).await {
898            Ok(r) => r,
899            Err(e) => {
900                println!("  ERROR: {}", e);
901                InstanceResult {
902                    instance_id: instance.instance_id.clone(),
903                    repo: instance.repo.clone(),
904                    model: opts
905                        .model_override
906                        .clone()
907                        .unwrap_or_else(|| "default".into()),
908                    resolved: false,
909                    fail_to_pass_count: instance.fail_to_pass_tests().len(),
910                    fail_to_pass_passed: 0,
911                    turns_used: 0,
912                    tokens_used: 0,
913                    duration_secs: 0.0,
914                    files_modified: vec![],
915                    patch: String::new(),
916                    error: Some(format!("{}", e)),
917                }
918            }
919        };
920
921        if result.resolved {
922            resolved_count += 1;
923        }
924        completed_count += 1;
925        total_duration += result.duration_secs;
926        total_tokens += result.tokens_used;
927
928        if let Err(e) = append_result(&opts.output_dir, &result) {
929            eprintln!("Warning: failed to save result: {}", e);
930        }
931        if let Err(e) = append_prediction(
932            &opts.output_dir,
933            &Prediction {
934                instance_id: result.instance_id.clone(),
935                model_name_or_path: result.model.clone(),
936                model_patch: result.patch.clone(),
937            },
938        ) {
939            eprintln!("Warning: failed to save prediction: {}", e);
940        }
941
942        let rate = if completed_count > 0 {
943            resolved_count as f64 / completed_count as f64 * 100.0
944        } else {
945            0.0
946        };
947        println!(
948            "\n  Progress: {}/{} completed | {}/{} resolved ({:.1}%) | {:.0}s avg",
949            completed_count,
950            total,
951            resolved_count,
952            completed_count,
953            rate,
954            if completed_count > 0 {
955                total_duration / completed_count as f64
956            } else {
957                0.0
958            }
959        );
960    }
961
962    println!("\n========================================");
963    println!(" SWE-bench Run Complete");
964    println!("========================================");
965    let rate = if completed_count > 0 {
966        resolved_count as f64 / completed_count as f64 * 100.0
967    } else {
968        0.0
969    };
970    println!(
971        "  Resolved: {}/{} ({:.1}%)",
972        resolved_count, completed_count, rate
973    );
974    println!("  Total tokens: {}", total_tokens);
975    println!(
976        "  Total time: {:.0}s ({:.0}s avg)",
977        total_duration,
978        if completed_count > 0 {
979            total_duration / completed_count as f64
980        } else {
981            0.0
982        }
983    );
984    println!("  Results: {}/swebench_results.jsonl", opts.output_dir);
985    println!("  Predictions: {}/predictions.jsonl", opts.output_dir);
986    Ok(())
987}
988
989#[cfg(test)]
990mod tests {
991    use super::*;
992
993    #[test]
994    fn test_parse_test_list_array() {
995        let v = json!(["test_foo.py::TestBar", "test_baz.py::TestQux"]);
996        let tests = parse_test_list(&v);
997        assert_eq!(tests.len(), 2);
998    }
999
1000    #[test]
1001    fn test_parse_test_list_string() {
1002        let v = json!("[\"test_foo.py::TestBar\", \"test_baz.py::TestQux\"]");
1003        let tests = parse_test_list(&v);
1004        assert_eq!(tests.len(), 2);
1005    }
1006
1007    #[test]
1008    fn test_parse_test_list_single() {
1009        let v = json!("test_foo.py::TestBar");
1010        let tests = parse_test_list(&v);
1011        assert_eq!(tests.len(), 1);
1012    }
1013}