use std::time::Duration;
use anyhow::Result;
use async_trait::async_trait;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use crate::state::TokenUsage;
use super::{Agent, AgentEvent, AgentOutcome, AgentRequest, StopReason};
#[derive(Debug, Clone)]
pub enum DryRunStep {
Emit(AgentEvent),
Wait(Duration),
}
#[derive(Debug, Clone)]
pub enum DryRunFinal {
Success {
exit_code: i32,
tokens: TokenUsage,
},
Error(String),
Hang,
}
pub struct DryRunAgent {
name: String,
script: Vec<DryRunStep>,
finish: DryRunFinal,
}
impl DryRunAgent {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
script: Vec::new(),
finish: DryRunFinal::Success {
exit_code: 0,
tokens: TokenUsage::default(),
},
}
}
pub fn emit(mut self, event: AgentEvent) -> Self {
self.script.push(DryRunStep::Emit(event));
self
}
pub fn wait(mut self, d: Duration) -> Self {
self.script.push(DryRunStep::Wait(d));
self
}
pub fn finish(mut self, finish: DryRunFinal) -> Self {
self.finish = finish;
self
}
}
#[async_trait]
impl Agent for DryRunAgent {
fn name(&self) -> &str {
&self.name
}
async fn run(
&self,
req: AgentRequest,
events: mpsc::Sender<AgentEvent>,
cancel: CancellationToken,
) -> Result<AgentOutcome> {
let log_path = req.log_path.clone();
let work = async {
for step in &self.script {
match step {
DryRunStep::Emit(e) => {
let _ = events.send(e.clone()).await;
}
DryRunStep::Wait(d) => tokio::time::sleep(*d).await,
}
}
match &self.finish {
DryRunFinal::Success { exit_code, tokens } => AgentOutcome {
exit_code: *exit_code,
stop_reason: StopReason::Completed,
tokens: tokens.clone(),
log_path: log_path.clone(),
},
DryRunFinal::Error(msg) => AgentOutcome {
exit_code: 1,
stop_reason: StopReason::Error(msg.clone()),
tokens: TokenUsage::default(),
log_path: log_path.clone(),
},
DryRunFinal::Hang => {
std::future::pending::<()>().await;
unreachable!("std::future::pending never resolves");
}
}
};
let outcome = tokio::select! {
biased;
_ = cancel.cancelled() => AgentOutcome {
exit_code: -1,
stop_reason: StopReason::Cancelled,
tokens: TokenUsage::default(),
log_path: log_path.clone(),
},
_ = tokio::time::sleep(req.timeout) => AgentOutcome {
exit_code: -1,
stop_reason: StopReason::Timeout,
tokens: TokenUsage::default(),
log_path: log_path.clone(),
},
o = work => o,
};
Ok(outcome)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::Role;
use std::path::PathBuf;
fn req(timeout: Duration) -> AgentRequest {
AgentRequest {
role: Role::Implementer,
model: "test-model".into(),
system_prompt: String::new(),
user_prompt: String::new(),
workdir: PathBuf::from("/tmp"),
log_path: PathBuf::from("/tmp/dry-run.log"),
timeout,
env: std::collections::HashMap::new(),
}
}
async fn drain<T>(mut rx: mpsc::Receiver<T>) -> Vec<T> {
let mut out = Vec::new();
while let Some(v) = rx.recv().await {
out.push(v);
}
out
}
#[tokio::test]
async fn success_path_streams_events_and_returns_completed() {
let tokens = TokenUsage {
input: 10,
output: 5,
..Default::default()
};
let agent = DryRunAgent::new("test")
.emit(AgentEvent::Stdout("hello".into()))
.emit(AgentEvent::ToolUse("write".into()))
.finish(DryRunFinal::Success {
exit_code: 0,
tokens: tokens.clone(),
});
let (tx, rx) = mpsc::channel(8);
let cancel = CancellationToken::new();
let outcome = agent
.run(req(Duration::from_secs(5)), tx, cancel)
.await
.unwrap();
assert_eq!(outcome.stop_reason, StopReason::Completed);
assert_eq!(outcome.exit_code, 0);
assert_eq!(outcome.tokens, tokens);
assert_eq!(outcome.log_path, PathBuf::from("/tmp/dry-run.log"));
let events = drain(rx).await;
assert_eq!(events.len(), 2);
assert!(matches!(events[0], AgentEvent::Stdout(ref s) if s == "hello"));
assert!(matches!(events[1], AgentEvent::ToolUse(ref s) if s == "write"));
}
#[tokio::test]
async fn failure_path_returns_error_stop_reason() {
let agent = DryRunAgent::new("test").finish(DryRunFinal::Error("boom".into()));
let (tx, _rx) = mpsc::channel(1);
let cancel = CancellationToken::new();
let outcome = agent
.run(req(Duration::from_secs(1)), tx, cancel)
.await
.unwrap();
match outcome.stop_reason {
StopReason::Error(msg) => assert_eq!(msg, "boom"),
other => panic!("expected Error, got {:?}", other),
}
assert_eq!(outcome.exit_code, 1);
}
#[tokio::test]
async fn timeout_path_fires_when_agent_hangs() {
let agent = DryRunAgent::new("test").finish(DryRunFinal::Hang);
let (tx, _rx) = mpsc::channel(1);
let cancel = CancellationToken::new();
let outcome = agent
.run(req(Duration::from_millis(40)), tx, cancel)
.await
.unwrap();
assert_eq!(outcome.stop_reason, StopReason::Timeout);
assert_eq!(outcome.exit_code, -1);
}
#[tokio::test]
async fn cancellation_path_aborts_a_hanging_agent() {
let agent = DryRunAgent::new("test").finish(DryRunFinal::Hang);
let (tx, _rx) = mpsc::channel(1);
let cancel = CancellationToken::new();
let canceler = cancel.clone();
let trigger = tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(20)).await;
canceler.cancel();
});
let outcome = agent
.run(req(Duration::from_secs(60)), tx, cancel)
.await
.unwrap();
assert_eq!(outcome.stop_reason, StopReason::Cancelled);
assert_eq!(outcome.exit_code, -1);
trigger.await.unwrap();
}
#[tokio::test]
async fn cancellation_wins_when_already_signalled_at_start() {
let agent = DryRunAgent::new("test")
.emit(AgentEvent::Stdout("never sent".into()))
.finish(DryRunFinal::Success {
exit_code: 0,
tokens: TokenUsage::default(),
});
let (tx, _rx) = mpsc::channel(1);
let cancel = CancellationToken::new();
cancel.cancel();
let outcome = agent
.run(req(Duration::from_secs(5)), tx, cancel)
.await
.unwrap();
assert_eq!(outcome.stop_reason, StopReason::Cancelled);
}
}