1use std::collections::HashMap;
170use std::future::Future;
171use std::pin::Pin;
172use std::process::Stdio;
173use std::sync::Arc;
174use std::time::Duration;
175
176use serde_json::Value;
177use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
178use tokio::process::{Child, ChildStdin, ChildStdout, Command};
179use tokio::sync::{broadcast, mpsc, oneshot, watch};
180use tokio::task::JoinHandle;
181use tracing::{debug, warn};
182
183use crate::Claude;
184use crate::error::{Error, Result};
185use crate::types::PermissionMode;
186
187pub const DEFAULT_SUBSCRIBER_CAPACITY: usize = 256;
192
193#[derive(Debug, Clone)]
202pub struct PermissionRequest {
203 pub request_id: String,
206 pub tool_name: String,
208 pub input: Value,
210 pub raw: Value,
213}
214
215#[derive(Debug, Clone)]
224pub enum PermissionDecision {
225 Allow {
227 updated_input: Option<Value>,
230 },
231 Deny {
233 message: String,
235 },
236 Defer,
239}
240
241type PermissionFuture = Pin<Box<dyn Future<Output = PermissionDecision> + Send + 'static>>;
242type PermissionFn = dyn Fn(PermissionRequest) -> PermissionFuture + Send + Sync + 'static;
243
244#[derive(Clone)]
257pub struct PermissionHandler {
258 inner: Arc<PermissionFn>,
259}
260
261impl PermissionHandler {
262 pub fn new<F, Fut>(f: F) -> Self
278 where
279 F: Fn(PermissionRequest) -> Fut + Send + Sync + 'static,
280 Fut: Future<Output = PermissionDecision> + Send + 'static,
281 {
282 Self {
283 inner: Arc::new(move |req| Box::pin(f(req))),
284 }
285 }
286
287 fn invoke(&self, req: PermissionRequest) -> PermissionFuture {
288 (self.inner)(req)
289 }
290}
291
292impl std::fmt::Debug for PermissionHandler {
293 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
294 f.debug_struct("PermissionHandler").finish_non_exhaustive()
295 }
296}
297
298#[derive(Debug, Default, Clone)]
305pub struct DuplexOptions {
306 model: Option<String>,
307 system_prompt: Option<String>,
308 append_system_prompt: Option<String>,
309 resume: Option<String>,
310 continue_session: bool,
311 worktree: bool,
312 worktree_name: Option<String>,
313 agent: Option<String>,
314 agents_json: Option<String>,
315 permission_mode: Option<PermissionMode>,
316 dangerously_skip_permissions: bool,
317 additional_args: Vec<String>,
318 subscriber_capacity: Option<usize>,
319 on_permission: Option<PermissionHandler>,
320}
321
322impl DuplexOptions {
323 #[must_use]
325 pub fn model(mut self, model: impl Into<String>) -> Self {
326 self.model = Some(model.into());
327 self
328 }
329
330 #[must_use]
332 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
333 self.system_prompt = Some(prompt.into());
334 self
335 }
336
337 #[must_use]
339 pub fn append_system_prompt(mut self, prompt: impl Into<String>) -> Self {
340 self.append_system_prompt = Some(prompt.into());
341 self
342 }
343
344 #[must_use]
361 pub fn resume(mut self, session_id: impl Into<String>) -> Self {
362 self.resume = Some(session_id.into());
363 self
364 }
365
366 #[must_use]
373 pub fn continue_session(mut self) -> Self {
374 self.continue_session = true;
375 self
376 }
377
378 #[must_use]
389 pub fn worktree(mut self, name: Option<impl Into<String>>) -> Self {
390 self.worktree = true;
391 if let Some(n) = name {
392 self.worktree_name = Some(n.into());
393 }
394 self
395 }
396
397 #[must_use]
411 pub fn agent(mut self, name: impl Into<String>) -> Self {
412 self.agent = Some(name.into());
413 self
414 }
415
416 #[must_use]
428 pub fn agents_json(mut self, json: impl Into<String>) -> Self {
429 self.agents_json = Some(json.into());
430 self
431 }
432
433 #[must_use]
449 pub fn permission_mode(mut self, mode: PermissionMode) -> Self {
450 self.permission_mode = Some(mode);
451 self
452 }
453
454 #[must_use]
462 pub fn dangerously_skip_permissions(mut self) -> Self {
463 self.dangerously_skip_permissions = true;
464 self
465 }
466
467 #[must_use]
472 pub fn arg(mut self, arg: impl Into<String>) -> Self {
473 self.additional_args.push(arg.into());
474 self
475 }
476
477 #[must_use]
485 pub fn subscriber_capacity(mut self, capacity: usize) -> Self {
486 self.subscriber_capacity = Some(capacity);
487 self
488 }
489
490 #[must_use]
502 pub fn on_permission(mut self, handler: PermissionHandler) -> Self {
503 self.on_permission = Some(handler);
504 self
505 }
506
507 fn into_args(self) -> Vec<String> {
508 let mut args = vec![
509 "--print".to_string(),
510 "--verbose".to_string(),
511 "--output-format".to_string(),
512 "stream-json".to_string(),
513 "--input-format".to_string(),
514 "stream-json".to_string(),
515 ];
516
517 if let Some(m) = self.model {
518 args.push("--model".to_string());
519 args.push(m);
520 }
521 if let Some(p) = self.system_prompt {
522 args.push("--system-prompt".to_string());
523 args.push(p);
524 }
525 if let Some(p) = self.append_system_prompt {
526 args.push("--append-system-prompt".to_string());
527 args.push(p);
528 }
529 if let Some(id) = self.resume {
530 args.push("--resume".to_string());
531 args.push(id);
532 }
533 if self.continue_session {
534 args.push("--continue".to_string());
535 }
536 if self.worktree {
537 args.push("--worktree".to_string());
538 if let Some(n) = self.worktree_name {
539 args.push(n);
540 }
541 }
542 if let Some(json) = self.agents_json {
543 args.push("--agents".to_string());
544 args.push(json);
545 }
546 if let Some(name) = self.agent {
547 args.push("--agent".to_string());
548 args.push(name);
549 }
550 if let Some(mode) = self.permission_mode {
551 args.push("--permission-mode".to_string());
552 args.push(mode.as_arg().to_string());
553 }
554 if self.dangerously_skip_permissions {
555 args.push("--dangerously-skip-permissions".to_string());
556 }
557 if self.on_permission.is_some() {
558 args.push("--permission-prompt-tool".to_string());
559 args.push("stdio".to_string());
560 }
561 args.extend(self.additional_args);
562
563 args
564 }
565}
566
567#[derive(Debug, Clone)]
574pub struct TurnResult {
575 pub result: Value,
577 pub events: Vec<Value>,
579}
580
581impl TurnResult {
582 #[must_use]
584 pub fn result_text(&self) -> Option<&str> {
585 self.result.get("result").and_then(Value::as_str)
586 }
587
588 #[must_use]
590 pub fn session_id(&self) -> Option<&str> {
591 self.result.get("session_id").and_then(Value::as_str)
592 }
593
594 #[must_use]
597 pub fn total_cost_usd(&self) -> Option<f64> {
598 self.result
599 .get("total_cost_usd")
600 .or_else(|| self.result.get("cost_usd"))
601 .and_then(Value::as_f64)
602 }
603
604 #[must_use]
606 pub fn duration_ms(&self) -> Option<u64> {
607 self.result.get("duration_ms").and_then(Value::as_u64)
608 }
609}
610
611#[derive(Debug, Clone)]
625pub enum InboundEvent {
626 SystemInit {
629 session_id: String,
632 },
633 Assistant(Value),
636 StreamEvent(Value),
639 User(Value),
642 Other(Value),
645}
646
647fn classify(msg: &Value) -> InboundEvent {
648 match msg.get("type").and_then(Value::as_str) {
649 Some("system") => {
650 if msg.get("subtype").and_then(Value::as_str) == Some("init")
651 && let Some(id) = msg.get("session_id").and_then(Value::as_str)
652 {
653 return InboundEvent::SystemInit {
654 session_id: id.to_string(),
655 };
656 }
657 InboundEvent::Other(msg.clone())
658 }
659 Some("assistant") => InboundEvent::Assistant(msg.clone()),
660 Some("stream_event") => InboundEvent::StreamEvent(msg.clone()),
661 Some("user") => InboundEvent::User(msg.clone()),
662 _ => InboundEvent::Other(msg.clone()),
663 }
664}
665
666#[derive(Debug, Clone)]
681pub enum SessionExitStatus {
682 Running,
684 Completed,
687 Failed(String),
690}
691
692#[derive(Debug)]
701pub struct DuplexSession {
702 outbound_tx: mpsc::UnboundedSender<OutboundMsg>,
703 events_tx: broadcast::Sender<InboundEvent>,
704 exit_rx: watch::Receiver<SessionExitStatus>,
705 join: JoinHandle<Result<()>>,
706}
707
708#[derive(Debug)]
709enum OutboundMsg {
710 Send {
711 prompt: String,
712 reply: oneshot::Sender<Result<TurnResult>>,
713 },
714 PermissionResponse {
715 request_id: String,
716 decision: PermissionDecision,
717 },
718 Interrupt {
719 reply: oneshot::Sender<Result<()>>,
720 },
721}
722
723impl DuplexSession {
724 pub async fn spawn(claude: &Claude, opts: DuplexOptions) -> Result<Self> {
732 let capacity = opts
733 .subscriber_capacity
734 .unwrap_or(DEFAULT_SUBSCRIBER_CAPACITY);
735 let permission_handler = opts.on_permission.clone();
736
737 let mut command_args = Vec::new();
738 command_args.extend(claude.global_args.clone());
739 command_args.extend(opts.into_args());
740
741 debug!(
742 binary = %claude.binary.display(),
743 args = ?command_args,
744 "spawning duplex claude session"
745 );
746
747 let mut cmd = Command::new(&claude.binary);
748 cmd.args(&command_args)
749 .env_remove("CLAUDECODE")
750 .env_remove("CLAUDE_CODE_ENTRYPOINT")
751 .envs(&claude.env)
752 .stdin(Stdio::piped())
753 .stdout(Stdio::piped())
754 .stderr(Stdio::piped())
755 .kill_on_drop(true);
756
757 if let Some(ref dir) = claude.working_dir {
758 cmd.current_dir(dir);
759 }
760
761 let mut child = cmd.spawn().map_err(|e| Error::Io {
762 message: format!("failed to spawn claude: {e}"),
763 source: e,
764 working_dir: claude.working_dir.clone(),
765 })?;
766
767 let stdin = child.stdin.take().expect("stdin was piped");
768 let stdout = child.stdout.take().expect("stdout was piped");
769
770 let (outbound_tx, outbound_rx) = mpsc::unbounded_channel();
771 let (events_tx, _initial_rx) = broadcast::channel(capacity);
772 let (exit_tx, exit_rx) = watch::channel(SessionExitStatus::Running);
773
774 let join = tokio::spawn(run_session(
775 child,
776 stdin,
777 stdout,
778 outbound_rx,
779 events_tx.clone(),
780 permission_handler,
781 exit_tx,
782 ));
783
784 Ok(Self {
785 outbound_tx,
786 events_tx,
787 exit_rx,
788 join,
789 })
790 }
791
792 pub async fn send(&self, prompt: impl Into<String>) -> Result<TurnResult> {
798 let (reply_tx, reply_rx) = oneshot::channel();
799 self.outbound_tx
800 .send(OutboundMsg::Send {
801 prompt: prompt.into(),
802 reply: reply_tx,
803 })
804 .map_err(|_| Error::DuplexClosed)?;
805 reply_rx.await.map_err(|_| Error::DuplexClosed)?
806 }
807
808 #[must_use]
843 pub fn subscribe(&self) -> broadcast::Receiver<InboundEvent> {
844 self.events_tx.subscribe()
845 }
846
847 #[must_use]
858 pub fn is_alive(&self) -> bool {
859 matches!(*self.exit_rx.borrow(), SessionExitStatus::Running)
860 }
861
862 #[must_use]
871 pub fn exit_status(&self) -> SessionExitStatus {
872 self.exit_rx.borrow().clone()
873 }
874
875 pub async fn wait_for_exit(&self) -> SessionExitStatus {
887 let mut rx = self.exit_rx.clone();
888 loop {
889 {
890 let value = rx.borrow_and_update();
891 if !matches!(*value, SessionExitStatus::Running) {
892 return value.clone();
893 }
894 }
895 if rx.changed().await.is_err() {
896 return rx.borrow().clone();
897 }
898 }
899 }
900
901 pub fn respond_to_permission(
946 &self,
947 request_id: impl Into<String>,
948 decision: PermissionDecision,
949 ) -> Result<()> {
950 if matches!(decision, PermissionDecision::Defer) {
951 warn!("respond_to_permission called with Defer; ignoring");
952 return Ok(());
953 }
954 self.outbound_tx
955 .send(OutboundMsg::PermissionResponse {
956 request_id: request_id.into(),
957 decision,
958 })
959 .map_err(|_| Error::DuplexClosed)?;
960 Ok(())
961 }
962
963 pub async fn interrupt(&self) -> Result<()> {
1004 let (reply_tx, reply_rx) = oneshot::channel();
1005 self.outbound_tx
1006 .send(OutboundMsg::Interrupt { reply: reply_tx })
1007 .map_err(|_| Error::DuplexClosed)?;
1008 reply_rx.await.map_err(|_| Error::DuplexClosed)?
1009 }
1010
1011 pub async fn close(self) -> Result<()> {
1017 drop(self.outbound_tx);
1018 drop(self.events_tx);
1019 match self.join.await {
1020 Ok(result) => result,
1021 Err(e) if e.is_cancelled() => Ok(()),
1022 Err(e) => Err(Error::Io {
1023 message: format!("duplex session task panicked: {e}"),
1024 source: std::io::Error::other(e.to_string()),
1025 working_dir: None,
1026 }),
1027 }
1028 }
1029}
1030
1031const SHUTDOWN_BUDGET: Duration = Duration::from_secs(5);
1035
1036async fn run_session(
1037 mut child: Child,
1038 mut stdin: ChildStdin,
1039 stdout: ChildStdout,
1040 mut outbound_rx: mpsc::UnboundedReceiver<OutboundMsg>,
1041 events_tx: broadcast::Sender<InboundEvent>,
1042 permission_handler: Option<PermissionHandler>,
1043 exit_tx: watch::Sender<SessionExitStatus>,
1044) -> Result<()> {
1045 let mut lines = BufReader::new(stdout).lines();
1046 let mut pending: Option<(oneshot::Sender<Result<TurnResult>>, Vec<Value>)> = None;
1047 let mut pending_control: HashMap<String, oneshot::Sender<Result<()>>> = HashMap::new();
1048 let mut next_control_id: u64 = 0;
1049 let mut stream_err: Option<Error> = None;
1050
1051 loop {
1052 tokio::select! {
1053 biased;
1054
1055 line = lines.next_line() => match line {
1056 Ok(Some(l)) => {
1057 if l.trim().is_empty() {
1058 continue;
1059 }
1060 let parsed = match serde_json::from_str::<Value>(&l) {
1061 Ok(v) => v,
1062 Err(e) => {
1063 debug!(line = %l, error = %e, "failed to parse duplex event, skipping");
1064 continue;
1065 }
1066 };
1067 match handle_inbound(parsed, &mut pending, &events_tx) {
1068 InboundAction::None => {}
1069 InboundAction::Permission(req) => {
1070 let request_id = req.request_id.clone();
1071 let decision = match permission_handler.as_ref() {
1072 Some(h) => h.invoke(req).await,
1073 None => {
1074 warn!(
1075 request_id = %request_id,
1076 "received can_use_tool with no permission handler; auto-denying"
1077 );
1078 PermissionDecision::Deny {
1079 message:
1080 "no permission handler configured on duplex session"
1081 .into(),
1082 }
1083 }
1084 };
1085 if matches!(decision, PermissionDecision::Defer) {
1086 debug!(
1087 request_id = %request_id,
1088 "permission handler deferred; waiting for respond_to_permission"
1089 );
1090 } else if let Err(e) =
1091 write_permission_response(&mut stdin, &request_id, &decision).await
1092 {
1093 warn!(error = %e, "failed to write permission response");
1094 }
1095 }
1096 InboundAction::ControlResponse { request_id, outcome } => {
1097 if let Some(reply) = pending_control.remove(&request_id) {
1098 let _ = reply.send(outcome);
1099 } else {
1100 debug!(
1101 request_id = %request_id,
1102 "received control_response with no pending request"
1103 );
1104 }
1105 }
1106 }
1107 }
1108 Ok(None) => break,
1109 Err(e) => {
1110 stream_err = Some(Error::Io {
1111 message: "failed to read duplex stdout".to_string(),
1112 source: e,
1113 working_dir: None,
1114 });
1115 break;
1116 }
1117 },
1118
1119 msg = outbound_rx.recv() => match msg {
1120 Some(OutboundMsg::Send { prompt, reply }) => {
1121 if pending.is_some() {
1122 let _ = reply.send(Err(Error::DuplexTurnInFlight));
1123 continue;
1124 }
1125 if let Err(e) = write_user(&mut stdin, &prompt).await {
1126 let _ = reply.send(Err(e));
1127 continue;
1128 }
1129 pending = Some((reply, Vec::new()));
1130 }
1131 Some(OutboundMsg::PermissionResponse { request_id, decision }) => {
1132 if let Err(e) =
1133 write_permission_response(&mut stdin, &request_id, &decision).await
1134 {
1135 warn!(error = %e, "failed to write deferred permission response");
1136 }
1137 }
1138 Some(OutboundMsg::Interrupt { reply }) => {
1139 next_control_id += 1;
1140 let request_id = format!("interrupt-{next_control_id}");
1141 if let Err(e) =
1142 write_control_request(&mut stdin, &request_id, "interrupt").await
1143 {
1144 let _ = reply.send(Err(e));
1145 continue;
1146 }
1147 pending_control.insert(request_id, reply);
1148 }
1149 None => break,
1150 },
1151 }
1152 }
1153
1154 drop(stdin);
1155 match tokio::time::timeout(SHUTDOWN_BUDGET, child.wait()).await {
1156 Ok(Ok(_status)) => {}
1157 Ok(Err(e)) => {
1158 warn!(error = %e, "failed to wait for duplex child");
1159 }
1160 Err(_) => {
1161 warn!("duplex child did not exit within shutdown budget; killing");
1162 let _ = child.kill().await;
1163 }
1164 }
1165
1166 if let Some((reply, _)) = pending.take() {
1167 let _ = reply.send(Err(Error::DuplexClosed));
1168 }
1169 for (_, reply) in pending_control.drain() {
1170 let _ = reply.send(Err(Error::DuplexClosed));
1171 }
1172
1173 let result = match stream_err {
1174 Some(e) => Err(e),
1175 None => Ok(()),
1176 };
1177 let final_state = match &result {
1178 Ok(()) => SessionExitStatus::Completed,
1179 Err(e) => SessionExitStatus::Failed(e.to_string()),
1180 };
1181 let _ = exit_tx.send(final_state);
1182 result
1183}
1184
1185enum InboundAction {
1189 None,
1191 Permission(PermissionRequest),
1195 ControlResponse {
1200 request_id: String,
1201 outcome: Result<()>,
1202 },
1203}
1204
1205fn handle_inbound(
1206 msg: Value,
1207 pending: &mut Option<(oneshot::Sender<Result<TurnResult>>, Vec<Value>)>,
1208 events_tx: &broadcast::Sender<InboundEvent>,
1209) -> InboundAction {
1210 match msg.get("type").and_then(Value::as_str) {
1211 Some("result") => {
1212 if let Some((reply, events)) = pending.take() {
1213 let _ = reply.send(Ok(TurnResult {
1214 result: msg,
1215 events,
1216 }));
1217 } else {
1218 debug!("dropping orphan result event with no pending turn");
1219 }
1220 InboundAction::None
1221 }
1222 Some("control_request") => {
1223 if msg
1226 .get("request")
1227 .and_then(|r| r.get("subtype"))
1228 .and_then(Value::as_str)
1229 == Some("can_use_tool")
1230 && let Some(req) = parse_permission_request(&msg)
1231 {
1232 if let Some((_, events)) = pending.as_mut() {
1233 events.push(msg);
1234 }
1235 return InboundAction::Permission(req);
1236 }
1237 debug!(
1238 ?msg,
1239 "received unhandled control_request; treating as Other"
1240 );
1241 let _ = events_tx.send(InboundEvent::Other(msg.clone()));
1242 if let Some((_, events)) = pending.as_mut() {
1243 events.push(msg);
1244 }
1245 InboundAction::None
1246 }
1247 Some("control_response") => {
1248 if let Some((request_id, outcome)) = parse_control_response(&msg) {
1249 return InboundAction::ControlResponse {
1250 request_id,
1251 outcome,
1252 };
1253 }
1254 debug!(
1255 ?msg,
1256 "received malformed control_response; treating as Other"
1257 );
1258 let _ = events_tx.send(InboundEvent::Other(msg.clone()));
1259 if let Some((_, events)) = pending.as_mut() {
1260 events.push(msg);
1261 }
1262 InboundAction::None
1263 }
1264 _ => {
1265 let _ = events_tx.send(classify(&msg));
1268
1269 if let Some((_, events)) = pending.as_mut() {
1270 events.push(msg);
1271 } else {
1272 debug!("dropping inbound event with no pending turn");
1273 }
1274 InboundAction::None
1275 }
1276 }
1277}
1278
1279fn parse_permission_request(msg: &Value) -> Option<PermissionRequest> {
1280 let request_id = msg.get("request_id").and_then(Value::as_str)?;
1281 let request = msg.get("request")?;
1282 let tool_name = request.get("tool_name").and_then(Value::as_str)?;
1283 let input = request.get("input").cloned().unwrap_or(Value::Null);
1284 Some(PermissionRequest {
1285 request_id: request_id.to_string(),
1286 tool_name: tool_name.to_string(),
1287 input,
1288 raw: request.clone(),
1289 })
1290}
1291
1292fn parse_control_response(msg: &Value) -> Option<(String, Result<()>)> {
1298 let response = msg.get("response")?;
1299 let request_id = response.get("request_id").and_then(Value::as_str)?;
1300 let outcome = match response.get("subtype").and_then(Value::as_str) {
1301 Some("success") => Ok(()),
1302 Some("error") => {
1303 let message = response
1304 .get("error")
1305 .and_then(Value::as_str)
1306 .unwrap_or("unknown control_response error")
1307 .to_string();
1308 Err(Error::DuplexControlFailed { message })
1309 }
1310 _ => return None,
1311 };
1312 Some((request_id.to_string(), outcome))
1313}
1314
1315async fn write_user(stdin: &mut ChildStdin, prompt: &str) -> Result<()> {
1316 let user_msg = serde_json::json!({
1317 "type": "user",
1318 "message": {
1319 "role": "user",
1320 "content": prompt,
1321 },
1322 "parent_tool_use_id": null,
1323 });
1324 write_line(stdin, &user_msg, "user message").await
1325}
1326
1327async fn write_control_request(
1328 stdin: &mut ChildStdin,
1329 request_id: &str,
1330 subtype: &str,
1331) -> Result<()> {
1332 let envelope = serde_json::json!({
1333 "type": "control_request",
1334 "request_id": request_id,
1335 "request": { "subtype": subtype },
1336 });
1337 write_line(stdin, &envelope, "control_request").await
1338}
1339
1340async fn write_permission_response(
1341 stdin: &mut ChildStdin,
1342 request_id: &str,
1343 decision: &PermissionDecision,
1344) -> Result<()> {
1345 let inner = match decision {
1346 PermissionDecision::Allow { updated_input } => {
1347 let mut obj = serde_json::Map::new();
1348 obj.insert("behavior".to_string(), Value::String("allow".to_string()));
1349 if let Some(input) = updated_input {
1350 obj.insert("updatedInput".to_string(), input.clone());
1351 }
1352 Value::Object(obj)
1353 }
1354 PermissionDecision::Deny { message } => serde_json::json!({
1355 "behavior": "deny",
1356 "message": message,
1357 }),
1358 PermissionDecision::Defer => {
1359 return Ok(());
1361 }
1362 };
1363 let envelope = serde_json::json!({
1364 "type": "control_response",
1365 "response": {
1366 "request_id": request_id,
1367 "subtype": "success",
1368 "response": inner,
1369 },
1370 });
1371 write_line(stdin, &envelope, "control_response").await
1372}
1373
1374async fn write_line(stdin: &mut ChildStdin, value: &Value, what: &'static str) -> Result<()> {
1375 let mut line = serde_json::to_string(value).map_err(|e| Error::Json {
1376 message: format!("failed to serialize duplex {what}"),
1377 source: e,
1378 })?;
1379 line.push('\n');
1380 stdin
1381 .write_all(line.as_bytes())
1382 .await
1383 .map_err(|e| Error::Io {
1384 message: format!("failed to write {what} to duplex stdin"),
1385 source: e,
1386 working_dir: None,
1387 })?;
1388 stdin.flush().await.map_err(|e| Error::Io {
1389 message: "failed to flush duplex stdin".to_string(),
1390 source: e,
1391 working_dir: None,
1392 })?;
1393 Ok(())
1394}
1395
1396#[cfg(test)]
1397mod tests {
1398 use super::*;
1399 use serde_json::json;
1400
1401 #[test]
1402 fn into_args_default_includes_required_flags() {
1403 let args = DuplexOptions::default().into_args();
1404 assert!(args.contains(&"--print".to_string()));
1405 assert!(args.contains(&"--verbose".to_string()));
1406 assert!(
1407 args.windows(2)
1408 .any(|w| w == ["--output-format", "stream-json"])
1409 );
1410 assert!(
1411 args.windows(2)
1412 .any(|w| w == ["--input-format", "stream-json"])
1413 );
1414 }
1415
1416 #[test]
1417 fn into_args_includes_model() {
1418 let args = DuplexOptions::default().model("haiku").into_args();
1419 assert!(args.windows(2).any(|w| w == ["--model", "haiku"]));
1420 }
1421
1422 #[test]
1423 fn into_args_includes_system_prompts() {
1424 let args = DuplexOptions::default()
1425 .system_prompt("be concise")
1426 .append_system_prompt("also polite")
1427 .into_args();
1428 assert!(
1429 args.windows(2)
1430 .any(|w| w == ["--system-prompt", "be concise"])
1431 );
1432 assert!(
1433 args.windows(2)
1434 .any(|w| w == ["--append-system-prompt", "also polite"])
1435 );
1436 }
1437
1438 #[test]
1439 fn into_args_appends_raw_args_last() {
1440 let args = DuplexOptions::default()
1441 .arg("--add-dir")
1442 .arg("/tmp/foo")
1443 .into_args();
1444 assert_eq!(&args[args.len() - 2..], &["--add-dir", "/tmp/foo"]);
1446 }
1447
1448 #[test]
1449 fn into_args_includes_resume_when_set() {
1450 let args = DuplexOptions::default().resume("abc-123").into_args();
1451 assert!(args.windows(2).any(|w| w == ["--resume", "abc-123"]));
1452 }
1453
1454 #[test]
1455 fn into_args_omits_resume_by_default() {
1456 let args = DuplexOptions::default().into_args();
1457 assert!(
1458 !args.iter().any(|a| a == "--resume"),
1459 "--resume should not appear without an explicit resume(...) call; got {args:?}"
1460 );
1461 }
1462
1463 #[test]
1464 fn into_args_includes_continue_when_set() {
1465 let args = DuplexOptions::default().continue_session().into_args();
1466 assert!(args.iter().any(|a| a == "--continue"));
1467 }
1468
1469 #[test]
1470 fn into_args_omits_continue_by_default() {
1471 let args = DuplexOptions::default().into_args();
1472 assert!(!args.iter().any(|a| a == "--continue"));
1473 }
1474
1475 #[test]
1476 fn into_args_includes_worktree_flag_without_name() {
1477 let args = DuplexOptions::default().worktree(None::<&str>).into_args();
1478 assert!(args.iter().any(|a| a == "--worktree"));
1479 let pos = args.iter().position(|a| a == "--worktree").unwrap();
1481 assert!(
1482 args.get(pos + 1).is_none_or(|a| a.starts_with("--")),
1483 "--worktree without a name should not be followed by a positional; got {args:?}"
1484 );
1485 }
1486
1487 #[test]
1488 fn into_args_includes_worktree_flag_with_name() {
1489 let args = DuplexOptions::default()
1490 .worktree(Some("agent-xyz"))
1491 .into_args();
1492 let pos = args.iter().position(|a| a == "--worktree").unwrap();
1493 assert_eq!(args.get(pos + 1).map(String::as_str), Some("agent-xyz"));
1494 }
1495
1496 #[test]
1497 fn into_args_omits_worktree_by_default() {
1498 let args = DuplexOptions::default().into_args();
1499 assert!(
1500 !args.iter().any(|a| a == "--worktree"),
1501 "--worktree should not appear without an explicit worktree(...) call; got {args:?}"
1502 );
1503 }
1504
1505 #[test]
1506 fn worktree_lands_before_additional_args() {
1507 let args = DuplexOptions::default()
1509 .worktree(Some("foo"))
1510 .arg("--")
1511 .arg("trailing")
1512 .into_args();
1513 let wt_pos = args.iter().position(|a| a == "--worktree").unwrap();
1514 let dash_dash_pos = args.iter().position(|a| a == "--").unwrap();
1515 assert!(
1516 wt_pos < dash_dash_pos,
1517 "--worktree must precede `--` separator; got {args:?}"
1518 );
1519 }
1520
1521 #[test]
1522 fn into_args_includes_agent_when_set() {
1523 let args = DuplexOptions::default().agent("rust-qa").into_args();
1524 assert!(
1525 args.windows(2).any(|w| w == ["--agent", "rust-qa"]),
1526 "missing --agent rust-qa in {args:?}"
1527 );
1528 }
1529
1530 #[test]
1531 fn into_args_omits_agent_by_default() {
1532 let args = DuplexOptions::default().into_args();
1533 assert!(
1534 !args.iter().any(|a| a == "--agent"),
1535 "--agent should not appear without an explicit agent(...) call; got {args:?}"
1536 );
1537 }
1538
1539 #[test]
1540 fn into_args_includes_agents_json_when_set() {
1541 let json = r#"{"reviewer":{"description":"r","prompt":"p"}}"#;
1542 let args = DuplexOptions::default().agents_json(json).into_args();
1543 let pos = args.iter().position(|a| a == "--agents").unwrap();
1544 assert_eq!(args.get(pos + 1).map(String::as_str), Some(json));
1545 }
1546
1547 #[test]
1548 fn into_args_omits_agents_json_by_default() {
1549 let args = DuplexOptions::default().into_args();
1550 assert!(!args.iter().any(|a| a == "--agents"));
1551 }
1552
1553 #[test]
1554 fn agent_and_agents_json_compose() {
1555 let json = r#"{"reviewer":{"description":"r","prompt":"p"}}"#;
1556 let args = DuplexOptions::default()
1557 .agents_json(json)
1558 .agent("reviewer")
1559 .into_args();
1560 assert!(args.iter().any(|a| a == "--agents"));
1562 assert!(args.iter().any(|a| a == "--agent"));
1563 }
1564
1565 #[test]
1566 fn agent_lands_before_additional_args() {
1567 let args = DuplexOptions::default()
1568 .agent("rust-qa")
1569 .arg("--")
1570 .arg("trailing")
1571 .into_args();
1572 let agent_pos = args.iter().position(|a| a == "--agent").unwrap();
1573 let dash_dash_pos = args.iter().position(|a| a == "--").unwrap();
1574 assert!(
1575 agent_pos < dash_dash_pos,
1576 "--agent must precede `--` separator; got {args:?}"
1577 );
1578 }
1579
1580 #[test]
1581 fn agents_json_lands_before_additional_args() {
1582 let args = DuplexOptions::default()
1583 .agents_json("{}")
1584 .arg("--")
1585 .arg("trailing")
1586 .into_args();
1587 let agents_pos = args.iter().position(|a| a == "--agents").unwrap();
1588 let dash_dash_pos = args.iter().position(|a| a == "--").unwrap();
1589 assert!(
1590 agents_pos < dash_dash_pos,
1591 "--agents must precede `--` separator; got {args:?}"
1592 );
1593 }
1594
1595 #[test]
1596 fn resume_lands_before_additional_args() {
1597 let args = DuplexOptions::default()
1602 .resume("xyz")
1603 .arg("--")
1604 .arg("trailing")
1605 .into_args();
1606 let resume_pos = args.iter().position(|a| a == "--resume").unwrap();
1607 let dash_dash_pos = args.iter().position(|a| a == "--").unwrap();
1608 assert!(
1609 resume_pos < dash_dash_pos,
1610 "--resume must precede `--` separator; got {args:?}"
1611 );
1612 }
1613
1614 #[test]
1615 fn turn_result_accessors_pull_from_result() {
1616 let r = TurnResult {
1617 result: json!({
1618 "type": "result",
1619 "result": "hello",
1620 "session_id": "sess-123",
1621 "total_cost_usd": 0.0042,
1622 "duration_ms": 1234_u64,
1623 }),
1624 events: vec![],
1625 };
1626 assert_eq!(r.result_text(), Some("hello"));
1627 assert_eq!(r.session_id(), Some("sess-123"));
1628 assert_eq!(r.total_cost_usd(), Some(0.0042));
1629 assert_eq!(r.duration_ms(), Some(1234));
1630 }
1631
1632 #[test]
1633 fn turn_result_total_cost_falls_back_to_legacy_field() {
1634 let r = TurnResult {
1635 result: json!({ "cost_usd": 0.5 }),
1636 events: vec![],
1637 };
1638 assert_eq!(r.total_cost_usd(), Some(0.5));
1639 }
1640
1641 #[test]
1642 fn turn_result_accessors_return_none_when_missing() {
1643 let r = TurnResult {
1644 result: json!({}),
1645 events: vec![],
1646 };
1647 assert_eq!(r.result_text(), None);
1648 assert_eq!(r.session_id(), None);
1649 assert_eq!(r.total_cost_usd(), None);
1650 assert_eq!(r.duration_ms(), None);
1651 }
1652
1653 #[test]
1654 fn handle_inbound_appends_non_result_to_pending_events() {
1655 let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
1656 let (events_tx, _events_rx) = broadcast::channel(16);
1657 let mut pending = Some((tx, Vec::new()));
1658 handle_inbound(
1659 json!({ "type": "assistant", "message": {} }),
1660 &mut pending,
1661 &events_tx,
1662 );
1663 let (_, events) = pending.as_ref().unwrap();
1664 assert_eq!(events.len(), 1);
1665 assert_eq!(
1666 events[0].get("type").and_then(Value::as_str),
1667 Some("assistant")
1668 );
1669 }
1670
1671 #[test]
1672 fn handle_inbound_resolves_pending_on_result() {
1673 let (tx, rx) = oneshot::channel::<Result<TurnResult>>();
1674 let (events_tx, _events_rx) = broadcast::channel(16);
1675 let mut pending = Some((tx, vec![json!({ "type": "assistant" })]));
1676 handle_inbound(
1677 json!({ "type": "result", "result": "ok" }),
1678 &mut pending,
1679 &events_tx,
1680 );
1681 assert!(pending.is_none());
1682 let received = rx.blocking_recv().unwrap().unwrap();
1683 assert_eq!(received.result_text(), Some("ok"));
1684 assert_eq!(received.events.len(), 1);
1685 }
1686
1687 #[test]
1688 fn handle_inbound_drops_orphans_without_pending_turn() {
1689 let (events_tx, _events_rx) = broadcast::channel(16);
1690 let mut pending: Option<(oneshot::Sender<Result<TurnResult>>, Vec<Value>)> = None;
1691 handle_inbound(json!({ "type": "assistant" }), &mut pending, &events_tx);
1692 handle_inbound(
1693 json!({ "type": "result", "result": "ok" }),
1694 &mut pending,
1695 &events_tx,
1696 );
1697 assert!(pending.is_none());
1698 }
1699
1700 #[test]
1701 fn handle_inbound_broadcasts_classified_event() {
1702 let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
1703 let (events_tx, mut events_rx) = broadcast::channel(16);
1704 let mut pending = Some((tx, Vec::new()));
1705 handle_inbound(
1706 json!({ "type": "assistant", "message": { "role": "assistant" } }),
1707 &mut pending,
1708 &events_tx,
1709 );
1710 let event = events_rx.try_recv().expect("classified event broadcast");
1711 assert!(matches!(event, InboundEvent::Assistant(_)));
1712 }
1713
1714 #[test]
1715 fn handle_inbound_does_not_broadcast_result() {
1716 let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
1717 let (events_tx, mut events_rx) = broadcast::channel(16);
1718 let mut pending = Some((tx, Vec::new()));
1719 handle_inbound(
1720 json!({ "type": "result", "result": "ok" }),
1721 &mut pending,
1722 &events_tx,
1723 );
1724 assert!(events_rx.try_recv().is_err());
1726 }
1727
1728 #[test]
1729 fn classify_system_init_pulls_session_id() {
1730 let v = json!({
1731 "type": "system",
1732 "subtype": "init",
1733 "session_id": "sess-abc",
1734 });
1735 match classify(&v) {
1736 InboundEvent::SystemInit { session_id } => assert_eq!(session_id, "sess-abc"),
1737 other => panic!("expected SystemInit, got {other:?}"),
1738 }
1739 }
1740
1741 #[test]
1742 fn classify_system_without_init_subtype_is_other() {
1743 let v = json!({ "type": "system", "subtype": "compaction" });
1744 assert!(matches!(classify(&v), InboundEvent::Other(_)));
1745 }
1746
1747 #[test]
1748 fn classify_system_init_without_session_id_is_other() {
1749 let v = json!({ "type": "system", "subtype": "init" });
1750 assert!(matches!(classify(&v), InboundEvent::Other(_)));
1751 }
1752
1753 #[test]
1754 fn classify_assistant_stream_event_user() {
1755 assert!(matches!(
1756 classify(&json!({ "type": "assistant" })),
1757 InboundEvent::Assistant(_)
1758 ));
1759 assert!(matches!(
1760 classify(&json!({ "type": "stream_event" })),
1761 InboundEvent::StreamEvent(_)
1762 ));
1763 assert!(matches!(
1764 classify(&json!({ "type": "user" })),
1765 InboundEvent::User(_)
1766 ));
1767 }
1768
1769 #[test]
1770 fn classify_unknown_type_is_other() {
1771 assert!(matches!(
1772 classify(&json!({ "type": "control_request" })),
1773 InboundEvent::Other(_)
1774 ));
1775 assert!(matches!(
1776 classify(&json!({ "type": "future_thing" })),
1777 InboundEvent::Other(_)
1778 ));
1779 assert!(matches!(classify(&json!({})), InboundEvent::Other(_)));
1780 }
1781
1782 #[test]
1783 fn into_args_does_not_emit_subscriber_capacity_flag() {
1784 let args = DuplexOptions::default().subscriber_capacity(64).into_args();
1786 assert!(!args.iter().any(|a| a.contains("subscriber")));
1787 assert!(!args.iter().any(|a| a.contains("capacity")));
1788 }
1789
1790 #[test]
1791 fn into_args_includes_permission_prompt_tool_when_handler_set() {
1792 let handler = PermissionHandler::new(|_req| async move {
1793 PermissionDecision::Allow {
1794 updated_input: None,
1795 }
1796 });
1797 let args = DuplexOptions::default().on_permission(handler).into_args();
1798 assert!(
1799 args.windows(2)
1800 .any(|w| w == ["--permission-prompt-tool", "stdio"])
1801 );
1802 }
1803
1804 #[test]
1805 fn into_args_omits_permission_prompt_tool_without_handler() {
1806 let args = DuplexOptions::default().into_args();
1807 assert!(!args.iter().any(|a| a == "--permission-prompt-tool"));
1808 }
1809
1810 #[test]
1811 fn into_args_emits_permission_mode_flag() {
1812 let args = DuplexOptions::default()
1813 .permission_mode(PermissionMode::AcceptEdits)
1814 .into_args();
1815 assert!(
1816 args.windows(2)
1817 .any(|w| w == ["--permission-mode", "acceptEdits"]),
1818 "missing --permission-mode acceptEdits in {args:?}"
1819 );
1820 }
1821
1822 #[test]
1823 fn into_args_emits_plan_mode() {
1824 let args = DuplexOptions::default()
1825 .permission_mode(PermissionMode::Plan)
1826 .into_args();
1827 assert!(args.windows(2).any(|w| w == ["--permission-mode", "plan"]));
1828 }
1829
1830 #[test]
1831 fn into_args_omits_permission_mode_by_default() {
1832 let args = DuplexOptions::default().into_args();
1833 assert!(!args.iter().any(|a| a == "--permission-mode"));
1834 }
1835
1836 #[test]
1837 fn into_args_emits_dangerously_skip_permissions_flag() {
1838 let args = DuplexOptions::default()
1839 .dangerously_skip_permissions()
1840 .into_args();
1841 assert!(args.iter().any(|a| a == "--dangerously-skip-permissions"));
1842 }
1843
1844 #[test]
1845 fn into_args_omits_dangerously_skip_by_default() {
1846 let args = DuplexOptions::default().into_args();
1847 assert!(!args.iter().any(|a| a == "--dangerously-skip-permissions"));
1848 }
1849
1850 #[test]
1851 fn parse_permission_request_extracts_fields() {
1852 let msg = json!({
1853 "type": "control_request",
1854 "request_id": "req-1",
1855 "request": {
1856 "subtype": "can_use_tool",
1857 "tool_name": "Bash",
1858 "input": { "command": "ls" }
1859 }
1860 });
1861 let req = parse_permission_request(&msg).expect("permission request");
1862 assert_eq!(req.request_id, "req-1");
1863 assert_eq!(req.tool_name, "Bash");
1864 assert_eq!(req.input, json!({ "command": "ls" }));
1865 assert_eq!(
1866 req.raw.get("subtype").and_then(Value::as_str),
1867 Some("can_use_tool")
1868 );
1869 }
1870
1871 #[test]
1872 fn parse_permission_request_returns_none_when_missing_request_id() {
1873 let msg = json!({
1874 "type": "control_request",
1875 "request": {
1876 "subtype": "can_use_tool",
1877 "tool_name": "Bash",
1878 }
1879 });
1880 assert!(parse_permission_request(&msg).is_none());
1881 }
1882
1883 #[test]
1884 fn parse_permission_request_returns_none_when_missing_tool_name() {
1885 let msg = json!({
1886 "type": "control_request",
1887 "request_id": "req-1",
1888 "request": { "subtype": "can_use_tool" }
1889 });
1890 assert!(parse_permission_request(&msg).is_none());
1891 }
1892
1893 #[test]
1894 fn parse_permission_request_handles_missing_input() {
1895 let msg = json!({
1896 "type": "control_request",
1897 "request_id": "req-1",
1898 "request": {
1899 "subtype": "can_use_tool",
1900 "tool_name": "Bash",
1901 }
1902 });
1903 let req = parse_permission_request(&msg).expect("request");
1904 assert_eq!(req.input, Value::Null);
1905 }
1906
1907 #[test]
1908 fn handle_inbound_returns_permission_for_can_use_tool() {
1909 let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
1910 let (events_tx, _events_rx) = broadcast::channel(16);
1911 let mut pending = Some((tx, Vec::new()));
1912 let action = handle_inbound(
1913 json!({
1914 "type": "control_request",
1915 "request_id": "req-1",
1916 "request": {
1917 "subtype": "can_use_tool",
1918 "tool_name": "Bash",
1919 "input": { "command": "ls" }
1920 }
1921 }),
1922 &mut pending,
1923 &events_tx,
1924 );
1925 match action {
1926 InboundAction::Permission(req) => {
1927 assert_eq!(req.request_id, "req-1");
1928 assert_eq!(req.tool_name, "Bash");
1929 }
1930 InboundAction::None | InboundAction::ControlResponse { .. } => {
1931 panic!("expected Permission action");
1932 }
1933 }
1934 let (_, events) = pending.as_ref().unwrap();
1936 assert_eq!(events.len(), 1);
1937 }
1938
1939 #[test]
1940 fn handle_inbound_treats_unknown_control_request_as_other() {
1941 let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
1942 let (events_tx, mut events_rx) = broadcast::channel(16);
1943 let mut pending = Some((tx, Vec::new()));
1944 let action = handle_inbound(
1945 json!({
1946 "type": "control_request",
1947 "request_id": "req-2",
1948 "request": { "subtype": "future_subtype" }
1949 }),
1950 &mut pending,
1951 &events_tx,
1952 );
1953 assert!(matches!(action, InboundAction::None));
1954 let event = events_rx.try_recv().expect("broadcast");
1955 assert!(matches!(event, InboundEvent::Other(_)));
1956 }
1957
1958 #[tokio::test]
1959 async fn permission_handler_invokes_closure_async() {
1960 let handler = PermissionHandler::new(|req| async move {
1961 if req.tool_name == "Bash" {
1962 PermissionDecision::Deny {
1963 message: "no bash".into(),
1964 }
1965 } else {
1966 PermissionDecision::Allow {
1967 updated_input: None,
1968 }
1969 }
1970 });
1971 let req = PermissionRequest {
1972 request_id: "r1".into(),
1973 tool_name: "Bash".into(),
1974 input: Value::Null,
1975 raw: Value::Null,
1976 };
1977 match handler.invoke(req).await {
1978 PermissionDecision::Deny { message } => assert_eq!(message, "no bash"),
1979 other => panic!("expected Deny, got {other:?}"),
1980 }
1981 }
1982
1983 #[test]
1984 fn parse_control_response_extracts_success() {
1985 let msg = json!({
1986 "type": "control_response",
1987 "response": {
1988 "request_id": "interrupt-1",
1989 "subtype": "success",
1990 "response": {}
1991 }
1992 });
1993 let (id, outcome) = parse_control_response(&msg).expect("parsed");
1994 assert_eq!(id, "interrupt-1");
1995 assert!(outcome.is_ok());
1996 }
1997
1998 #[test]
1999 fn parse_control_response_extracts_error_with_message() {
2000 let msg = json!({
2001 "type": "control_response",
2002 "response": {
2003 "request_id": "interrupt-2",
2004 "subtype": "error",
2005 "error": "no turn in flight"
2006 }
2007 });
2008 let (id, outcome) = parse_control_response(&msg).expect("parsed");
2009 assert_eq!(id, "interrupt-2");
2010 match outcome {
2011 Err(Error::DuplexControlFailed { message }) => {
2012 assert_eq!(message, "no turn in flight");
2013 }
2014 other => panic!("expected DuplexControlFailed, got {other:?}"),
2015 }
2016 }
2017
2018 #[test]
2019 fn parse_control_response_returns_none_on_missing_request_id() {
2020 let msg = json!({
2021 "type": "control_response",
2022 "response": { "subtype": "success" }
2023 });
2024 assert!(parse_control_response(&msg).is_none());
2025 }
2026
2027 #[test]
2028 fn parse_control_response_returns_none_on_unknown_subtype() {
2029 let msg = json!({
2030 "type": "control_response",
2031 "response": { "request_id": "x", "subtype": "future_subtype" }
2032 });
2033 assert!(parse_control_response(&msg).is_none());
2034 }
2035
2036 #[test]
2037 fn handle_inbound_returns_control_response_action() {
2038 let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
2039 let (events_tx, _events_rx) = broadcast::channel(16);
2040 let mut pending = Some((tx, Vec::new()));
2041 let action = handle_inbound(
2042 json!({
2043 "type": "control_response",
2044 "response": {
2045 "request_id": "interrupt-1",
2046 "subtype": "success",
2047 "response": {}
2048 }
2049 }),
2050 &mut pending,
2051 &events_tx,
2052 );
2053 match action {
2054 InboundAction::ControlResponse {
2055 request_id,
2056 outcome,
2057 } => {
2058 assert_eq!(request_id, "interrupt-1");
2059 assert!(outcome.is_ok());
2060 }
2061 InboundAction::None | InboundAction::Permission(_) => {
2062 panic!("expected ControlResponse action");
2063 }
2064 }
2065 }
2066
2067 #[test]
2068 fn handle_inbound_treats_malformed_control_response_as_other() {
2069 let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
2070 let (events_tx, mut events_rx) = broadcast::channel(16);
2071 let mut pending = Some((tx, Vec::new()));
2072 let action = handle_inbound(
2073 json!({
2074 "type": "control_response",
2075 "response": { "subtype": "success" }
2076 }),
2077 &mut pending,
2078 &events_tx,
2079 );
2080 assert!(matches!(action, InboundAction::None));
2081 let event = events_rx.try_recv().expect("broadcast");
2082 assert!(matches!(event, InboundEvent::Other(_)));
2083 }
2084
2085 #[tokio::test]
2086 async fn permission_handler_clones_arc() {
2087 let handler = PermissionHandler::new(|_req| async move {
2088 PermissionDecision::Allow {
2089 updated_input: None,
2090 }
2091 });
2092 let cloned = handler.clone();
2093 let req = PermissionRequest {
2094 request_id: "r1".into(),
2095 tool_name: "Read".into(),
2096 input: Value::Null,
2097 raw: Value::Null,
2098 };
2099 let _ = handler.invoke(req.clone()).await;
2101 let _ = cloned.invoke(req).await;
2102 }
2103
2104 fn fake_session(
2111 initial: SessionExitStatus,
2112 ) -> (
2113 DuplexSession,
2114 watch::Sender<SessionExitStatus>,
2115 oneshot::Sender<()>,
2116 ) {
2117 let (outbound_tx, outbound_rx) = mpsc::unbounded_channel::<OutboundMsg>();
2118 let (events_tx, _events_rx) = broadcast::channel::<InboundEvent>(16);
2119 let (exit_tx, exit_rx) = watch::channel(initial);
2120 let (stop_tx, stop_rx) = oneshot::channel::<()>();
2121
2122 let join = tokio::spawn(async move {
2123 let _outbound_rx = outbound_rx;
2124 let _ = stop_rx.await;
2125 Ok::<(), Error>(())
2126 });
2127
2128 let session = DuplexSession {
2129 outbound_tx,
2130 events_tx,
2131 exit_rx,
2132 join,
2133 };
2134 (session, exit_tx, stop_tx)
2135 }
2136
2137 #[tokio::test]
2138 async fn is_alive_true_while_running() {
2139 let (session, _exit_tx, _stop) = fake_session(SessionExitStatus::Running);
2140 assert!(session.is_alive());
2141 }
2142
2143 #[tokio::test]
2144 async fn is_alive_false_after_completed() {
2145 let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
2146 exit_tx.send(SessionExitStatus::Completed).unwrap();
2147 assert!(!session.is_alive());
2148 }
2149
2150 #[tokio::test]
2151 async fn is_alive_false_after_failed() {
2152 let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
2153 exit_tx
2154 .send(SessionExitStatus::Failed("boom".into()))
2155 .unwrap();
2156 assert!(!session.is_alive());
2157 }
2158
2159 #[tokio::test]
2160 async fn exit_status_reports_running_initially() {
2161 let (session, _exit_tx, _stop) = fake_session(SessionExitStatus::Running);
2162 assert!(matches!(session.exit_status(), SessionExitStatus::Running));
2163 }
2164
2165 #[tokio::test]
2166 async fn exit_status_reflects_completed() {
2167 let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
2168 exit_tx.send(SessionExitStatus::Completed).unwrap();
2169 assert!(matches!(
2170 session.exit_status(),
2171 SessionExitStatus::Completed
2172 ));
2173 }
2174
2175 #[tokio::test]
2176 async fn exit_status_reflects_failed_with_message() {
2177 let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
2178 exit_tx
2179 .send(SessionExitStatus::Failed("oh no".into()))
2180 .unwrap();
2181 match session.exit_status() {
2182 SessionExitStatus::Failed(msg) => assert_eq!(msg, "oh no"),
2183 other => panic!("expected Failed, got {other:?}"),
2184 }
2185 }
2186
2187 #[tokio::test]
2188 async fn wait_for_exit_returns_immediately_when_already_terminal() {
2189 let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
2190 exit_tx.send(SessionExitStatus::Completed).unwrap();
2191 let status = tokio::time::timeout(Duration::from_secs(1), session.wait_for_exit())
2192 .await
2193 .expect("wait_for_exit should not block when already terminal");
2194 assert!(matches!(status, SessionExitStatus::Completed));
2195 }
2196
2197 #[tokio::test]
2198 async fn wait_for_exit_blocks_until_state_transitions() {
2199 let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
2200
2201 let waiter = async { session.wait_for_exit().await };
2202 let driver = async {
2203 tokio::time::sleep(Duration::from_millis(20)).await;
2204 exit_tx.send(SessionExitStatus::Completed).unwrap();
2205 };
2206 let (status, ()) = tokio::join!(waiter, driver);
2207 assert!(matches!(status, SessionExitStatus::Completed));
2208 }
2209
2210 #[tokio::test]
2211 async fn wait_for_exit_supports_multiple_observers() {
2212 let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
2213
2214 let waiter1 = async { session.wait_for_exit().await };
2215 let waiter2 = async { session.wait_for_exit().await };
2216 let driver = async {
2217 tokio::time::sleep(Duration::from_millis(20)).await;
2218 exit_tx
2219 .send(SessionExitStatus::Failed("crash".into()))
2220 .unwrap();
2221 };
2222 let (s1, s2, ()) = tokio::join!(waiter1, waiter2, driver);
2223 match s1 {
2224 SessionExitStatus::Failed(msg) => assert_eq!(msg, "crash"),
2225 other => panic!("waiter1 expected Failed, got {other:?}"),
2226 }
2227 match s2 {
2228 SessionExitStatus::Failed(msg) => assert_eq!(msg, "crash"),
2229 other => panic!("waiter2 expected Failed, got {other:?}"),
2230 }
2231 }
2232
2233 #[tokio::test]
2234 async fn wait_for_exit_returns_last_value_when_sender_dropped() {
2235 let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
2239 let waiter = async { session.wait_for_exit().await };
2240 let driver = async {
2241 tokio::time::sleep(Duration::from_millis(20)).await;
2242 drop(exit_tx);
2243 };
2244 let (status, ()) = tokio::time::timeout(Duration::from_secs(1), async {
2245 tokio::join!(waiter, driver)
2246 })
2247 .await
2248 .expect("wait_for_exit must not hang when sender is dropped");
2249 assert!(matches!(status, SessionExitStatus::Running));
2250 }
2251}