cap-rs-orchestrator 0.1.0

Fleet orchestration engine for CAP (CLI Agent Protocol) — declarative YAML-driven multi-agent coordination.
Documentation
//! Test doubles: a `Driver` and (later) a driver factory that emit scripted
//! events, so the engine can be tested with zero real LLM / network.

use std::collections::HashMap;
use std::collections::VecDeque;
use std::path::Path;
use std::sync::{Arc, Mutex};
use std::time::Duration;

use async_trait::async_trait;
use cap_rs::core::{
    AgentEvent, CAP_PROTOCOL_VERSION, ClientFrame, PermissionScope, RiskLevel, StopReason,
    TextChannel, Usage,
};
use cap_rs::driver::{Driver, DriverError};

use crate::OrchestratorError;
use crate::config::{DriverKind, PermissionPolicy, SessionId};
use crate::factory::DriverFactory;

/// A scripted driver. Build it with chained helpers, then it replays the queued
/// events on successive `next_event()` calls and returns `None` afterwards.
#[derive(Debug, Default)]
pub struct StubDriver {
    name: String,
    queue: VecDeque<AgentEvent>,
    alive: bool,
    last_decision: Option<cap_rs::core::PermissionDecision>,
    captured: Option<Arc<Mutex<Vec<String>>>>,
    captured_frame_kinds: Option<Arc<Mutex<Vec<&'static str>>>>,
    /// Mirrors `Driver::prompt_after_ready` — makes the actor wait for a
    /// scripted `Ready` before sending the prompt.
    await_ready: bool,
    event_delay: Option<Duration>,
}

impl StubDriver {
    pub fn new(name: &str) -> Self {
        Self {
            name: name.to_string(),
            queue: VecDeque::new(),
            alive: true,
            last_decision: None,
            captured: None,
            captured_frame_kinds: None,
            await_ready: false,
            event_delay: None,
        }
    }

    /// Script a `Ready` event (e.g. a PTY agent finishing boot).
    pub fn ready(mut self) -> Self {
        self.queue.push_back(AgentEvent::Ready {
            session_id: Some(format!("{}-sess", self.name)),
            version: CAP_PROTOCOL_VERSION.into(),
            model: None,
        });
        self
    }

    /// Behave like a PTY driver: require a `Ready` before the prompt is sent.
    pub fn await_ready(mut self) -> Self {
        self.await_ready = true;
        self
    }

    pub fn text(mut self, t: &str) -> Self {
        self.queue.push_back(AgentEvent::TextChunk {
            msg_id: format!("{}-m", self.name),
            text: t.to_string(),
            channel: TextChannel::Assistant,
        });
        self
    }

    /// Script a permission request the engine must resolve before `done`.
    pub fn permission(mut self, tool: &str, risk: RiskLevel) -> Self {
        self.queue.push_back(AgentEvent::PermissionRequest {
            req_id: format!("{}-req", self.name),
            tool: tool.to_string(),
            intent: serde_json::json!({}),
            scope: PermissionScope::Execute,
            risk_level: risk,
        });
        self
    }

    pub fn done(mut self, stop: StopReason) -> Self {
        self.queue.push_back(AgentEvent::Done {
            stop_reason: stop,
            usage: Usage::default(),
        });
        self
    }

    pub fn usage_cost(mut self, cost: f64) -> Self {
        let mut usage = Usage::default();
        usage.cost_usd_estimate = Some(cost);
        self.queue.push_back(AgentEvent::Usage { usage });
        self
    }

    /// Record the text of every Prompt frame this driver receives, for assertions.
    pub fn capture(mut self, sink: Arc<Mutex<Vec<String>>>) -> Self {
        self.captured = Some(sink);
        self
    }

    /// Record the kind of every frame received via `send()`.
    pub fn capture_frame_kinds(mut self, sink: Arc<Mutex<Vec<&'static str>>>) -> Self {
        self.captured_frame_kinds = Some(sink);
        self
    }

    pub fn delay_events(mut self, duration: Duration) -> Self {
        self.event_delay = Some(duration);
        self
    }

