use async_trait::async_trait;
use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::CancellationToken;
use crate::proto::{RunSpec, TerminalStatus};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HostRequestKind {
Approval,
}
pub struct HostRequest {
pub kind: HostRequestKind,
pub body: serde_json::Value,
pub reply: oneshot::Sender<serde_json::Value>,
}
#[derive(Clone)]
pub struct HostBridge {
req_tx: mpsc::UnboundedSender<HostRequest>,
}
impl HostBridge {
pub fn channel() -> (Self, mpsc::UnboundedReceiver<HostRequest>) {
let (req_tx, req_rx) = mpsc::unbounded_channel();
(HostBridge { req_tx }, req_rx)
}
pub async fn approval_call(
&self,
body: serde_json::Value,
) -> Result<serde_json::Value, String> {
self.call(HostRequestKind::Approval, body).await
}
async fn call(
&self,
kind: HostRequestKind,
body: serde_json::Value,
) -> Result<serde_json::Value, String> {
let (reply, rx) = oneshot::channel();
self.req_tx
.send(HostRequest { kind, body, reply })
.map_err(|_| "host bridge closed".to_string())?;
rx.await
.map_err(|_| "host bridge dropped reply".to_string())
}
}
#[derive(Clone)]
pub struct EventSink {
tx: mpsc::UnboundedSender<serde_json::Value>,
host: Option<HostBridge>,
}
impl EventSink {
pub fn channel() -> (Self, mpsc::UnboundedReceiver<serde_json::Value>) {
let (tx, rx) = mpsc::unbounded_channel();
(EventSink { tx, host: None }, rx)
}
pub fn with_host_bridge(mut self, bridge: HostBridge) -> Self {
self.host = Some(bridge);
self
}
pub fn host(&self) -> Option<&HostBridge> {
self.host.as_ref()
}
pub fn emit(&self, event: serde_json::Value) {
let _ = self.tx.send(event);
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ChildOutcome {
pub status: TerminalStatus,
pub result: Option<String>,
pub error: Option<String>,
pub transcript: Vec<serde_json::Value>,
}
impl ChildOutcome {
pub fn completed(result: impl Into<String>) -> Self {
Self {
status: TerminalStatus::Completed,
result: Some(result.into()),
error: None,
transcript: Vec::new(),
}
}
pub fn error(msg: impl Into<String>) -> Self {
Self {
status: TerminalStatus::Error,
result: None,
error: Some(msg.into()),
transcript: Vec::new(),
}
}
pub fn cancelled() -> Self {
Self {
status: TerminalStatus::Cancelled,
result: None,
error: None,
transcript: Vec::new(),
}
}
pub fn suspended(transcript: Vec<serde_json::Value>) -> Self {
Self {
status: TerminalStatus::Suspended,
result: None,
error: None,
transcript,
}
}
}
pub struct SteerInbox {
rx: mpsc::UnboundedReceiver<String>,
}
impl SteerInbox {
pub fn channel() -> (mpsc::UnboundedSender<String>, Self) {
let (tx, rx) = mpsc::unbounded_channel();
(tx, SteerInbox { rx })
}
pub fn disconnected() -> Self {
let (_tx, rx) = mpsc::unbounded_channel();
SteerInbox { rx }
}
pub async fn recv(&mut self) -> Option<String> {
self.rx.recv().await
}
}
#[async_trait]
pub trait ChildExecutor: Send + Sync + 'static {
async fn run(
&self,
spec: RunSpec,
events: EventSink,
steer: SteerInbox,
cancel: CancellationToken,
) -> ChildOutcome;
}
pub struct EchoExecutor;
pub const ECHO_SLEEP_PREFIX: &str = "__sleep_ms:";
#[async_trait]
impl ChildExecutor for EchoExecutor {
async fn run(
&self,
spec: RunSpec,
events: EventSink,
_steer: SteerInbox,
cancel: CancellationToken,
) -> ChildOutcome {
let mut sleep_ms: Option<u64> = None;
let mut words: Vec<&str> = Vec::new();
for word in spec.assignment.split_whitespace() {
match word
.strip_prefix(ECHO_SLEEP_PREFIX)
.and_then(|n| n.parse::<u64>().ok())
{
Some(ms) if sleep_ms.is_none() => sleep_ms = Some(ms),
_ => words.push(word),
}
}
if let Some(ms) = sleep_ms {
tokio::select! {
_ = tokio::time::sleep(std::time::Duration::from_millis(ms)) => {}
_ = cancel.cancelled() => return ChildOutcome::cancelled(),
}
}
for word in &words {
if cancel.is_cancelled() {
return ChildOutcome::cancelled();
}
events.emit(serde_json::json!({ "type": "token", "content": format!("{word} ") }));
tokio::task::yield_now().await;
}
events.emit(serde_json::json!({ "type": "complete" }));
ChildOutcome::completed(format!("echo: {}", words.join(" ")))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn echo_streams_then_completes() {
let (sink, mut rx) = EventSink::channel();
let outcome = EchoExecutor
.run(
RunSpec {
assignment: "alpha beta".into(),
reasoning_effort: None,
messages: Vec::new(),
},
sink,
SteerInbox::disconnected(),
CancellationToken::new(),
)
.await;
assert_eq!(outcome.status, TerminalStatus::Completed);
assert_eq!(outcome.result.as_deref(), Some("echo: alpha beta"));
let mut events = Vec::new();
while let Ok(e) = rx.try_recv() {
events.push(e);
}
assert_eq!(events.len(), 3);
assert_eq!(events[0]["content"], "alpha ");
}
#[tokio::test]
async fn echo_honors_cancel() {
let (sink, _rx) = EventSink::channel();
let cancel = CancellationToken::new();
cancel.cancel();
let outcome = EchoExecutor
.run(
RunSpec {
assignment: "a b c".into(),
reasoning_effort: None,
messages: Vec::new(),
},
sink,
SteerInbox::disconnected(),
cancel,
)
.await;
assert_eq!(outcome.status, TerminalStatus::Cancelled);
}
#[tokio::test]
async fn approval_call_sends_approval_kind_and_round_trips_reply() {
let (bridge, mut req_rx) = HostBridge::channel();
let caller = tokio::spawn(async move {
bridge
.approval_call(serde_json::json!({"resource": "/tmp/x"}))
.await
});
let req = req_rx.recv().await.expect("a host request");
assert_eq!(req.kind, HostRequestKind::Approval);
assert_eq!(req.body["resource"], "/tmp/x");
let _ = req.reply.send(serde_json::json!({"approved": true}));
let reply = caller.await.unwrap().expect("decision");
assert_eq!(reply["approved"], true);
}
}