use std::future::Future;
use std::time::{Duration, Instant};
use anyhow::Result;
use tokio::io::AsyncWriteExt;
use super::types::{
JudgmentDecision, JudgmentOutput, JudgmentRequest, JudgmentResult, JudgmentUsage,
};
const JUDGMENT_SCHEMA: &str = r#"{"type":"object","properties":{"decision":{"type":"string","enum":["approve","reject","uncertain"]},"reasoning":{"type":"string"}},"required":["decision","reasoning"]}"#;
pub trait JudgmentProvider: Send + Sync {
fn judge(
&self,
request: &JudgmentRequest,
) -> impl Future<Output = Result<JudgmentResult>> + Send;
}
pub struct ClaudeHaikuJudge {
model: String,
timeout: Duration,
custom_command: Option<String>,
}
impl ClaudeHaikuJudge {
pub fn new(model: String, timeout_secs: u64, custom_command: Option<String>) -> Self {
Self {
model,
timeout: Duration::from_secs(timeout_secs),
custom_command,
}
}
fn build_prompt(&self, request: &JudgmentRequest) -> String {
format!(
r#"You are a safety gate for an AI coding agent. Examine the terminal output and decide if the pending action should be auto-approved.
Approval Type: {approval_type}
Details: {details}
Working Directory: {cwd}
Agent Type: {agent_type}
Terminal Output (last 30 lines):
{screen_context}
Rules:
- Return "approve" only when ALL of these apply:
- The operation is read-only OR explicitly low-risk (e.g., reading files, listing directories, running tests, formatting code)
- No file modification that could break the build or delete important data
- No privilege escalation (sudo, chmod 777, etc.)
- No network/data exfiltration risk (curl to external URLs with sensitive data, etc.)
- No signs of command injection or untrusted input
- The action is clearly related to the current development task
- Return "reject" when ANY of these apply:
- Destructive operations (rm -rf, DROP TABLE, force push, etc.)
- Writing to system files or configuration outside the project
- Network requests to unknown external services with sensitive data
- Privilege escalation attempts
- The action seems suspicious or unrelated to development
- When unsure, return "uncertain" (safer than wrong approval)"#,
approval_type = request.approval_type,
details = request.details,
cwd = request.cwd,
agent_type = request.agent_type,
screen_context = request.screen_context,
)
}
}
impl JudgmentProvider for ClaudeHaikuJudge {
async fn judge(&self, request: &JudgmentRequest) -> Result<JudgmentResult> {
let start = Instant::now();
let prompt = self.build_prompt(request);
let command = self.custom_command.as_deref().unwrap_or("claude");
let mut child = tokio::process::Command::new(command)
.args([
"-p",
"--model",
&self.model,
"--output-format",
"json",
"--json-schema",
JUDGMENT_SCHEMA,
])
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.kill_on_drop(true)
.spawn()?;
{
let mut stdin = child
.stdin
.take()
.ok_or_else(|| anyhow::anyhow!("Failed to open stdin"))?;
stdin.write_all(prompt.as_bytes()).await?;
}
let output = match tokio::time::timeout(self.timeout, child.wait_with_output()).await {
Ok(Ok(output)) => output,
Ok(Err(e)) => {
return Ok(JudgmentResult {
decision: JudgmentDecision::Uncertain,
reasoning: format!("Process error: {}", e),
model: self.model.clone(),
elapsed_ms: start.elapsed().as_millis() as u64,
usage: None,
});
}
Err(_) => {
return Ok(JudgmentResult {
decision: JudgmentDecision::Uncertain,
reasoning: "Judgment timed out".to_string(),
model: self.model.clone(),
elapsed_ms: start.elapsed().as_millis() as u64,
usage: None,
});
}
};
let elapsed_ms = start.elapsed().as_millis() as u64;
let stderr = String::from_utf8_lossy(&output.stderr);
let usage = Self::parse_usage_from_stderr(&stderr);
if !output.status.success() {
return Ok(JudgmentResult {
decision: JudgmentDecision::Uncertain,
reasoning: format!("CLI error (exit {}): {}", output.status, stderr),
model: self.model.clone(),
elapsed_ms,
usage,
});
}
let stdout = String::from_utf8_lossy(&output.stdout);
let raw_source = if stdout.trim().is_empty() {
&stderr
} else {
&stdout
};
match Self::parse_claude_output(raw_source) {
Ok(judgment_output) => Ok(JudgmentResult {
decision: judgment_output.parse_decision(),
reasoning: judgment_output.reasoning,
model: self.model.clone(),
elapsed_ms,
usage,
}),
Err(e) => {
let truncated: String = raw_source.chars().take(500).collect();
let raw_display = if raw_source.chars().count() > 500 {
format!("{}...(truncated)", truncated)
} else {
truncated
};
Ok(JudgmentResult {
decision: JudgmentDecision::Uncertain,
reasoning: format!("Failed to parse output: {}. Raw: {}", e, raw_display),
model: self.model.clone(),
elapsed_ms,
usage,
})
}
}
}
}
impl ClaudeHaikuJudge {
fn parse_claude_output(stdout: &str) -> Result<JudgmentOutput> {
if let Ok(output) = serde_json::from_str::<JudgmentOutput>(stdout) {
return Ok(output);
}
if let Ok(wrapper) = serde_json::from_str::<serde_json::Value>(stdout) {
let text = Self::extract_text_from_claude_json(&wrapper);
if let Some(text) = text {
if let Ok(output) = serde_json::from_str::<JudgmentOutput>(&text) {
return Ok(output);
}
}
}
anyhow::bail!("Could not parse claude output as JudgmentOutput")
}
fn extract_text_from_claude_json(value: &serde_json::Value) -> Option<String> {
if let Some(structured) = value.get("structured_output") {
if structured.is_object() {
return Some(structured.to_string());
}
}
if let Some(result) = value.get("result") {
if let Some(text) = result.as_str() {
return Some(text.to_string());
}
}
if let Some(arr) = value.as_array() {
for item in arr {
if item.get("type").and_then(|t| t.as_str()) == Some("text") {
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
return Some(text.to_string());
}
}
}
}
if let Some(result) = value.get("result") {
if let Some(arr) = result.as_array() {
for item in arr {
if item.get("type").and_then(|t| t.as_str()) == Some("text") {
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
return Some(text.to_string());
}
}
}
}
}
None
}
fn parse_usage_from_stderr(stderr: &str) -> Option<JudgmentUsage> {
let value: serde_json::Value = serde_json::from_str(stderr.trim()).ok()?;
let usage_obj = value.get("usage")?;
let cost = value
.get("total_cost_usd")
.and_then(|v| v.as_f64())
.unwrap_or(0.0);
Some(JudgmentUsage {
input_tokens: usage_obj
.get("input_tokens")
.and_then(|v| v.as_u64())
.unwrap_or(0),
output_tokens: usage_obj
.get("output_tokens")
.and_then(|v| v.as_u64())
.unwrap_or(0),
cache_read_input_tokens: usage_obj
.get("cache_read_input_tokens")
.and_then(|v| v.as_u64())
.unwrap_or(0),
cache_creation_input_tokens: usage_obj
.get("cache_creation_input_tokens")
.and_then(|v| v.as_u64())
.unwrap_or(0),
cost_usd: cost,
})
}
}