    /// The most recent `PermissionDecision` received via `send()`, if any.
    pub fn last_decision(&self) -> Option<cap_rs::core::PermissionDecision> {
        self.last_decision
    }
}

#[async_trait::async_trait]
impl Driver for StubDriver {
    async fn send(&mut self, frame: ClientFrame) -> Result<(), DriverError> {
        if let Some(sink) = &self.captured_frame_kinds {
            let kind = match &frame {
                ClientFrame::SessionConfig(_) => "SessionConfig",
                ClientFrame::Prompt { .. } => "Prompt",
                ClientFrame::AskUserAnswer { .. } => "AskUserAnswer",
                ClientFrame::PermissionResponse { .. } => "PermissionResponse",
                ClientFrame::Cancel { .. } => "Cancel",
                ClientFrame::ReverseRpcResult { .. } => "ReverseRpcResult",
                _ => "Unknown",
            };
            sink.lock().expect("frame sink mutex poisoned").push(kind);
        }
        match frame {
            ClientFrame::PermissionResponse { decision, .. } => {
                self.last_decision = Some(decision);
            }
            ClientFrame::Prompt { content } => {
                if let Some(sink) = &self.captured {
                    let text: String = content
                        .iter()
                        .filter_map(|c| match c {
                            cap_rs::core::Content::Text { text } => Some(text.as_str()),
                            _ => None,
                        })
                        .collect();
                    sink.lock().expect("capture sink mutex poisoned").push(text);
                }
            }
            _ => {}
        }
        Ok(())
    }

    async fn next_event(&mut self) -> Option<AgentEvent> {
        if let Some(delay) = self.event_delay.take() {
            tokio::time::sleep(delay).await;
        }
        let ev = self.queue.pop_front();
        if ev.is_none() {
            self.alive = false;
        }
        ev
    }

    async fn shutdown(&mut self) -> Result<(), DriverError> {
        self.alive = false;
        Ok(())
    }

    fn is_alive(&self) -> bool {
        self.alive
    }

    fn prompt_after_ready(&self) -> bool {
        self.await_ready
    }
}

/// A factory that hands out pre-scripted `StubDriver`s by session id.
#[derive(Debug, Default)]
pub struct StubDriverFactory {
    scripts: Mutex<HashMap<SessionId, StubDriver>>,
}

impl StubDriverFactory {
    pub fn new() -> Self {
        Self::default()
    }

    /// Register the driver a given session id should receive.
    pub fn with(self, session: &str, driver: StubDriver) -> Self {
        self.scripts
            .lock()
            .expect("stub factory mutex poisoned")
            .insert(session.to_string(), driver);
        self
    }
}

#[async_trait]
impl DriverFactory for StubDriverFactory {
    async fn build(
        &self,
        session: &SessionId,
        _kind: &DriverKind,
        _cwd: &Path,
        _policy: PermissionPolicy,
    ) -> Result<Box<dyn cap_rs::driver::Driver>, OrchestratorError> {
        self.scripts
            .lock()
            .expect("stub factory mutex poisoned")
            .remove(session)
            .map(|d| Box::new(d) as Box<dyn cap_rs::driver::Driver>)
            .ok_or_else(|| OrchestratorError::Config(format!("no stub for session '{session}'")))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use cap_rs::core::{ClientFrame, Content, StopReason};
    use cap_rs::driver::Driver;

    #[tokio::test]
    async fn stub_emits_scripted_events_then_done() {
        let mut d = StubDriver::new("s1")
            .text("hello ")
            .text("world")
            .done(StopReason::EndTurn);

        // Driving a prompt in is a no-op for the stub but must not error.
        d.send(ClientFrame::Prompt {
            content: vec![Content::text("hi")],
        })
        .await
        .unwrap();

        let mut texts = String::new();
        let mut saw_done = false;
        while let Some(ev) = d.next_event().await {
            match ev {
                cap_rs::core::AgentEvent::TextChunk { text, .. } => texts.push_str(&text),
                cap_rs::core::AgentEvent::Done { .. } => saw_done = true,
                _ => {}
            }
        }
        assert_eq!(texts, "hello world");
        assert!(saw_done);
    }
}