use std::collections::HashMap;
use std::collections::VecDeque;
use std::path::Path;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use async_trait::async_trait;
use cap_rs::core::{
AgentEvent, CAP_PROTOCOL_VERSION, ClientFrame, PermissionScope, RiskLevel, StopReason,
TextChannel, Usage,
};
use cap_rs::driver::{Driver, DriverError};
use crate::OrchestratorError;
use crate::config::{DriverKind, PermissionPolicy, SessionId};
use crate::factory::DriverFactory;
#[derive(Debug, Default)]
pub struct StubDriver {
name: String,
queue: VecDeque<AgentEvent>,
alive: bool,
last_decision: Option<cap_rs::core::PermissionDecision>,
captured: Option<Arc<Mutex<Vec<String>>>>,
captured_frame_kinds: Option<Arc<Mutex<Vec<&'static str>>>>,
await_ready: bool,
event_delay: Option<Duration>,
}
impl StubDriver {
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
queue: VecDeque::new(),
alive: true,
last_decision: None,
captured: None,
captured_frame_kinds: None,
await_ready: false,
event_delay: None,
}
}
pub fn ready(mut self) -> Self {
self.queue.push_back(AgentEvent::Ready {
session_id: Some(format!("{}-sess", self.name)),
version: CAP_PROTOCOL_VERSION.into(),
model: None,
});
self
}
pub fn await_ready(mut self) -> Self {
self.await_ready = true;
self
}
pub fn text(mut self, t: &str) -> Self {
self.queue.push_back(AgentEvent::TextChunk {
msg_id: format!("{}-m", self.name),
text: t.to_string(),
channel: TextChannel::Assistant,
});
self
}
pub fn permission(mut self, tool: &str, risk: RiskLevel) -> Self {
self.queue.push_back(AgentEvent::PermissionRequest {
req_id: format!("{}-req", self.name),
tool: tool.to_string(),
intent: serde_json::json!({}),
scope: PermissionScope::Execute,
risk_level: risk,
});
self
}
pub fn done(mut self, stop: StopReason) -> Self {
self.queue.push_back(AgentEvent::Done {
stop_reason: stop,
usage: Usage::default(),
});
self
}
pub fn usage_cost(mut self, cost: f64) -> Self {
let mut usage = Usage::default();
usage.cost_usd_estimate = Some(cost);
self.queue.push_back(AgentEvent::Usage { usage });
self
}
pub fn capture(mut self, sink: Arc<Mutex<Vec<String>>>) -> Self {
self.captured = Some(sink);
self
}
pub fn capture_frame_kinds(mut self, sink: Arc<Mutex<Vec<&'static str>>>) -> Self {
self.captured_frame_kinds = Some(sink);
self
}
pub fn delay_events(mut self, duration: Duration) -> Self {
self.event_delay = Some(duration);
self
}
pub fn last_decision(&self) -> Option<cap_rs::core::PermissionDecision> {
self.last_decision
}
}
#[async_trait::async_trait]
impl Driver for StubDriver {
async fn send(&mut self, frame: ClientFrame) -> Result<(), DriverError> {
if let Some(sink) = &self.captured_frame_kinds {
let kind = match &frame {
ClientFrame::SessionConfig(_) => "SessionConfig",
ClientFrame::Prompt { .. } => "Prompt",
ClientFrame::AskUserAnswer { .. } => "AskUserAnswer",
ClientFrame::PermissionResponse { .. } => "PermissionResponse",
ClientFrame::Cancel { .. } => "Cancel",
ClientFrame::ReverseRpcResult { .. } => "ReverseRpcResult",
_ => "Unknown",
};
sink.lock().expect("frame sink mutex poisoned").push(kind);
}
match frame {
ClientFrame::PermissionResponse { decision, .. } => {
self.last_decision = Some(decision);
}
ClientFrame::Prompt { content } => {
if let Some(sink) = &self.captured {
let text: String = content
.iter()
.filter_map(|c| match c {
cap_rs::core::Content::Text { text } => Some(text.as_str()),
_ => None,
})
.collect();
sink.lock().expect("capture sink mutex poisoned").push(text);
}
}
_ => {}
}
Ok(())
}
async fn next_event(&mut self) -> Option<AgentEvent> {
if let Some(delay) = self.event_delay.take() {
tokio::time::sleep(delay).await;
}
let ev = self.queue.pop_front();
if ev.is_none() {
self.alive = false;
}
ev
}
async fn shutdown(&mut self) -> Result<(), DriverError> {
self.alive = false;
Ok(())
}
fn is_alive(&self) -> bool {
self.alive
}
fn prompt_after_ready(&self) -> bool {
self.await_ready
}
}
#[derive(Debug, Default)]
pub struct StubDriverFactory {
scripts: Mutex<HashMap<SessionId, StubDriver>>,
}
impl StubDriverFactory {
pub fn new() -> Self {
Self::default()
}
pub fn with(self, session: &str, driver: StubDriver) -> Self {
self.scripts
.lock()
.expect("stub factory mutex poisoned")
.insert(session.to_string(), driver);
self
}
}
#[async_trait]
impl DriverFactory for StubDriverFactory {
async fn build(
&self,
session: &SessionId,
_kind: &DriverKind,
_cwd: &Path,
_policy: PermissionPolicy,
) -> Result<Box<dyn cap_rs::driver::Driver>, OrchestratorError> {
self.scripts
.lock()
.expect("stub factory mutex poisoned")
.remove(session)
.map(|d| Box::new(d) as Box<dyn cap_rs::driver::Driver>)
.ok_or_else(|| OrchestratorError::Config(format!("no stub for session '{session}'")))
}
}
#[cfg(test)]
mod tests {
use super::*;
use cap_rs::core::{ClientFrame, Content, StopReason};
use cap_rs::driver::Driver;
#[tokio::test]
async fn stub_emits_scripted_events_then_done() {
let mut d = StubDriver::new("s1")
.text("hello ")
.text("world")
.done(StopReason::EndTurn);
d.send(ClientFrame::Prompt {
content: vec![Content::text("hi")],
})
.await
.unwrap();
let mut texts = String::new();
let mut saw_done = false;
while let Some(ev) = d.next_event().await {
match ev {
cap_rs::core::AgentEvent::TextChunk { text, .. } => texts.push_str(&text),
cap_rs::core::AgentEvent::Done { .. } => saw_done = true,
_ => {}
}
}
assert_eq!(texts, "hello world");
assert!(saw_done);
}
}