1use async_trait::async_trait;
8use tokio::sync::{mpsc, oneshot};
9use tokio_util::sync::CancellationToken;
10
11use crate::proto::{RunSpec, TerminalStatus};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum HostRequestKind {
16 Approval,
18}
19
20pub struct HostRequest {
23 pub kind: HostRequestKind,
24 pub body: serde_json::Value,
25 pub reply: oneshot::Sender<serde_json::Value>,
26}
27
28#[derive(Clone)]
32pub struct HostBridge {
33 req_tx: mpsc::UnboundedSender<HostRequest>,
34}
35
36impl HostBridge {
37 pub fn channel() -> (Self, mpsc::UnboundedReceiver<HostRequest>) {
39 let (req_tx, req_rx) = mpsc::unbounded_channel();
40 (HostBridge { req_tx }, req_rx)
41 }
42
43 pub async fn approval_call(
47 &self,
48 body: serde_json::Value,
49 ) -> Result<serde_json::Value, String> {
50 self.call(HostRequestKind::Approval, body).await
51 }
52
53 async fn call(
54 &self,
55 kind: HostRequestKind,
56 body: serde_json::Value,
57 ) -> Result<serde_json::Value, String> {
58 let (reply, rx) = oneshot::channel();
59 self.req_tx
60 .send(HostRequest { kind, body, reply })
61 .map_err(|_| "host bridge closed".to_string())?;
62 rx.await
63 .map_err(|_| "host bridge dropped reply".to_string())
64 }
65}
66
67#[derive(Clone)]
69pub struct EventSink {
70 tx: mpsc::UnboundedSender<serde_json::Value>,
71 host: Option<HostBridge>,
72}
73
74impl EventSink {
75 pub fn channel() -> (Self, mpsc::UnboundedReceiver<serde_json::Value>) {
77 let (tx, rx) = mpsc::unbounded_channel();
78 (EventSink { tx, host: None }, rx)
79 }
80 pub fn with_host_bridge(mut self, bridge: HostBridge) -> Self {
82 self.host = Some(bridge);
83 self
84 }
85 pub fn host(&self) -> Option<&HostBridge> {
87 self.host.as_ref()
88 }
89 pub fn emit(&self, event: serde_json::Value) {
91 let _ = self.tx.send(event);
92 }
93}
94
95#[derive(Debug, Clone, PartialEq)]
97pub struct ChildOutcome {
98 pub status: TerminalStatus,
99 pub result: Option<String>,
100 pub error: Option<String>,
101 pub transcript: Vec<serde_json::Value>,
104}
105
106impl ChildOutcome {
107 pub fn completed(result: impl Into<String>) -> Self {
108 Self {
109 status: TerminalStatus::Completed,
110 result: Some(result.into()),
111 error: None,
112 transcript: Vec::new(),
113 }
114 }
115 pub fn error(msg: impl Into<String>) -> Self {
116 Self {
117 status: TerminalStatus::Error,
118 result: None,
119 error: Some(msg.into()),
120 transcript: Vec::new(),
121 }
122 }
123 pub fn cancelled() -> Self {
124 Self {
125 status: TerminalStatus::Cancelled,
126 result: None,
127 error: None,
128 transcript: Vec::new(),
129 }
130 }
131 pub fn suspended(transcript: Vec<serde_json::Value>) -> Self {
134 Self {
135 status: TerminalStatus::Suspended,
136 result: None,
137 error: None,
138 transcript,
139 }
140 }
141}
142
143pub struct SteerInbox {
147 rx: mpsc::UnboundedReceiver<String>,
148}
149
150impl SteerInbox {
151 pub fn channel() -> (mpsc::UnboundedSender<String>, Self) {
153 let (tx, rx) = mpsc::unbounded_channel();
154 (tx, SteerInbox { rx })
155 }
156 pub fn disconnected() -> Self {
158 let (_tx, rx) = mpsc::unbounded_channel();
159 SteerInbox { rx }
160 }
161 pub async fn recv(&mut self) -> Option<String> {
163 self.rx.recv().await
164 }
165}
166
167#[async_trait]
169pub trait ChildExecutor: Send + Sync + 'static {
170 async fn run(
171 &self,
172 spec: RunSpec,
173 events: EventSink,
174 steer: SteerInbox,
175 cancel: CancellationToken,
176 ) -> ChildOutcome;
177}
178
179pub struct EchoExecutor;
186
187pub const ECHO_SLEEP_PREFIX: &str = "__sleep_ms:";
189
190#[async_trait]
191impl ChildExecutor for EchoExecutor {
192 async fn run(
193 &self,
194 spec: RunSpec,
195 events: EventSink,
196 _steer: SteerInbox,
197 cancel: CancellationToken,
198 ) -> ChildOutcome {
199 let mut sleep_ms: Option<u64> = None;
203 let mut words: Vec<&str> = Vec::new();
204 for word in spec.assignment.split_whitespace() {
205 match word
206 .strip_prefix(ECHO_SLEEP_PREFIX)
207 .and_then(|n| n.parse::<u64>().ok())
208 {
209 Some(ms) if sleep_ms.is_none() => sleep_ms = Some(ms),
210 _ => words.push(word),
211 }
212 }
213 if let Some(ms) = sleep_ms {
214 tokio::select! {
215 _ = tokio::time::sleep(std::time::Duration::from_millis(ms)) => {}
216 _ = cancel.cancelled() => return ChildOutcome::cancelled(),
217 }
218 }
219
220 for word in &words {
221 if cancel.is_cancelled() {
222 return ChildOutcome::cancelled();
223 }
224 events.emit(serde_json::json!({ "type": "token", "content": format!("{word} ") }));
225 tokio::task::yield_now().await;
227 }
228 events.emit(serde_json::json!({ "type": "complete" }));
229 ChildOutcome::completed(format!("echo: {}", words.join(" ")))
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[tokio::test]
238 async fn echo_streams_then_completes() {
239 let (sink, mut rx) = EventSink::channel();
240 let outcome = EchoExecutor
241 .run(
242 RunSpec {
243 assignment: "alpha beta".into(),
244 reasoning_effort: None,
245 messages: Vec::new(),
246 },
247 sink,
248 SteerInbox::disconnected(),
249 CancellationToken::new(),
250 )
251 .await;
252 assert_eq!(outcome.status, TerminalStatus::Completed);
253 assert_eq!(outcome.result.as_deref(), Some("echo: alpha beta"));
254
255 let mut events = Vec::new();
256 while let Ok(e) = rx.try_recv() {
257 events.push(e);
258 }
259 assert_eq!(events.len(), 3);
261 assert_eq!(events[0]["content"], "alpha ");
262 }
263
264 #[tokio::test]
265 async fn echo_honors_cancel() {
266 let (sink, _rx) = EventSink::channel();
267 let cancel = CancellationToken::new();
268 cancel.cancel();
269 let outcome = EchoExecutor
270 .run(
271 RunSpec {
272 assignment: "a b c".into(),
273 reasoning_effort: None,
274 messages: Vec::new(),
275 },
276 sink,
277 SteerInbox::disconnected(),
278 cancel,
279 )
280 .await;
281 assert_eq!(outcome.status, TerminalStatus::Cancelled);
282 }
283
284 #[tokio::test]
285 async fn approval_call_sends_approval_kind_and_round_trips_reply() {
286 let (bridge, mut req_rx) = HostBridge::channel();
287 let caller = tokio::spawn(async move {
288 bridge
289 .approval_call(serde_json::json!({"resource": "/tmp/x"}))
290 .await
291 });
292 let req = req_rx.recv().await.expect("a host request");
293 assert_eq!(req.kind, HostRequestKind::Approval);
294 assert_eq!(req.body["resource"], "/tmp/x");
295 let _ = req.reply.send(serde_json::json!({"approved": true}));
296 let reply = caller.await.unwrap().expect("decision");
297 assert_eq!(reply["approved"], true);
298 }
299}