1use std::collections::HashMap;
179use std::future::Future;
180use std::pin::Pin;
181use std::process::Stdio;
182use std::sync::Arc;
183use std::time::Duration;
184
185use serde_json::Value;
186use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
187use tokio::process::{Child, ChildStdin, ChildStdout, Command};
188use tokio::sync::{broadcast, mpsc, oneshot, watch};
189use tokio::task::JoinHandle;
190use tracing::{debug, warn};
191
192use crate::Claude;
193use crate::error::{Error, Result};
194use crate::types::PermissionMode;
195
196pub const DEFAULT_SUBSCRIBER_CAPACITY: usize = 256;
201
202#[derive(Debug, Clone)]
211pub struct PermissionRequest {
212 pub request_id: String,
215 pub tool_name: String,
217 pub input: Value,
219 pub raw: Value,
222}
223
224#[derive(Debug, Clone)]
233pub enum PermissionDecision {
234 Allow {
236 updated_input: Option<Value>,
239 },
240 Deny {
242 message: String,
244 },
245 Defer,
248}
249
250type PermissionFuture = Pin<Box<dyn Future<Output = PermissionDecision> + Send + 'static>>;
251type PermissionFn = dyn Fn(PermissionRequest) -> PermissionFuture + Send + Sync + 'static;
252
253#[derive(Clone)]
266pub struct PermissionHandler {
267 inner: Arc<PermissionFn>,
268}
269
270impl PermissionHandler {
271 pub fn new<F, Fut>(f: F) -> Self
287 where
288 F: Fn(PermissionRequest) -> Fut + Send + Sync + 'static,
289 Fut: Future<Output = PermissionDecision> + Send + 'static,
290 {
291 Self {
292 inner: Arc::new(move |req| Box::pin(f(req))),
293 }
294 }
295
296 fn invoke(&self, req: PermissionRequest) -> PermissionFuture {
297 (self.inner)(req)
298 }
299}
300
301impl std::fmt::Debug for PermissionHandler {
302 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303 f.debug_struct("PermissionHandler").finish_non_exhaustive()
304 }
305}
306
307#[derive(Debug, Default, Clone)]
314pub struct DuplexOptions {
315 model: Option<String>,
316 system_prompt: Option<String>,
317 append_system_prompt: Option<String>,
318 resume: Option<String>,
319 continue_session: bool,
320 worktree: bool,
321 worktree_name: Option<String>,
322 agent: Option<String>,
323 agents_json: Option<String>,
324 permission_mode: Option<PermissionMode>,
325 dangerously_skip_permissions: bool,
326 additional_args: Vec<String>,
327 subscriber_capacity: Option<usize>,
328 on_permission: Option<PermissionHandler>,
329}
330
331impl DuplexOptions {
332 #[must_use]
334 pub fn model(mut self, model: impl Into<String>) -> Self {
335 self.model = Some(model.into());
336 self
337 }
338
339 #[must_use]
341 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
342 self.system_prompt = Some(prompt.into());
343 self
344 }
345
346 #[must_use]
348 pub fn append_system_prompt(mut self, prompt: impl Into<String>) -> Self {
349 self.append_system_prompt = Some(prompt.into());
350 self
351 }
352
353 #[must_use]
370 pub fn resume(mut self, session_id: impl Into<String>) -> Self {
371 self.resume = Some(session_id.into());
372 self
373 }
374
375 #[must_use]
382 pub fn continue_session(mut self) -> Self {
383 self.continue_session = true;
384 self
385 }
386
387 #[must_use]
398 pub fn worktree(mut self, name: Option<impl Into<String>>) -> Self {
399 self.worktree = true;
400 if let Some(n) = name {
401 self.worktree_name = Some(n.into());
402 }
403 self
404 }
405
406 #[must_use]
420 pub fn agent(mut self, name: impl Into<String>) -> Self {
421 self.agent = Some(name.into());
422 self
423 }
424
425 #[must_use]
437 pub fn agents_json(mut self, json: impl Into<String>) -> Self {
438 self.agents_json = Some(json.into());
439 self
440 }
441
442 #[must_use]
458 pub fn permission_mode(mut self, mode: PermissionMode) -> Self {
459 self.permission_mode = Some(mode);
460 self
461 }
462
463 #[must_use]
471 pub fn dangerously_skip_permissions(mut self) -> Self {
472 self.dangerously_skip_permissions = true;
473 self
474 }
475
476 #[must_use]
481 pub fn arg(mut self, arg: impl Into<String>) -> Self {
482 self.additional_args.push(arg.into());
483 self
484 }
485
486 #[must_use]
494 pub fn subscriber_capacity(mut self, capacity: usize) -> Self {
495 self.subscriber_capacity = Some(capacity);
496 self
497 }
498
499 #[must_use]
517 pub fn on_permission(mut self, handler: PermissionHandler) -> Self {
518 self.on_permission = Some(handler);
519 self
520 }
521
522 fn into_args(self) -> Vec<String> {
523 let mut args = vec![
524 "--print".to_string(),
525 "--verbose".to_string(),
526 "--output-format".to_string(),
527 "stream-json".to_string(),
528 "--input-format".to_string(),
529 "stream-json".to_string(),
530 ];
531
532 if let Some(m) = self.model {
533 args.push("--model".to_string());
534 args.push(m);
535 }
536 if let Some(p) = self.system_prompt {
537 args.push("--system-prompt".to_string());
538 args.push(p);
539 }
540 if let Some(p) = self.append_system_prompt {
541 args.push("--append-system-prompt".to_string());
542 args.push(p);
543 }
544 if let Some(id) = self.resume {
545 args.push("--resume".to_string());
546 args.push(id);
547 }
548 if self.continue_session {
549 args.push("--continue".to_string());
550 }
551 if self.worktree {
552 args.push("--worktree".to_string());
553 if let Some(n) = self.worktree_name {
554 args.push(n);
555 }
556 }
557 if let Some(json) = self.agents_json {
558 args.push("--agents".to_string());
559 args.push(json);
560 }
561 if let Some(name) = self.agent {
562 args.push("--agent".to_string());
563 args.push(name);
564 }
565 if let Some(mode) = self.permission_mode {
566 args.push("--permission-mode".to_string());
567 args.push(mode.as_arg().to_string());
568 }
569 if self.dangerously_skip_permissions {
570 args.push("--dangerously-skip-permissions".to_string());
571 }
572 if self.on_permission.is_some() {
573 args.push("--permission-prompt-tool".to_string());
574 args.push("stdio".to_string());
575 }
576 args.extend(self.additional_args);
577
578 args
579 }
580}
581
582#[derive(Debug, Clone)]
589pub struct TurnResult {
590 pub result: Value,
592 pub events: Vec<Value>,
594}
595
596impl TurnResult {
597 #[must_use]
599 pub fn result_text(&self) -> Option<&str> {
600 self.result.get("result").and_then(Value::as_str)
601 }
602
603 #[must_use]
605 pub fn session_id(&self) -> Option<&str> {
606 self.result.get("session_id").and_then(Value::as_str)
607 }
608
609 #[must_use]
612 pub fn total_cost_usd(&self) -> Option<f64> {
613 self.result
614 .get("total_cost_usd")
615 .or_else(|| self.result.get("cost_usd"))
616 .and_then(Value::as_f64)
617 }
618
619 #[must_use]
621 pub fn duration_ms(&self) -> Option<u64> {
622 self.result.get("duration_ms").and_then(Value::as_u64)
623 }
624}
625
626#[derive(Debug, Clone)]
640pub enum InboundEvent {
641 SystemInit {
644 session_id: String,
647 },
648 Assistant(Value),
651 StreamEvent(Value),
654 User(Value),
657 Other(Value),
660}
661
662fn classify(msg: &Value) -> InboundEvent {
663 match msg.get("type").and_then(Value::as_str) {
664 Some("system") => {
665 if msg.get("subtype").and_then(Value::as_str) == Some("init")
666 && let Some(id) = msg.get("session_id").and_then(Value::as_str)
667 {
668 return InboundEvent::SystemInit {
669 session_id: id.to_string(),
670 };
671 }
672 InboundEvent::Other(msg.clone())
673 }
674 Some("assistant") => InboundEvent::Assistant(msg.clone()),
675 Some("stream_event") => InboundEvent::StreamEvent(msg.clone()),
676 Some("user") => InboundEvent::User(msg.clone()),
677 _ => InboundEvent::Other(msg.clone()),
678 }
679}
680
681#[derive(Debug, Clone)]
696pub enum SessionExitStatus {
697 Running,
699 Completed,
702 Failed(String),
705}
706
707#[derive(Debug)]
716pub struct DuplexSession {
717 outbound_tx: mpsc::UnboundedSender<OutboundMsg>,
718 events_tx: broadcast::Sender<InboundEvent>,
719 exit_rx: watch::Receiver<SessionExitStatus>,
720 join: JoinHandle<Result<()>>,
721}
722
723#[derive(Debug)]
724enum OutboundMsg {
725 Send {
726 prompt: String,
727 reply: oneshot::Sender<Result<TurnResult>>,
728 },
729 PermissionResponse {
730 request_id: String,
731 decision: PermissionDecision,
732 },
733 Interrupt {
734 reply: oneshot::Sender<Result<()>>,
735 },
736}
737
738impl DuplexSession {
739 pub async fn spawn(claude: &Claude, opts: DuplexOptions) -> Result<Self> {
747 let capacity = opts
748 .subscriber_capacity
749 .unwrap_or(DEFAULT_SUBSCRIBER_CAPACITY);
750 let permission_handler = opts.on_permission.clone();
751
752 let mut command_args = Vec::new();
753 command_args.extend(claude.global_args.clone());
754 command_args.extend(opts.into_args());
755
756 debug!(
757 binary = %claude.binary.display(),
758 args = ?command_args,
759 "spawning duplex claude session"
760 );
761
762 let mut cmd = Command::new(&claude.binary);
763 cmd.args(&command_args)
764 .env_remove("CLAUDECODE")
765 .env_remove("CLAUDE_CODE_ENTRYPOINT")
766 .envs(&claude.env)
767 .stdin(Stdio::piped())
768 .stdout(Stdio::piped())
769 .stderr(Stdio::piped())
770 .kill_on_drop(true);
771
772 if let Some(ref dir) = claude.working_dir {
773 cmd.current_dir(dir);
774 }
775
776 let mut child = cmd.spawn().map_err(|e| Error::Io {
777 message: format!("failed to spawn claude: {e}"),
778 source: e,
779 working_dir: claude.working_dir.clone(),
780 })?;
781
782 let stdin = child.stdin.take().expect("stdin was piped");
783 let stdout = child.stdout.take().expect("stdout was piped");
784
785 let (outbound_tx, outbound_rx) = mpsc::unbounded_channel();
786 let (events_tx, _initial_rx) = broadcast::channel(capacity);
787 let (exit_tx, exit_rx) = watch::channel(SessionExitStatus::Running);
788
789 let join = tokio::spawn(run_session(
790 child,
791 stdin,
792 stdout,
793 outbound_rx,
794 events_tx.clone(),
795 permission_handler,
796 exit_tx,
797 ));
798
799 Ok(Self {
800 outbound_tx,
801 events_tx,
802 exit_rx,
803 join,
804 })
805 }
806
807 pub async fn send(&self, prompt: impl Into<String>) -> Result<TurnResult> {
813 let (reply_tx, reply_rx) = oneshot::channel();
814 self.outbound_tx
815 .send(OutboundMsg::Send {
816 prompt: prompt.into(),
817 reply: reply_tx,
818 })
819 .map_err(|_| Error::DuplexClosed)?;
820 reply_rx.await.map_err(|_| Error::DuplexClosed)?
821 }
822
823 #[must_use]
858 pub fn subscribe(&self) -> broadcast::Receiver<InboundEvent> {
859 self.events_tx.subscribe()
860 }
861
862 #[must_use]
873 pub fn is_alive(&self) -> bool {
874 matches!(*self.exit_rx.borrow(), SessionExitStatus::Running)
875 }
876
877 #[must_use]
886 pub fn exit_status(&self) -> SessionExitStatus {
887 self.exit_rx.borrow().clone()
888 }
889
890 pub async fn wait_for_exit(&self) -> SessionExitStatus {
902 let mut rx = self.exit_rx.clone();
903 loop {
904 {
905 let value = rx.borrow_and_update();
906 if !matches!(*value, SessionExitStatus::Running) {
907 return value.clone();
908 }
909 }
910 if rx.changed().await.is_err() {
911 return rx.borrow().clone();
912 }
913 }
914 }
915
916 pub fn respond_to_permission(
961 &self,
962 request_id: impl Into<String>,
963 decision: PermissionDecision,
964 ) -> Result<()> {
965 if matches!(decision, PermissionDecision::Defer) {
966 warn!("respond_to_permission called with Defer; ignoring");
967 return Ok(());
968 }
969 self.outbound_tx
970 .send(OutboundMsg::PermissionResponse {
971 request_id: request_id.into(),
972 decision,
973 })
974 .map_err(|_| Error::DuplexClosed)?;
975 Ok(())
976 }
977
978 pub async fn interrupt(&self) -> Result<()> {
1019 let (reply_tx, reply_rx) = oneshot::channel();
1020 self.outbound_tx
1021 .send(OutboundMsg::Interrupt { reply: reply_tx })
1022 .map_err(|_| Error::DuplexClosed)?;
1023 reply_rx.await.map_err(|_| Error::DuplexClosed)?
1024 }
1025
1026 pub async fn close(self) -> Result<()> {
1032 drop(self.outbound_tx);
1033 drop(self.events_tx);
1034 match self.join.await {
1035 Ok(result) => result,
1036 Err(e) if e.is_cancelled() => Ok(()),
1037 Err(e) => Err(Error::Io {
1038 message: format!("duplex session task panicked: {e}"),
1039 source: std::io::Error::other(e.to_string()),
1040 working_dir: None,
1041 }),
1042 }
1043 }
1044}
1045
1046const SHUTDOWN_BUDGET: Duration = Duration::from_secs(5);
1050
1051async fn run_session(
1052 mut child: Child,
1053 mut stdin: ChildStdin,
1054 stdout: ChildStdout,
1055 mut outbound_rx: mpsc::UnboundedReceiver<OutboundMsg>,
1056 events_tx: broadcast::Sender<InboundEvent>,
1057 permission_handler: Option<PermissionHandler>,
1058 exit_tx: watch::Sender<SessionExitStatus>,
1059) -> Result<()> {
1060 let mut lines = BufReader::new(stdout).lines();
1061 let mut pending: Option<(oneshot::Sender<Result<TurnResult>>, Vec<Value>)> = None;
1062 let mut pending_control: HashMap<String, oneshot::Sender<Result<()>>> = HashMap::new();
1063 let mut next_control_id: u64 = 0;
1064 let mut stream_err: Option<Error> = None;
1065
1066 loop {
1067 tokio::select! {
1068 biased;
1069
1070 line = lines.next_line() => match line {
1071 Ok(Some(l)) => {
1072 if l.trim().is_empty() {
1073 continue;
1074 }
1075 let parsed = match serde_json::from_str::<Value>(&l) {
1076 Ok(v) => v,
1077 Err(e) => {
1078 debug!(line = %l, error = %e, "failed to parse duplex event, skipping");
1079 continue;
1080 }
1081 };
1082 match handle_inbound(parsed, &mut pending, &events_tx) {
1083 InboundAction::None => {}
1084 InboundAction::Permission(req) => {
1085 let request_id = req.request_id.clone();
1086 let decision = match permission_handler.as_ref() {
1087 Some(h) => h.invoke(req).await,
1088 None => {
1089 warn!(
1090 request_id = %request_id,
1091 "received can_use_tool with no permission handler; auto-denying"
1092 );
1093 PermissionDecision::Deny {
1094 message:
1095 "no permission handler configured on duplex session"
1096 .into(),
1097 }
1098 }
1099 };
1100 if matches!(decision, PermissionDecision::Defer) {
1101 debug!(
1102 request_id = %request_id,
1103 "permission handler deferred; waiting for respond_to_permission"
1104 );
1105 } else if let Err(e) =
1106 write_permission_response(&mut stdin, &request_id, &decision).await
1107 {
1108 warn!(error = %e, "failed to write permission response");
1109 }
1110 }
1111 InboundAction::ControlResponse { request_id, outcome } => {
1112 if let Some(reply) = pending_control.remove(&request_id) {
1113 let _ = reply.send(outcome);
1114 } else {
1115 debug!(
1116 request_id = %request_id,
1117 "received control_response with no pending request"
1118 );
1119 }
1120 }
1121 }
1122 }
1123 Ok(None) => break,
1124 Err(e) => {
1125 stream_err = Some(Error::Io {
1126 message: "failed to read duplex stdout".to_string(),
1127 source: e,
1128 working_dir: None,
1129 });
1130 break;
1131 }
1132 },
1133
1134 msg = outbound_rx.recv() => match msg {
1135 Some(OutboundMsg::Send { prompt, reply }) => {
1136 if pending.is_some() {
1137 let _ = reply.send(Err(Error::DuplexTurnInFlight));
1138 continue;
1139 }
1140 if let Err(e) = write_user(&mut stdin, &prompt).await {
1141 let _ = reply.send(Err(e));
1142 continue;
1143 }
1144 pending = Some((reply, Vec::new()));
1145 }
1146 Some(OutboundMsg::PermissionResponse { request_id, decision }) => {
1147 if let Err(e) =
1148 write_permission_response(&mut stdin, &request_id, &decision).await
1149 {
1150 warn!(error = %e, "failed to write deferred permission response");
1151 }
1152 }
1153 Some(OutboundMsg::Interrupt { reply }) => {
1154 next_control_id += 1;
1155 let request_id = format!("interrupt-{next_control_id}");
1156 if let Err(e) =
1157 write_control_request(&mut stdin, &request_id, "interrupt").await
1158 {
1159 let _ = reply.send(Err(e));
1160 continue;
1161 }
1162 pending_control.insert(request_id, reply);
1163 }
1164 None => break,
1165 },
1166 }
1167 }
1168
1169 drop(stdin);
1170 match tokio::time::timeout(SHUTDOWN_BUDGET, child.wait()).await {
1171 Ok(Ok(_status)) => {}
1172 Ok(Err(e)) => {
1173 warn!(error = %e, "failed to wait for duplex child");
1174 }
1175 Err(_) => {
1176 warn!("duplex child did not exit within shutdown budget; killing");
1177 let _ = child.kill().await;
1178 }
1179 }
1180
1181 if let Some((reply, _)) = pending.take() {
1182 let _ = reply.send(Err(Error::DuplexClosed));
1183 }
1184 for (_, reply) in pending_control.drain() {
1185 let _ = reply.send(Err(Error::DuplexClosed));
1186 }
1187
1188 let result = match stream_err {
1189 Some(e) => Err(e),
1190 None => Ok(()),
1191 };
1192 let final_state = match &result {
1193 Ok(()) => SessionExitStatus::Completed,
1194 Err(e) => SessionExitStatus::Failed(e.to_string()),
1195 };
1196 let _ = exit_tx.send(final_state);
1197 result
1198}
1199
1200enum InboundAction {
1204 None,
1206 Permission(PermissionRequest),
1210 ControlResponse {
1215 request_id: String,
1216 outcome: Result<()>,
1217 },
1218}
1219
1220fn handle_inbound(
1221 msg: Value,
1222 pending: &mut Option<(oneshot::Sender<Result<TurnResult>>, Vec<Value>)>,
1223 events_tx: &broadcast::Sender<InboundEvent>,
1224) -> InboundAction {
1225 match msg.get("type").and_then(Value::as_str) {
1226 Some("result") => {
1227 if let Some((reply, events)) = pending.take() {
1228 let _ = reply.send(Ok(TurnResult {
1229 result: msg,
1230 events,
1231 }));
1232 } else {
1233 debug!("dropping orphan result event with no pending turn");
1234 }
1235 InboundAction::None
1236 }
1237 Some("control_request") => {
1238 if msg
1241 .get("request")
1242 .and_then(|r| r.get("subtype"))
1243 .and_then(Value::as_str)
1244 == Some("can_use_tool")
1245 && let Some(req) = parse_permission_request(&msg)
1246 {
1247 if let Some((_, events)) = pending.as_mut() {
1248 events.push(msg);
1249 }
1250 return InboundAction::Permission(req);
1251 }
1252 debug!(
1253 ?msg,
1254 "received unhandled control_request; treating as Other"
1255 );
1256 let _ = events_tx.send(InboundEvent::Other(msg.clone()));
1257 if let Some((_, events)) = pending.as_mut() {
1258 events.push(msg);
1259 }
1260 InboundAction::None
1261 }
1262 Some("control_response") => {
1263 if let Some((request_id, outcome)) = parse_control_response(&msg) {
1264 return InboundAction::ControlResponse {
1265 request_id,
1266 outcome,
1267 };
1268 }
1269 debug!(
1270 ?msg,
1271 "received malformed control_response; treating as Other"
1272 );
1273 let _ = events_tx.send(InboundEvent::Other(msg.clone()));
1274 if let Some((_, events)) = pending.as_mut() {
1275 events.push(msg);
1276 }
1277 InboundAction::None
1278 }
1279 _ => {
1280 let _ = events_tx.send(classify(&msg));
1283
1284 if let Some((_, events)) = pending.as_mut() {
1285 events.push(msg);
1286 } else {
1287 debug!("dropping inbound event with no pending turn");
1288 }
1289 InboundAction::None
1290 }
1291 }
1292}
1293
1294fn parse_permission_request(msg: &Value) -> Option<PermissionRequest> {
1295 let request_id = msg.get("request_id").and_then(Value::as_str)?;
1296 let request = msg.get("request")?;
1297 let tool_name = request.get("tool_name").and_then(Value::as_str)?;
1298 let input = request.get("input").cloned().unwrap_or(Value::Null);
1299 Some(PermissionRequest {
1300 request_id: request_id.to_string(),
1301 tool_name: tool_name.to_string(),
1302 input,
1303 raw: request.clone(),
1304 })
1305}
1306
1307fn parse_control_response(msg: &Value) -> Option<(String, Result<()>)> {
1313 let response = msg.get("response")?;
1314 let request_id = response.get("request_id").and_then(Value::as_str)?;
1315 let outcome = match response.get("subtype").and_then(Value::as_str) {
1316 Some("success") => Ok(()),
1317 Some("error") => {
1318 let message = response
1319 .get("error")
1320 .and_then(Value::as_str)
1321 .unwrap_or("unknown control_response error")
1322 .to_string();
1323 Err(Error::DuplexControlFailed { message })
1324 }
1325 _ => return None,
1326 };
1327 Some((request_id.to_string(), outcome))
1328}
1329
1330async fn write_user(stdin: &mut ChildStdin, prompt: &str) -> Result<()> {
1331 let user_msg = serde_json::json!({
1332 "type": "user",
1333 "message": {
1334 "role": "user",
1335 "content": prompt,
1336 },
1337 "parent_tool_use_id": null,
1338 });
1339 write_line(stdin, &user_msg, "user message").await
1340}
1341
1342async fn write_control_request(
1343 stdin: &mut ChildStdin,
1344 request_id: &str,
1345 subtype: &str,
1346) -> Result<()> {
1347 let envelope = serde_json::json!({
1348 "type": "control_request",
1349 "request_id": request_id,
1350 "request": { "subtype": subtype },
1351 });
1352 write_line(stdin, &envelope, "control_request").await
1353}
1354
1355async fn write_permission_response(
1356 stdin: &mut ChildStdin,
1357 request_id: &str,
1358 decision: &PermissionDecision,
1359) -> Result<()> {
1360 let inner = match decision {
1361 PermissionDecision::Allow { updated_input } => {
1362 let mut obj = serde_json::Map::new();
1363 obj.insert("behavior".to_string(), Value::String("allow".to_string()));
1364 if let Some(input) = updated_input {
1365 obj.insert("updatedInput".to_string(), input.clone());
1366 }
1367 Value::Object(obj)
1368 }
1369 PermissionDecision::Deny { message } => serde_json::json!({
1370 "behavior": "deny",
1371 "message": message,
1372 }),
1373 PermissionDecision::Defer => {
1374 return Ok(());
1376 }
1377 };
1378 let envelope = serde_json::json!({
1379 "type": "control_response",
1380 "response": {
1381 "request_id": request_id,
1382 "subtype": "success",
1383 "response": inner,
1384 },
1385 });
1386 write_line(stdin, &envelope, "control_response").await
1387}
1388
1389async fn write_line(stdin: &mut ChildStdin, value: &Value, what: &'static str) -> Result<()> {
1390 let mut line = serde_json::to_string(value).map_err(|e| Error::Json {
1391 message: format!("failed to serialize duplex {what}"),
1392 source: e,
1393 })?;
1394 line.push('\n');
1395 stdin
1396 .write_all(line.as_bytes())
1397 .await
1398 .map_err(|e| Error::Io {
1399 message: format!("failed to write {what} to duplex stdin"),
1400 source: e,
1401 working_dir: None,
1402 })?;
1403 stdin.flush().await.map_err(|e| Error::Io {
1404 message: "failed to flush duplex stdin".to_string(),
1405 source: e,
1406 working_dir: None,
1407 })?;
1408 Ok(())
1409}
1410
1411#[cfg(test)]
1412mod tests {
1413 use super::*;
1414 use serde_json::json;
1415
1416 #[test]
1417 fn into_args_default_includes_required_flags() {
1418 let args = DuplexOptions::default().into_args();
1419 assert!(args.contains(&"--print".to_string()));
1420 assert!(args.contains(&"--verbose".to_string()));
1421 assert!(
1422 args.windows(2)
1423 .any(|w| w == ["--output-format", "stream-json"])
1424 );
1425 assert!(
1426 args.windows(2)
1427 .any(|w| w == ["--input-format", "stream-json"])
1428 );
1429 }
1430
1431 #[test]
1432 fn into_args_includes_model() {
1433 let args = DuplexOptions::default().model("haiku").into_args();
1434 assert!(args.windows(2).any(|w| w == ["--model", "haiku"]));
1435 }
1436
1437 #[test]
1438 fn into_args_includes_system_prompts() {
1439 let args = DuplexOptions::default()
1440 .system_prompt("be concise")
1441 .append_system_prompt("also polite")
1442 .into_args();
1443 assert!(
1444 args.windows(2)
1445 .any(|w| w == ["--system-prompt", "be concise"])
1446 );
1447 assert!(
1448 args.windows(2)
1449 .any(|w| w == ["--append-system-prompt", "also polite"])
1450 );
1451 }
1452
1453 #[test]
1454 fn into_args_appends_raw_args_last() {
1455 let args = DuplexOptions::default()
1456 .arg("--add-dir")
1457 .arg("/tmp/foo")
1458 .into_args();
1459 assert_eq!(&args[args.len() - 2..], &["--add-dir", "/tmp/foo"]);
1461 }
1462
1463 #[test]
1464 fn into_args_includes_resume_when_set() {
1465 let args = DuplexOptions::default().resume("abc-123").into_args();
1466 assert!(args.windows(2).any(|w| w == ["--resume", "abc-123"]));
1467 }
1468
1469 #[test]
1470 fn into_args_omits_resume_by_default() {
1471 let args = DuplexOptions::default().into_args();
1472 assert!(
1473 !args.iter().any(|a| a == "--resume"),
1474 "--resume should not appear without an explicit resume(...) call; got {args:?}"
1475 );
1476 }
1477
1478 #[test]
1479 fn into_args_includes_continue_when_set() {
1480 let args = DuplexOptions::default().continue_session().into_args();
1481 assert!(args.iter().any(|a| a == "--continue"));
1482 }
1483
1484 #[test]
1485 fn into_args_omits_continue_by_default() {
1486 let args = DuplexOptions::default().into_args();
1487 assert!(!args.iter().any(|a| a == "--continue"));
1488 }
1489
1490 #[test]
1491 fn into_args_includes_worktree_flag_without_name() {
1492 let args = DuplexOptions::default().worktree(None::<&str>).into_args();
1493 assert!(args.iter().any(|a| a == "--worktree"));
1494 let pos = args.iter().position(|a| a == "--worktree").unwrap();
1496 assert!(
1497 args.get(pos + 1).is_none_or(|a| a.starts_with("--")),
1498 "--worktree without a name should not be followed by a positional; got {args:?}"
1499 );
1500 }
1501
1502 #[test]
1503 fn into_args_includes_worktree_flag_with_name() {
1504 let args = DuplexOptions::default()
1505 .worktree(Some("agent-xyz"))
1506 .into_args();
1507 let pos = args.iter().position(|a| a == "--worktree").unwrap();
1508 assert_eq!(args.get(pos + 1).map(String::as_str), Some("agent-xyz"));
1509 }
1510
1511 #[test]
1512 fn into_args_omits_worktree_by_default() {
1513 let args = DuplexOptions::default().into_args();
1514 assert!(
1515 !args.iter().any(|a| a == "--worktree"),
1516 "--worktree should not appear without an explicit worktree(...) call; got {args:?}"
1517 );
1518 }
1519
1520 #[test]
1521 fn worktree_lands_before_additional_args() {
1522 let args = DuplexOptions::default()
1524 .worktree(Some("foo"))
1525 .arg("--")
1526 .arg("trailing")
1527 .into_args();
1528 let wt_pos = args.iter().position(|a| a == "--worktree").unwrap();
1529 let dash_dash_pos = args.iter().position(|a| a == "--").unwrap();
1530 assert!(
1531 wt_pos < dash_dash_pos,
1532 "--worktree must precede `--` separator; got {args:?}"
1533 );
1534 }
1535
1536 #[test]
1537 fn into_args_includes_agent_when_set() {
1538 let args = DuplexOptions::default().agent("rust-qa").into_args();
1539 assert!(
1540 args.windows(2).any(|w| w == ["--agent", "rust-qa"]),
1541 "missing --agent rust-qa in {args:?}"
1542 );
1543 }
1544
1545 #[test]
1546 fn into_args_omits_agent_by_default() {
1547 let args = DuplexOptions::default().into_args();
1548 assert!(
1549 !args.iter().any(|a| a == "--agent"),
1550 "--agent should not appear without an explicit agent(...) call; got {args:?}"
1551 );
1552 }
1553
1554 #[test]
1555 fn into_args_includes_agents_json_when_set() {
1556 let json = r#"{"reviewer":{"description":"r","prompt":"p"}}"#;
1557 let args = DuplexOptions::default().agents_json(json).into_args();
1558 let pos = args.iter().position(|a| a == "--agents").unwrap();
1559 assert_eq!(args.get(pos + 1).map(String::as_str), Some(json));
1560 }
1561
1562 #[test]
1563 fn into_args_omits_agents_json_by_default() {
1564 let args = DuplexOptions::default().into_args();
1565 assert!(!args.iter().any(|a| a == "--agents"));
1566 }
1567
1568 #[test]
1569 fn agent_and_agents_json_compose() {
1570 let json = r#"{"reviewer":{"description":"r","prompt":"p"}}"#;
1571 let args = DuplexOptions::default()
1572 .agents_json(json)
1573 .agent("reviewer")
1574 .into_args();
1575 assert!(args.iter().any(|a| a == "--agents"));
1577 assert!(args.iter().any(|a| a == "--agent"));
1578 }
1579
1580 #[test]
1581 fn agent_lands_before_additional_args() {
1582 let args = DuplexOptions::default()
1583 .agent("rust-qa")
1584 .arg("--")
1585 .arg("trailing")
1586 .into_args();
1587 let agent_pos = args.iter().position(|a| a == "--agent").unwrap();
1588 let dash_dash_pos = args.iter().position(|a| a == "--").unwrap();
1589 assert!(
1590 agent_pos < dash_dash_pos,
1591 "--agent must precede `--` separator; got {args:?}"
1592 );
1593 }
1594
1595 #[test]
1596 fn agents_json_lands_before_additional_args() {
1597 let args = DuplexOptions::default()
1598 .agents_json("{}")
1599 .arg("--")
1600 .arg("trailing")
1601 .into_args();
1602 let agents_pos = args.iter().position(|a| a == "--agents").unwrap();
1603 let dash_dash_pos = args.iter().position(|a| a == "--").unwrap();
1604 assert!(
1605 agents_pos < dash_dash_pos,
1606 "--agents must precede `--` separator; got {args:?}"
1607 );
1608 }
1609
1610 #[test]
1611 fn resume_lands_before_additional_args() {
1612 let args = DuplexOptions::default()
1617 .resume("xyz")
1618 .arg("--")
1619 .arg("trailing")
1620 .into_args();
1621 let resume_pos = args.iter().position(|a| a == "--resume").unwrap();
1622 let dash_dash_pos = args.iter().position(|a| a == "--").unwrap();
1623 assert!(
1624 resume_pos < dash_dash_pos,
1625 "--resume must precede `--` separator; got {args:?}"
1626 );
1627 }
1628
1629 #[test]
1630 fn turn_result_accessors_pull_from_result() {
1631 let r = TurnResult {
1632 result: json!({
1633 "type": "result",
1634 "result": "hello",
1635 "session_id": "sess-123",
1636 "total_cost_usd": 0.0042,
1637 "duration_ms": 1234_u64,
1638 }),
1639 events: vec![],
1640 };
1641 assert_eq!(r.result_text(), Some("hello"));
1642 assert_eq!(r.session_id(), Some("sess-123"));
1643 assert_eq!(r.total_cost_usd(), Some(0.0042));
1644 assert_eq!(r.duration_ms(), Some(1234));
1645 }
1646
1647 #[test]
1648 fn turn_result_total_cost_falls_back_to_legacy_field() {
1649 let r = TurnResult {
1650 result: json!({ "cost_usd": 0.5 }),
1651 events: vec![],
1652 };
1653 assert_eq!(r.total_cost_usd(), Some(0.5));
1654 }
1655
1656 #[test]
1657 fn turn_result_accessors_return_none_when_missing() {
1658 let r = TurnResult {
1659 result: json!({}),
1660 events: vec![],
1661 };
1662 assert_eq!(r.result_text(), None);
1663 assert_eq!(r.session_id(), None);
1664 assert_eq!(r.total_cost_usd(), None);
1665 assert_eq!(r.duration_ms(), None);
1666 }
1667
1668 #[test]
1669 fn handle_inbound_appends_non_result_to_pending_events() {
1670 let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
1671 let (events_tx, _events_rx) = broadcast::channel(16);
1672 let mut pending = Some((tx, Vec::new()));
1673 handle_inbound(
1674 json!({ "type": "assistant", "message": {} }),
1675 &mut pending,
1676 &events_tx,
1677 );
1678 let (_, events) = pending.as_ref().unwrap();
1679 assert_eq!(events.len(), 1);
1680 assert_eq!(
1681 events[0].get("type").and_then(Value::as_str),
1682 Some("assistant")
1683 );
1684 }
1685
1686 #[test]
1687 fn handle_inbound_resolves_pending_on_result() {
1688 let (tx, rx) = oneshot::channel::<Result<TurnResult>>();
1689 let (events_tx, _events_rx) = broadcast::channel(16);
1690 let mut pending = Some((tx, vec![json!({ "type": "assistant" })]));
1691 handle_inbound(
1692 json!({ "type": "result", "result": "ok" }),
1693 &mut pending,
1694 &events_tx,
1695 );
1696 assert!(pending.is_none());
1697 let received = rx.blocking_recv().unwrap().unwrap();
1698 assert_eq!(received.result_text(), Some("ok"));
1699 assert_eq!(received.events.len(), 1);
1700 }
1701
1702 #[test]
1703 fn handle_inbound_drops_orphans_without_pending_turn() {
1704 let (events_tx, _events_rx) = broadcast::channel(16);
1705 let mut pending: Option<(oneshot::Sender<Result<TurnResult>>, Vec<Value>)> = None;
1706 handle_inbound(json!({ "type": "assistant" }), &mut pending, &events_tx);
1707 handle_inbound(
1708 json!({ "type": "result", "result": "ok" }),
1709 &mut pending,
1710 &events_tx,
1711 );
1712 assert!(pending.is_none());
1713 }
1714
1715 #[test]
1716 fn handle_inbound_broadcasts_classified_event() {
1717 let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
1718 let (events_tx, mut events_rx) = broadcast::channel(16);
1719 let mut pending = Some((tx, Vec::new()));
1720 handle_inbound(
1721 json!({ "type": "assistant", "message": { "role": "assistant" } }),
1722 &mut pending,
1723 &events_tx,
1724 );
1725 let event = events_rx.try_recv().expect("classified event broadcast");
1726 assert!(matches!(event, InboundEvent::Assistant(_)));
1727 }
1728
1729 #[test]
1730 fn handle_inbound_does_not_broadcast_result() {
1731 let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
1732 let (events_tx, mut events_rx) = broadcast::channel(16);
1733 let mut pending = Some((tx, Vec::new()));
1734 handle_inbound(
1735 json!({ "type": "result", "result": "ok" }),
1736 &mut pending,
1737 &events_tx,
1738 );
1739 assert!(events_rx.try_recv().is_err());
1741 }
1742
1743 #[test]
1744 fn classify_system_init_pulls_session_id() {
1745 let v = json!({
1746 "type": "system",
1747 "subtype": "init",
1748 "session_id": "sess-abc",
1749 });
1750 match classify(&v) {
1751 InboundEvent::SystemInit { session_id } => assert_eq!(session_id, "sess-abc"),
1752 other => panic!("expected SystemInit, got {other:?}"),
1753 }
1754 }
1755
1756 #[test]
1757 fn classify_system_without_init_subtype_is_other() {
1758 let v = json!({ "type": "system", "subtype": "compaction" });
1759 assert!(matches!(classify(&v), InboundEvent::Other(_)));
1760 }
1761
1762 #[test]
1763 fn classify_system_init_without_session_id_is_other() {
1764 let v = json!({ "type": "system", "subtype": "init" });
1765 assert!(matches!(classify(&v), InboundEvent::Other(_)));
1766 }
1767
1768 #[test]
1769 fn classify_assistant_stream_event_user() {
1770 assert!(matches!(
1771 classify(&json!({ "type": "assistant" })),
1772 InboundEvent::Assistant(_)
1773 ));
1774 assert!(matches!(
1775 classify(&json!({ "type": "stream_event" })),
1776 InboundEvent::StreamEvent(_)
1777 ));
1778 assert!(matches!(
1779 classify(&json!({ "type": "user" })),
1780 InboundEvent::User(_)
1781 ));
1782 }
1783
1784 #[test]
1785 fn classify_unknown_type_is_other() {
1786 assert!(matches!(
1787 classify(&json!({ "type": "control_request" })),
1788 InboundEvent::Other(_)
1789 ));
1790 assert!(matches!(
1791 classify(&json!({ "type": "future_thing" })),
1792 InboundEvent::Other(_)
1793 ));
1794 assert!(matches!(classify(&json!({})), InboundEvent::Other(_)));
1795 }
1796
1797 #[test]
1798 fn into_args_does_not_emit_subscriber_capacity_flag() {
1799 let args = DuplexOptions::default().subscriber_capacity(64).into_args();
1801 assert!(!args.iter().any(|a| a.contains("subscriber")));
1802 assert!(!args.iter().any(|a| a.contains("capacity")));
1803 }
1804
1805 #[test]
1806 fn into_args_includes_permission_prompt_tool_when_handler_set() {
1807 let handler = PermissionHandler::new(|_req| async move {
1808 PermissionDecision::Allow {
1809 updated_input: None,
1810 }
1811 });
1812 let args = DuplexOptions::default().on_permission(handler).into_args();
1813 assert!(
1814 args.windows(2)
1815 .any(|w| w == ["--permission-prompt-tool", "stdio"])
1816 );
1817 }
1818
1819 #[test]
1820 fn into_args_omits_permission_prompt_tool_without_handler() {
1821 let args = DuplexOptions::default().into_args();
1822 assert!(!args.iter().any(|a| a == "--permission-prompt-tool"));
1823 }
1824
1825 #[test]
1826 fn into_args_emits_permission_mode_flag() {
1827 let args = DuplexOptions::default()
1828 .permission_mode(PermissionMode::AcceptEdits)
1829 .into_args();
1830 assert!(
1831 args.windows(2)
1832 .any(|w| w == ["--permission-mode", "acceptEdits"]),
1833 "missing --permission-mode acceptEdits in {args:?}"
1834 );
1835 }
1836
1837 #[test]
1838 fn into_args_emits_plan_mode() {
1839 let args = DuplexOptions::default()
1840 .permission_mode(PermissionMode::Plan)
1841 .into_args();
1842 assert!(args.windows(2).any(|w| w == ["--permission-mode", "plan"]));
1843 }
1844
1845 #[test]
1846 fn into_args_omits_permission_mode_by_default() {
1847 let args = DuplexOptions::default().into_args();
1848 assert!(!args.iter().any(|a| a == "--permission-mode"));
1849 }
1850
1851 #[test]
1852 fn into_args_emits_dangerously_skip_permissions_flag() {
1853 let args = DuplexOptions::default()
1854 .dangerously_skip_permissions()
1855 .into_args();
1856 assert!(args.iter().any(|a| a == "--dangerously-skip-permissions"));
1857 }
1858
1859 #[test]
1860 fn into_args_omits_dangerously_skip_by_default() {
1861 let args = DuplexOptions::default().into_args();
1862 assert!(!args.iter().any(|a| a == "--dangerously-skip-permissions"));
1863 }
1864
1865 #[test]
1866 fn parse_permission_request_extracts_fields() {
1867 let msg = json!({
1868 "type": "control_request",
1869 "request_id": "req-1",
1870 "request": {
1871 "subtype": "can_use_tool",
1872 "tool_name": "Bash",
1873 "input": { "command": "ls" }
1874 }
1875 });
1876 let req = parse_permission_request(&msg).expect("permission request");
1877 assert_eq!(req.request_id, "req-1");
1878 assert_eq!(req.tool_name, "Bash");
1879 assert_eq!(req.input, json!({ "command": "ls" }));
1880 assert_eq!(
1881 req.raw.get("subtype").and_then(Value::as_str),
1882 Some("can_use_tool")
1883 );
1884 }
1885
1886 #[test]
1887 fn parse_permission_request_returns_none_when_missing_request_id() {
1888 let msg = json!({
1889 "type": "control_request",
1890 "request": {
1891 "subtype": "can_use_tool",
1892 "tool_name": "Bash",
1893 }
1894 });
1895 assert!(parse_permission_request(&msg).is_none());
1896 }
1897
1898 #[test]
1899 fn parse_permission_request_returns_none_when_missing_tool_name() {
1900 let msg = json!({
1901 "type": "control_request",
1902 "request_id": "req-1",
1903 "request": { "subtype": "can_use_tool" }
1904 });
1905 assert!(parse_permission_request(&msg).is_none());
1906 }
1907
1908 #[test]
1909 fn parse_permission_request_handles_missing_input() {
1910 let msg = json!({
1911 "type": "control_request",
1912 "request_id": "req-1",
1913 "request": {
1914 "subtype": "can_use_tool",
1915 "tool_name": "Bash",
1916 }
1917 });
1918 let req = parse_permission_request(&msg).expect("request");
1919 assert_eq!(req.input, Value::Null);
1920 }
1921
1922 #[test]
1923 fn handle_inbound_returns_permission_for_can_use_tool() {
1924 let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
1925 let (events_tx, _events_rx) = broadcast::channel(16);
1926 let mut pending = Some((tx, Vec::new()));
1927 let action = handle_inbound(
1928 json!({
1929 "type": "control_request",
1930 "request_id": "req-1",
1931 "request": {
1932 "subtype": "can_use_tool",
1933 "tool_name": "Bash",
1934 "input": { "command": "ls" }
1935 }
1936 }),
1937 &mut pending,
1938 &events_tx,
1939 );
1940 match action {
1941 InboundAction::Permission(req) => {
1942 assert_eq!(req.request_id, "req-1");
1943 assert_eq!(req.tool_name, "Bash");
1944 }
1945 InboundAction::None | InboundAction::ControlResponse { .. } => {
1946 panic!("expected Permission action");
1947 }
1948 }
1949 let (_, events) = pending.as_ref().unwrap();
1951 assert_eq!(events.len(), 1);
1952 }
1953
1954 #[test]
1955 fn handle_inbound_treats_unknown_control_request_as_other() {
1956 let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
1957 let (events_tx, mut events_rx) = broadcast::channel(16);
1958 let mut pending = Some((tx, Vec::new()));
1959 let action = handle_inbound(
1960 json!({
1961 "type": "control_request",
1962 "request_id": "req-2",
1963 "request": { "subtype": "future_subtype" }
1964 }),
1965 &mut pending,
1966 &events_tx,
1967 );
1968 assert!(matches!(action, InboundAction::None));
1969 let event = events_rx.try_recv().expect("broadcast");
1970 assert!(matches!(event, InboundEvent::Other(_)));
1971 }
1972
1973 #[tokio::test]
1974 async fn permission_handler_invokes_closure_async() {
1975 let handler = PermissionHandler::new(|req| async move {
1976 if req.tool_name == "Bash" {
1977 PermissionDecision::Deny {
1978 message: "no bash".into(),
1979 }
1980 } else {
1981 PermissionDecision::Allow {
1982 updated_input: None,
1983 }
1984 }
1985 });
1986 let req = PermissionRequest {
1987 request_id: "r1".into(),
1988 tool_name: "Bash".into(),
1989 input: Value::Null,
1990 raw: Value::Null,
1991 };
1992 match handler.invoke(req).await {
1993 PermissionDecision::Deny { message } => assert_eq!(message, "no bash"),
1994 other => panic!("expected Deny, got {other:?}"),
1995 }
1996 }
1997
1998 #[test]
1999 fn parse_control_response_extracts_success() {
2000 let msg = json!({
2001 "type": "control_response",
2002 "response": {
2003 "request_id": "interrupt-1",
2004 "subtype": "success",
2005 "response": {}
2006 }
2007 });
2008 let (id, outcome) = parse_control_response(&msg).expect("parsed");
2009 assert_eq!(id, "interrupt-1");
2010 assert!(outcome.is_ok());
2011 }
2012
2013 #[test]
2014 fn parse_control_response_extracts_error_with_message() {
2015 let msg = json!({
2016 "type": "control_response",
2017 "response": {
2018 "request_id": "interrupt-2",
2019 "subtype": "error",
2020 "error": "no turn in flight"
2021 }
2022 });
2023 let (id, outcome) = parse_control_response(&msg).expect("parsed");
2024 assert_eq!(id, "interrupt-2");
2025 match outcome {
2026 Err(Error::DuplexControlFailed { message }) => {
2027 assert_eq!(message, "no turn in flight");
2028 }
2029 other => panic!("expected DuplexControlFailed, got {other:?}"),
2030 }
2031 }
2032
2033 #[test]
2034 fn parse_control_response_returns_none_on_missing_request_id() {
2035 let msg = json!({
2036 "type": "control_response",
2037 "response": { "subtype": "success" }
2038 });
2039 assert!(parse_control_response(&msg).is_none());
2040 }
2041
2042 #[test]
2043 fn parse_control_response_returns_none_on_unknown_subtype() {
2044 let msg = json!({
2045 "type": "control_response",
2046 "response": { "request_id": "x", "subtype": "future_subtype" }
2047 });
2048 assert!(parse_control_response(&msg).is_none());
2049 }
2050
2051 #[test]
2052 fn handle_inbound_returns_control_response_action() {
2053 let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
2054 let (events_tx, _events_rx) = broadcast::channel(16);
2055 let mut pending = Some((tx, Vec::new()));
2056 let action = handle_inbound(
2057 json!({
2058 "type": "control_response",
2059 "response": {
2060 "request_id": "interrupt-1",
2061 "subtype": "success",
2062 "response": {}
2063 }
2064 }),
2065 &mut pending,
2066 &events_tx,
2067 );
2068 match action {
2069 InboundAction::ControlResponse {
2070 request_id,
2071 outcome,
2072 } => {
2073 assert_eq!(request_id, "interrupt-1");
2074 assert!(outcome.is_ok());
2075 }
2076 InboundAction::None | InboundAction::Permission(_) => {
2077 panic!("expected ControlResponse action");
2078 }
2079 }
2080 }
2081
2082 #[test]
2083 fn handle_inbound_treats_malformed_control_response_as_other() {
2084 let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
2085 let (events_tx, mut events_rx) = broadcast::channel(16);
2086 let mut pending = Some((tx, Vec::new()));
2087 let action = handle_inbound(
2088 json!({
2089 "type": "control_response",
2090 "response": { "subtype": "success" }
2091 }),
2092 &mut pending,
2093 &events_tx,
2094 );
2095 assert!(matches!(action, InboundAction::None));
2096 let event = events_rx.try_recv().expect("broadcast");
2097 assert!(matches!(event, InboundEvent::Other(_)));
2098 }
2099
2100 #[tokio::test]
2101 async fn permission_handler_clones_arc() {
2102 let handler = PermissionHandler::new(|_req| async move {
2103 PermissionDecision::Allow {
2104 updated_input: None,
2105 }
2106 });
2107 let cloned = handler.clone();
2108 let req = PermissionRequest {
2109 request_id: "r1".into(),
2110 tool_name: "Read".into(),
2111 input: Value::Null,
2112 raw: Value::Null,
2113 };
2114 let _ = handler.invoke(req.clone()).await;
2116 let _ = cloned.invoke(req).await;
2117 }
2118
2119 fn fake_session(
2126 initial: SessionExitStatus,
2127 ) -> (
2128 DuplexSession,
2129 watch::Sender<SessionExitStatus>,
2130 oneshot::Sender<()>,
2131 ) {
2132 let (outbound_tx, outbound_rx) = mpsc::unbounded_channel::<OutboundMsg>();
2133 let (events_tx, _events_rx) = broadcast::channel::<InboundEvent>(16);
2134 let (exit_tx, exit_rx) = watch::channel(initial);
2135 let (stop_tx, stop_rx) = oneshot::channel::<()>();
2136
2137 let join = tokio::spawn(async move {
2138 let _outbound_rx = outbound_rx;
2139 let _ = stop_rx.await;
2140 Ok::<(), Error>(())
2141 });
2142
2143 let session = DuplexSession {
2144 outbound_tx,
2145 events_tx,
2146 exit_rx,
2147 join,
2148 };
2149 (session, exit_tx, stop_tx)
2150 }
2151
2152 #[tokio::test]
2153 async fn is_alive_true_while_running() {
2154 let (session, _exit_tx, _stop) = fake_session(SessionExitStatus::Running);
2155 assert!(session.is_alive());
2156 }
2157
2158 #[tokio::test]
2159 async fn is_alive_false_after_completed() {
2160 let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
2161 exit_tx.send(SessionExitStatus::Completed).unwrap();
2162 assert!(!session.is_alive());
2163 }
2164
2165 #[tokio::test]
2166 async fn is_alive_false_after_failed() {
2167 let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
2168 exit_tx
2169 .send(SessionExitStatus::Failed("boom".into()))
2170 .unwrap();
2171 assert!(!session.is_alive());
2172 }
2173
2174 #[tokio::test]
2175 async fn exit_status_reports_running_initially() {
2176 let (session, _exit_tx, _stop) = fake_session(SessionExitStatus::Running);
2177 assert!(matches!(session.exit_status(), SessionExitStatus::Running));
2178 }
2179
2180 #[tokio::test]
2181 async fn exit_status_reflects_completed() {
2182 let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
2183 exit_tx.send(SessionExitStatus::Completed).unwrap();
2184 assert!(matches!(
2185 session.exit_status(),
2186 SessionExitStatus::Completed
2187 ));
2188 }
2189
2190 #[tokio::test]
2191 async fn exit_status_reflects_failed_with_message() {
2192 let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
2193 exit_tx
2194 .send(SessionExitStatus::Failed("oh no".into()))
2195 .unwrap();
2196 match session.exit_status() {
2197 SessionExitStatus::Failed(msg) => assert_eq!(msg, "oh no"),
2198 other => panic!("expected Failed, got {other:?}"),
2199 }
2200 }
2201
2202 #[tokio::test]
2203 async fn wait_for_exit_returns_immediately_when_already_terminal() {
2204 let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
2205 exit_tx.send(SessionExitStatus::Completed).unwrap();
2206 let status = tokio::time::timeout(Duration::from_secs(1), session.wait_for_exit())
2207 .await
2208 .expect("wait_for_exit should not block when already terminal");
2209 assert!(matches!(status, SessionExitStatus::Completed));
2210 }
2211
2212 #[tokio::test]
2213 async fn wait_for_exit_blocks_until_state_transitions() {
2214 let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
2215
2216 let waiter = async { session.wait_for_exit().await };
2217 let driver = async {
2218 tokio::time::sleep(Duration::from_millis(20)).await;
2219 exit_tx.send(SessionExitStatus::Completed).unwrap();
2220 };
2221 let (status, ()) = tokio::join!(waiter, driver);
2222 assert!(matches!(status, SessionExitStatus::Completed));
2223 }
2224
2225 #[tokio::test]
2226 async fn wait_for_exit_supports_multiple_observers() {
2227 let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
2228
2229 let waiter1 = async { session.wait_for_exit().await };
2230 let waiter2 = async { session.wait_for_exit().await };
2231 let driver = async {
2232 tokio::time::sleep(Duration::from_millis(20)).await;
2233 exit_tx
2234 .send(SessionExitStatus::Failed("crash".into()))
2235 .unwrap();
2236 };
2237 let (s1, s2, ()) = tokio::join!(waiter1, waiter2, driver);
2238 match s1 {
2239 SessionExitStatus::Failed(msg) => assert_eq!(msg, "crash"),
2240 other => panic!("waiter1 expected Failed, got {other:?}"),
2241 }
2242 match s2 {
2243 SessionExitStatus::Failed(msg) => assert_eq!(msg, "crash"),
2244 other => panic!("waiter2 expected Failed, got {other:?}"),
2245 }
2246 }
2247
2248 #[tokio::test]
2249 async fn wait_for_exit_returns_last_value_when_sender_dropped() {
2250 let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
2254 let waiter = async { session.wait_for_exit().await };
2255 let driver = async {
2256 tokio::time::sleep(Duration::from_millis(20)).await;
2257 drop(exit_tx);
2258 };
2259 let (status, ()) = tokio::time::timeout(Duration::from_secs(1), async {
2260 tokio::join!(waiter, driver)
2261 })
2262 .await
2263 .expect("wait_for_exit must not hang when sender is dropped");
2264 assert!(matches!(status, SessionExitStatus::Running));
2265 }
2266}