use std::collections::HashMap;
use std::path::PathBuf;
use anyhow::{Context, Result};
use async_trait::async_trait;
use serde::Deserialize;
use tokio::process::Command;
use crate::evolution::trial::BenchmarkAdapter;
use crate::evolution::types::{Feedback, Task, Trajectory};
#[derive(Debug, Deserialize)]
struct SweBenchInstance {
instance_id: String,
problem_statement: String,
repo: String,
base_commit: String,
#[serde(default)]
hints_text: String,
#[serde(rename = "FAIL_TO_PASS", default)]
fail_to_pass: Vec<String>,
#[serde(rename = "PASS_TO_PASS", default)]
pass_to_pass: Vec<String>,
}
pub struct SweBenchAdapter {
instances_path: PathBuf,
eval_dir: PathBuf,
use_docker: bool,
}
impl SweBenchAdapter {
pub fn new(instances_path: PathBuf, eval_dir: PathBuf) -> Self {
Self {
instances_path,
eval_dir,
use_docker: false,
}
}
pub fn with_docker(mut self, enabled: bool) -> Self {
self.use_docker = enabled;
self
}
fn load_instances(&self) -> Result<Vec<SweBenchInstance>> {
let content = std::fs::read_to_string(&self.instances_path).with_context(|| {
format!(
"Failed to read SWE-bench instances: {}",
self.instances_path.display()
)
})?;
let mut instances = Vec::new();
for (i, line) in content.lines().enumerate() {
let line = line.trim();
if line.is_empty() {
continue;
}
let inst: SweBenchInstance = serde_json::from_str(line)
.with_context(|| format!("Failed to parse SWE-bench instance at line {}", i + 1))?;
instances.push(inst);
}
Ok(instances)
}
fn extract_patch(output: &str) -> &str {
if let Some(fence_start) = output.find("```diff\n") {
let body = &output[fence_start + 8..];
if let Some(fence_end) = body.find("\n```") {
return &body[..fence_end];
}
}
if let Some(start) = output.find("diff --git") {
return &output[start..];
}
output.trim()
}
fn parse_eval_log(log: &str, instance_id: &str) -> (bool, f64, String) {
let resolved = log
.lines()
.any(|l| l.contains("RESOLVED") && l.contains(instance_id));
let ftp_total = log.lines().filter(|l| l.contains("FAIL_TO_PASS")).count();
let ftp_passed = log
.lines()
.filter(|l| {
l.contains("FAIL_TO_PASS") && (l.contains("passed") || l.contains("PASSED"))
})
.count();
let score = if ftp_total > 0 {
ftp_passed as f64 / ftp_total as f64
} else if resolved {
1.0
} else {
0.0
};
let detail = if resolved {
format!("RESOLVED — {ftp_passed}/{ftp_total} FAIL_TO_PASS tests passed")
} else {
format!("NOT RESOLVED — {ftp_passed}/{ftp_total} FAIL_TO_PASS tests passed")
};
(resolved, score, detail)
}
}
#[async_trait]
impl BenchmarkAdapter for SweBenchAdapter {
async fn get_tasks(&self, _split: &str, limit: usize) -> Result<Vec<Task>> {
let instances = self.load_instances()?;
let tasks = instances
.into_iter()
.take(if limit == 0 { usize::MAX } else { limit })
.map(|inst| {
let mut metadata: HashMap<String, serde_json::Value> = HashMap::new();
metadata.insert("repo".to_string(), serde_json::Value::String(inst.repo));
metadata.insert(
"base_commit".to_string(),
serde_json::Value::String(inst.base_commit),
);
metadata.insert(
"fail_to_pass".to_string(),
serde_json::json!(inst.fail_to_pass),
);
metadata.insert(
"pass_to_pass".to_string(),
serde_json::json!(inst.pass_to_pass),
);
let input = if inst.hints_text.is_empty() {
inst.problem_statement
} else {
format!("{}\n\nHints:\n{}", inst.problem_statement, inst.hints_text)
};
Task {
id: inst.instance_id,
input,
metadata,
}
})
.collect();
Ok(tasks)
}
async fn evaluate(&self, task: &Task, trajectory: &Trajectory) -> Result<Feedback> {
let patch = Self::extract_patch(&trajectory.output);
if patch.is_empty() {
return Ok(Feedback {
success: false,
score: 0.0,
detail: "No patch found in agent output".to_string(),
raw: Default::default(),
});
}
let tmp_dir =
std::env::temp_dir().join(format!("collet_swe_{}", task.id.replace('/', "_")));
std::fs::create_dir_all(&tmp_dir)?;
let predictions_path = tmp_dir.join("predictions.jsonl");
let log_dir = tmp_dir.join("logs");
std::fs::create_dir_all(&log_dir)?;
let prediction = serde_json::json!({
"instance_id": &task.id,
"model_patch": patch,
"model_name_or_path": "collet",
});
std::fs::write(&predictions_path, serde_json::to_string(&prediction)?)?;
let mut cmd = Command::new("python");
cmd.args([
"-m",
"swebench.harness.run_evaluation",
"--predictions_path",
&predictions_path.to_string_lossy(),
"--swe_bench_tasks",
&self.instances_path.to_string_lossy(),
"--log_dir",
&log_dir.to_string_lossy(),
"--instance_ids",
&task.id,
]);
if self.use_docker {
cmd.arg("--use_docker");
}
cmd.current_dir(&self.eval_dir);
let output = cmd
.output()
.await
.context("Failed to run SWE-bench harness. Is it installed? (pip install swebench)")?;
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
let log_path = log_dir.join(format!("{}.collet.eval.log", task.id));
let (success, score, detail) = if let Ok(log) = std::fs::read_to_string(&log_path) {
Self::parse_eval_log(&log, &task.id)
} else {
let resolved = stdout.contains("RESOLVED") || stdout.contains("resolved");
let s = if resolved { 1.0 } else { 0.0 };
let d = format!("exit={}\nstdout={stdout}\nstderr={stderr}", output.status);
(resolved, s, d)
};
let mut raw: HashMap<String, serde_json::Value> = HashMap::new();
raw.insert("stdout".into(), serde_json::Value::String(stdout));
raw.insert("stderr".into(), serde_json::Value::String(stderr));
Ok(Feedback {
success,
score,
detail,
raw,
})
}
}