use std::sync::Arc;
use anyhow::{Result, anyhow};
use serde::{Deserialize, Serialize};
use brainwires_core::{Provider, Task};
use crate::context::AgentContext;
use crate::planner_agent::DynamicTaskSpec;
use crate::system_prompts::judge_agent_prompt;
use crate::task_agent::{TaskAgent, TaskAgentConfig, TaskAgentResult};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "verdict", rename_all = "snake_case")]
pub enum JudgeVerdict {
Complete {
summary: String,
},
Continue {
#[serde(default)]
summary: String,
#[serde(default)]
additional_tasks: Vec<DynamicTaskSpec>,
#[serde(default)]
retry_tasks: Vec<String>,
#[serde(default)]
hints: Vec<String>,
},
FreshRestart {
reason: String,
#[serde(default)]
hints: Vec<String>,
#[serde(default)]
summary: String,
},
Abort {
reason: String,
#[serde(default)]
summary: String,
},
}
impl JudgeVerdict {
pub fn verdict_type(&self) -> &'static str {
match self {
JudgeVerdict::Complete { .. } => "complete",
JudgeVerdict::Continue { .. } => "continue",
JudgeVerdict::FreshRestart { .. } => "fresh_restart",
JudgeVerdict::Abort { .. } => "abort",
}
}
pub fn hints(&self) -> &[String] {
match self {
JudgeVerdict::Continue { hints, .. } | JudgeVerdict::FreshRestart { hints, .. } => {
hints
}
_ => &[],
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MergeStatus {
Merged,
ConflictResolved,
ConflictFailed(String),
NotAttempted,
}
impl std::fmt::Display for MergeStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MergeStatus::Merged => write!(f, "merged"),
MergeStatus::ConflictResolved => write!(f, "conflict_resolved"),
MergeStatus::ConflictFailed(msg) => write!(f, "conflict_failed: {}", msg),
MergeStatus::NotAttempted => write!(f, "not_attempted"),
}
}
}
#[derive(Debug, Clone)]
pub struct WorkerResult {
pub task_id: String,
pub task_description: String,
pub agent_result: TaskAgentResult,
pub branch_name: String,
pub merge_status: MergeStatus,
}
#[derive(Debug, Clone)]
pub struct JudgeContext {
pub original_goal: String,
pub cycle_number: u32,
pub worker_results: Vec<WorkerResult>,
pub planner_rationale: String,
pub previous_verdicts: Vec<JudgeVerdict>,
}
#[derive(Debug, Clone)]
pub struct JudgeAgentConfig {
pub max_iterations: u32,
pub inspect_files: bool,
pub inspect_diffs: bool,
pub temperature: f32,
pub max_tokens: u32,
}
impl Default for JudgeAgentConfig {
fn default() -> Self {
Self {
max_iterations: 15,
inspect_files: true,
inspect_diffs: true,
temperature: 0.3, max_tokens: 4096,
}
}
}
pub struct JudgeAgent {
agent: Arc<TaskAgent>,
}
impl JudgeAgent {
pub fn new(
id: String,
judge_context: &JudgeContext,
provider: Arc<dyn Provider>,
context: Arc<AgentContext>,
config: JudgeAgentConfig,
) -> Self {
let system_prompt = judge_agent_prompt(&id, &context.working_directory);
let agent_config = TaskAgentConfig {
max_iterations: config.max_iterations,
system_prompt: Some(system_prompt),
temperature: config.temperature,
max_tokens: config.max_tokens,
validation_config: None,
..Default::default()
};
let task_description = Self::build_task_description(judge_context);
let task = Task::new(
format!("judge-cycle-{}", judge_context.cycle_number),
task_description,
);
let agent = Arc::new(TaskAgent::new(id, task, provider, context, agent_config));
Self { agent }
}
pub async fn execute(&self) -> Result<(JudgeVerdict, TaskAgentResult)> {
let result = self.agent.execute().await?;
if !result.success {
return Err(anyhow!("Judge agent failed: {}", result.summary));
}
let verdict = Self::parse_verdict(&result.summary)?;
Ok((verdict, result))
}
pub fn parse_verdict(text: &str) -> Result<JudgeVerdict> {
let json_str = extract_json_block(text)
.ok_or_else(|| anyhow!("No JSON block found in judge output"))?;
serde_json::from_str(&json_str)
.map_err(|e| anyhow!("Failed to parse judge verdict JSON: {}", e))
}
fn build_task_description(ctx: &JudgeContext) -> String {
let mut desc = format!(
"# Evaluate Cycle {} Results\n\n## Original Goal\n{}\n\n## Planner Rationale\n{}\n\n",
ctx.cycle_number, ctx.original_goal, ctx.planner_rationale
);
desc.push_str("## Worker Results\n\n");
for (i, wr) in ctx.worker_results.iter().enumerate() {
desc.push_str(&format!(
"### Worker {} (task: {})\n- **Task**: {}\n- **Success**: {}\n- **Summary**: {}\n- **Branch**: {}\n- **Merge**: {}\n- **Iterations**: {}\n\n",
i + 1,
wr.task_id,
wr.task_description,
wr.agent_result.success,
wr.agent_result.summary,
wr.branch_name,
wr.merge_status,
wr.agent_result.iterations,
));
}
if !ctx.previous_verdicts.is_empty() {
desc.push_str("## Previous Verdicts\n\n");
for (i, v) in ctx.previous_verdicts.iter().enumerate() {
desc.push_str(&format!("- Cycle {}: {}\n", i, v.verdict_type()));
}
desc.push('\n');
}
desc.push_str(
"## Your Task\n\n\
Evaluate the above results against the original goal. \
Output your verdict as a JSON block. \
If you need to inspect files or diffs for verification, use the available tools first.",
);
desc
}
pub fn agent(&self) -> &Arc<TaskAgent> {
&self.agent
}
}
fn extract_json_block(text: &str) -> Option<String> {
if let Some(start) = text.find("```json") {
let content_start = start + "```json".len();
if let Some(end) = text[content_start..].find("```") {
return Some(text[content_start..content_start + end].trim().to_string());
}
}
if let Some(start) = text.find("```") {
let content_start = start + "```".len();
let line_end = text[content_start..]
.find('\n')
.unwrap_or(text[content_start..].len());
let actual_start = content_start + line_end + 1;
if actual_start < text.len()
&& let Some(end) = text[actual_start..].find("```")
{
let candidate = text[actual_start..actual_start + end].trim();
if candidate.starts_with('{') {
return Some(candidate.to_string());
}
}
}
if let Some(start) = text.find('{') {
let mut depth = 0;
let mut end = start;
for (i, ch) in text[start..].char_indices() {
match ch {
'{' => depth += 1,
'}' => {
depth -= 1;
if depth == 0 {
end = start + i + 1;
break;
}
}
_ => {}
}
}
if depth == 0 && end > start {
return Some(text[start..end].to_string());
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_complete_verdict() {
let text = r#"```json
{
"verdict": "complete",
"summary": "All tasks completed successfully"
}
```"#;
let verdict = JudgeAgent::parse_verdict(text).unwrap();
assert!(matches!(verdict, JudgeVerdict::Complete { .. }));
assert_eq!(verdict.verdict_type(), "complete");
}
#[test]
fn test_parse_continue_verdict() {
let text = r#"```json
{
"verdict": "continue",
"summary": "Two tasks still need work",
"additional_tasks": [
{
"id": "fix-1",
"description": "Fix the remaining bug",
"files_involved": ["src/bug.rs"],
"depends_on": [],
"priority": "high"
}
],
"retry_tasks": ["task-3"],
"hints": ["Focus on error handling"]
}
```"#;
let verdict = JudgeAgent::parse_verdict(text).unwrap();
match &verdict {
JudgeVerdict::Continue {
additional_tasks,
retry_tasks,
hints,
..
} => {
assert_eq!(additional_tasks.len(), 1);
assert_eq!(retry_tasks, &["task-3"]);
assert_eq!(hints, &["Focus on error handling"]);
}
_ => panic!("Expected Continue verdict"),
}
}
#[test]
fn test_parse_fresh_restart_verdict() {
let text = r#"```json
{
"verdict": "fresh_restart",
"reason": "Agents went down the wrong path",
"hints": ["Try a different approach", "Focus on the API first"],
"summary": "Need to restart"
}
```"#;
let verdict = JudgeAgent::parse_verdict(text).unwrap();
match &verdict {
JudgeVerdict::FreshRestart { reason, hints, .. } => {
assert!(reason.contains("wrong path"));
assert_eq!(hints.len(), 2);
}
_ => panic!("Expected FreshRestart verdict"),
}
}
#[test]
fn test_parse_abort_verdict() {
let text = r#"```json
{
"verdict": "abort",
"reason": "The goal requires external API access we don't have",
"summary": "Cannot proceed"
}
```"#;
let verdict = JudgeAgent::parse_verdict(text).unwrap();
assert!(matches!(verdict, JudgeVerdict::Abort { .. }));
assert_eq!(verdict.verdict_type(), "abort");
}
#[test]
fn test_verdict_hints() {
let complete = JudgeVerdict::Complete {
summary: "done".into(),
};
assert!(complete.hints().is_empty());
let cont = JudgeVerdict::Continue {
summary: "partial".into(),
additional_tasks: vec![],
retry_tasks: vec![],
hints: vec!["hint1".into()],
};
assert_eq!(cont.hints().len(), 1);
}
#[test]
fn test_merge_status_display() {
assert_eq!(MergeStatus::Merged.to_string(), "merged");
assert_eq!(MergeStatus::NotAttempted.to_string(), "not_attempted");
assert!(
MergeStatus::ConflictFailed("oops".into())
.to_string()
.contains("oops")
);
}
}