1use std::collections::HashMap;
170use std::future::Future;
171use std::pin::Pin;
172use std::process::Stdio;
173use std::sync::Arc;
174use std::time::Duration;
175
176use serde_json::Value;
177use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
178use tokio::process::{Child, ChildStdin, ChildStdout, Command};
179use tokio::sync::{broadcast, mpsc, oneshot, watch};
180use tokio::task::JoinHandle;
181use tracing::{debug, warn};
182
183use crate::Claude;
184use crate::error::{Error, Result};
185
186pub const DEFAULT_SUBSCRIBER_CAPACITY: usize = 256;
191
192#[derive(Debug, Clone)]
201pub struct PermissionRequest {
202 pub request_id: String,
205 pub tool_name: String,
207 pub input: Value,
209 pub raw: Value,
212}
213
214#[derive(Debug, Clone)]
223pub enum PermissionDecision {
224 Allow {
226 updated_input: Option<Value>,
229 },
230 Deny {
232 message: String,
234 },
235 Defer,
238}
239
240type PermissionFuture = Pin<Box<dyn Future<Output = PermissionDecision> + Send + 'static>>;
241type PermissionFn = dyn Fn(PermissionRequest) -> PermissionFuture + Send + Sync + 'static;
242
243#[derive(Clone)]
256pub struct PermissionHandler {
257 inner: Arc<PermissionFn>,
258}
259
260impl PermissionHandler {
261 pub fn new<F, Fut>(f: F) -> Self
277 where
278 F: Fn(PermissionRequest) -> Fut + Send + Sync + 'static,
279 Fut: Future<Output = PermissionDecision> + Send + 'static,
280 {
281 Self {
282 inner: Arc::new(move |req| Box::pin(f(req))),
283 }
284 }
285
286 fn invoke(&self, req: PermissionRequest) -> PermissionFuture {
287 (self.inner)(req)
288 }
289}
290
291impl std::fmt::Debug for PermissionHandler {
292 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293 f.debug_struct("PermissionHandler").finish_non_exhaustive()
294 }
295}
296
297#[derive(Debug, Default, Clone)]
304pub struct DuplexOptions {
305 model: Option<String>,
306 system_prompt: Option<String>,
307 append_system_prompt: Option<String>,
308 additional_args: Vec<String>,
309 subscriber_capacity: Option<usize>,
310 on_permission: Option<PermissionHandler>,
311}
312
313impl DuplexOptions {
314 #[must_use]
316 pub fn model(mut self, model: impl Into<String>) -> Self {
317 self.model = Some(model.into());
318 self
319 }
320
321 #[must_use]
323 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
324 self.system_prompt = Some(prompt.into());
325 self
326 }
327
328 #[must_use]
330 pub fn append_system_prompt(mut self, prompt: impl Into<String>) -> Self {
331 self.append_system_prompt = Some(prompt.into());
332 self
333 }
334
335 #[must_use]
340 pub fn arg(mut self, arg: impl Into<String>) -> Self {
341 self.additional_args.push(arg.into());
342 self
343 }
344
345 #[must_use]
353 pub fn subscriber_capacity(mut self, capacity: usize) -> Self {
354 self.subscriber_capacity = Some(capacity);
355 self
356 }
357
358 #[must_use]
370 pub fn on_permission(mut self, handler: PermissionHandler) -> Self {
371 self.on_permission = Some(handler);
372 self
373 }
374
375 fn into_args(self) -> Vec<String> {
376 let mut args = vec![
377 "--print".to_string(),
378 "--verbose".to_string(),
379 "--output-format".to_string(),
380 "stream-json".to_string(),
381 "--input-format".to_string(),
382 "stream-json".to_string(),
383 ];
384
385 if let Some(m) = self.model {
386 args.push("--model".to_string());
387 args.push(m);
388 }
389 if let Some(p) = self.system_prompt {
390 args.push("--system-prompt".to_string());
391 args.push(p);
392 }
393 if let Some(p) = self.append_system_prompt {
394 args.push("--append-system-prompt".to_string());
395 args.push(p);
396 }
397 if self.on_permission.is_some() {
398 args.push("--permission-prompt-tool".to_string());
399 args.push("stdio".to_string());
400 }
401 args.extend(self.additional_args);
402
403 args
404 }
405}
406
407#[derive(Debug, Clone)]
414pub struct TurnResult {
415 pub result: Value,
417 pub events: Vec<Value>,
419}
420
421impl TurnResult {
422 #[must_use]
424 pub fn result_text(&self) -> Option<&str> {
425 self.result.get("result").and_then(Value::as_str)
426 }
427
428 #[must_use]
430 pub fn session_id(&self) -> Option<&str> {
431 self.result.get("session_id").and_then(Value::as_str)
432 }
433
434 #[must_use]
437 pub fn total_cost_usd(&self) -> Option<f64> {
438 self.result
439 .get("total_cost_usd")
440 .or_else(|| self.result.get("cost_usd"))
441 .and_then(Value::as_f64)
442 }
443
444 #[must_use]
446 pub fn duration_ms(&self) -> Option<u64> {
447 self.result.get("duration_ms").and_then(Value::as_u64)
448 }
449}
450
451#[derive(Debug, Clone)]
465pub enum InboundEvent {
466 SystemInit {
469 session_id: String,
472 },
473 Assistant(Value),
476 StreamEvent(Value),
479 User(Value),
482 Other(Value),
485}
486
487fn classify(msg: &Value) -> InboundEvent {
488 match msg.get("type").and_then(Value::as_str) {
489 Some("system") => {
490 if msg.get("subtype").and_then(Value::as_str) == Some("init")
491 && let Some(id) = msg.get("session_id").and_then(Value::as_str)
492 {
493 return InboundEvent::SystemInit {
494 session_id: id.to_string(),
495 };
496 }
497 InboundEvent::Other(msg.clone())
498 }
499 Some("assistant") => InboundEvent::Assistant(msg.clone()),
500 Some("stream_event") => InboundEvent::StreamEvent(msg.clone()),
501 Some("user") => InboundEvent::User(msg.clone()),
502 _ => InboundEvent::Other(msg.clone()),
503 }
504}
505
506#[derive(Debug, Clone)]
521pub enum SessionExitStatus {
522 Running,
524 Completed,
527 Failed(String),
530}
531
532#[derive(Debug)]
541pub struct DuplexSession {
542 outbound_tx: mpsc::UnboundedSender<OutboundMsg>,
543 events_tx: broadcast::Sender<InboundEvent>,
544 exit_rx: watch::Receiver<SessionExitStatus>,
545 join: JoinHandle<Result<()>>,
546}
547
548#[derive(Debug)]
549enum OutboundMsg {
550 Send {
551 prompt: String,
552 reply: oneshot::Sender<Result<TurnResult>>,
553 },
554 PermissionResponse {
555 request_id: String,
556 decision: PermissionDecision,
557 },
558 Interrupt {
559 reply: oneshot::Sender<Result<()>>,
560 },
561}
562
563impl DuplexSession {
564 pub async fn spawn(claude: &Claude, opts: DuplexOptions) -> Result<Self> {
572 let capacity = opts
573 .subscriber_capacity
574 .unwrap_or(DEFAULT_SUBSCRIBER_CAPACITY);
575 let permission_handler = opts.on_permission.clone();
576
577 let mut command_args = Vec::new();
578 command_args.extend(claude.global_args.clone());
579 command_args.extend(opts.into_args());
580
581 debug!(
582 binary = %claude.binary.display(),
583 args = ?command_args,
584 "spawning duplex claude session"
585 );
586
587 let mut cmd = Command::new(&claude.binary);
588 cmd.args(&command_args)
589 .env_remove("CLAUDECODE")
590 .env_remove("CLAUDE_CODE_ENTRYPOINT")
591 .envs(&claude.env)
592 .stdin(Stdio::piped())
593 .stdout(Stdio::piped())
594 .stderr(Stdio::piped())
595 .kill_on_drop(true);
596
597 if let Some(ref dir) = claude.working_dir {
598 cmd.current_dir(dir);
599 }
600
601 let mut child = cmd.spawn().map_err(|e| Error::Io {
602 message: format!("failed to spawn claude: {e}"),
603 source: e,
604 working_dir: claude.working_dir.clone(),
605 })?;
606
607 let stdin = child.stdin.take().expect("stdin was piped");
608 let stdout = child.stdout.take().expect("stdout was piped");
609
610 let (outbound_tx, outbound_rx) = mpsc::unbounded_channel();
611 let (events_tx, _initial_rx) = broadcast::channel(capacity);
612 let (exit_tx, exit_rx) = watch::channel(SessionExitStatus::Running);
613
614 let join = tokio::spawn(run_session(
615 child,
616 stdin,
617 stdout,
618 outbound_rx,
619 events_tx.clone(),
620 permission_handler,
621 exit_tx,
622 ));
623
624 Ok(Self {
625 outbound_tx,
626 events_tx,
627 exit_rx,
628 join,
629 })
630 }
631
632 pub async fn send(&self, prompt: impl Into<String>) -> Result<TurnResult> {
638 let (reply_tx, reply_rx) = oneshot::channel();
639 self.outbound_tx
640 .send(OutboundMsg::Send {
641 prompt: prompt.into(),
642 reply: reply_tx,
643 })
644 .map_err(|_| Error::DuplexClosed)?;
645 reply_rx.await.map_err(|_| Error::DuplexClosed)?
646 }
647
648 #[must_use]
683 pub fn subscribe(&self) -> broadcast::Receiver<InboundEvent> {
684 self.events_tx.subscribe()
685 }
686
687 #[must_use]
698 pub fn is_alive(&self) -> bool {
699 matches!(*self.exit_rx.borrow(), SessionExitStatus::Running)
700 }
701
702 #[must_use]
711 pub fn exit_status(&self) -> SessionExitStatus {
712 self.exit_rx.borrow().clone()
713 }
714
715 pub async fn wait_for_exit(&self) -> SessionExitStatus {
727 let mut rx = self.exit_rx.clone();
728 loop {
729 {
730 let value = rx.borrow_and_update();
731 if !matches!(*value, SessionExitStatus::Running) {
732 return value.clone();
733 }
734 }
735 if rx.changed().await.is_err() {
736 return rx.borrow().clone();
737 }
738 }
739 }
740
741 pub fn respond_to_permission(
786 &self,
787 request_id: impl Into<String>,
788 decision: PermissionDecision,
789 ) -> Result<()> {
790 if matches!(decision, PermissionDecision::Defer) {
791 warn!("respond_to_permission called with Defer; ignoring");
792 return Ok(());
793 }
794 self.outbound_tx
795 .send(OutboundMsg::PermissionResponse {
796 request_id: request_id.into(),
797 decision,
798 })
799 .map_err(|_| Error::DuplexClosed)?;
800 Ok(())
801 }
802
803 pub async fn interrupt(&self) -> Result<()> {
844 let (reply_tx, reply_rx) = oneshot::channel();
845 self.outbound_tx
846 .send(OutboundMsg::Interrupt { reply: reply_tx })
847 .map_err(|_| Error::DuplexClosed)?;
848 reply_rx.await.map_err(|_| Error::DuplexClosed)?
849 }
850
851 pub async fn close(self) -> Result<()> {
857 drop(self.outbound_tx);
858 drop(self.events_tx);
859 match self.join.await {
860 Ok(result) => result,
861 Err(e) if e.is_cancelled() => Ok(()),
862 Err(e) => Err(Error::Io {
863 message: format!("duplex session task panicked: {e}"),
864 source: std::io::Error::other(e.to_string()),
865 working_dir: None,
866 }),
867 }
868 }
869}
870
871const SHUTDOWN_BUDGET: Duration = Duration::from_secs(5);
875
876async fn run_session(
877 mut child: Child,
878 mut stdin: ChildStdin,
879 stdout: ChildStdout,
880 mut outbound_rx: mpsc::UnboundedReceiver<OutboundMsg>,
881 events_tx: broadcast::Sender<InboundEvent>,
882 permission_handler: Option<PermissionHandler>,
883 exit_tx: watch::Sender<SessionExitStatus>,
884) -> Result<()> {
885 let mut lines = BufReader::new(stdout).lines();
886 let mut pending: Option<(oneshot::Sender<Result<TurnResult>>, Vec<Value>)> = None;
887 let mut pending_control: HashMap<String, oneshot::Sender<Result<()>>> = HashMap::new();
888 let mut next_control_id: u64 = 0;
889 let mut stream_err: Option<Error> = None;
890
891 loop {
892 tokio::select! {
893 biased;
894
895 line = lines.next_line() => match line {
896 Ok(Some(l)) => {
897 if l.trim().is_empty() {
898 continue;
899 }
900 let parsed = match serde_json::from_str::<Value>(&l) {
901 Ok(v) => v,
902 Err(e) => {
903 debug!(line = %l, error = %e, "failed to parse duplex event, skipping");
904 continue;
905 }
906 };
907 match handle_inbound(parsed, &mut pending, &events_tx) {
908 InboundAction::None => {}
909 InboundAction::Permission(req) => {
910 let request_id = req.request_id.clone();
911 let decision = match permission_handler.as_ref() {
912 Some(h) => h.invoke(req).await,
913 None => {
914 warn!(
915 request_id = %request_id,
916 "received can_use_tool with no permission handler; auto-denying"
917 );
918 PermissionDecision::Deny {
919 message:
920 "no permission handler configured on duplex session"
921 .into(),
922 }
923 }
924 };
925 if matches!(decision, PermissionDecision::Defer) {
926 debug!(
927 request_id = %request_id,
928 "permission handler deferred; waiting for respond_to_permission"
929 );
930 } else if let Err(e) =
931 write_permission_response(&mut stdin, &request_id, &decision).await
932 {
933 warn!(error = %e, "failed to write permission response");
934 }
935 }
936 InboundAction::ControlResponse { request_id, outcome } => {
937 if let Some(reply) = pending_control.remove(&request_id) {
938 let _ = reply.send(outcome);
939 } else {
940 debug!(
941 request_id = %request_id,
942 "received control_response with no pending request"
943 );
944 }
945 }
946 }
947 }
948 Ok(None) => break,
949 Err(e) => {
950 stream_err = Some(Error::Io {
951 message: "failed to read duplex stdout".to_string(),
952 source: e,
953 working_dir: None,
954 });
955 break;
956 }
957 },
958
959 msg = outbound_rx.recv() => match msg {
960 Some(OutboundMsg::Send { prompt, reply }) => {
961 if pending.is_some() {
962 let _ = reply.send(Err(Error::DuplexTurnInFlight));
963 continue;
964 }
965 if let Err(e) = write_user(&mut stdin, &prompt).await {
966 let _ = reply.send(Err(e));
967 continue;
968 }
969 pending = Some((reply, Vec::new()));
970 }
971 Some(OutboundMsg::PermissionResponse { request_id, decision }) => {
972 if let Err(e) =
973 write_permission_response(&mut stdin, &request_id, &decision).await
974 {
975 warn!(error = %e, "failed to write deferred permission response");
976 }
977 }
978 Some(OutboundMsg::Interrupt { reply }) => {
979 next_control_id += 1;
980 let request_id = format!("interrupt-{next_control_id}");
981 if let Err(e) =
982 write_control_request(&mut stdin, &request_id, "interrupt").await
983 {
984 let _ = reply.send(Err(e));
985 continue;
986 }
987 pending_control.insert(request_id, reply);
988 }
989 None => break,
990 },
991 }
992 }
993
994 drop(stdin);
995 match tokio::time::timeout(SHUTDOWN_BUDGET, child.wait()).await {
996 Ok(Ok(_status)) => {}
997 Ok(Err(e)) => {
998 warn!(error = %e, "failed to wait for duplex child");
999 }
1000 Err(_) => {
1001 warn!("duplex child did not exit within shutdown budget; killing");
1002 let _ = child.kill().await;
1003 }
1004 }
1005
1006 if let Some((reply, _)) = pending.take() {
1007 let _ = reply.send(Err(Error::DuplexClosed));
1008 }
1009 for (_, reply) in pending_control.drain() {
1010 let _ = reply.send(Err(Error::DuplexClosed));
1011 }
1012
1013 let result = match stream_err {
1014 Some(e) => Err(e),
1015 None => Ok(()),
1016 };
1017 let final_state = match &result {
1018 Ok(()) => SessionExitStatus::Completed,
1019 Err(e) => SessionExitStatus::Failed(e.to_string()),
1020 };
1021 let _ = exit_tx.send(final_state);
1022 result
1023}
1024
1025enum InboundAction {
1029 None,
1031 Permission(PermissionRequest),
1035 ControlResponse {
1040 request_id: String,
1041 outcome: Result<()>,
1042 },
1043}
1044
1045fn handle_inbound(
1046 msg: Value,
1047 pending: &mut Option<(oneshot::Sender<Result<TurnResult>>, Vec<Value>)>,
1048 events_tx: &broadcast::Sender<InboundEvent>,
1049) -> InboundAction {
1050 match msg.get("type").and_then(Value::as_str) {
1051 Some("result") => {
1052 if let Some((reply, events)) = pending.take() {
1053 let _ = reply.send(Ok(TurnResult {
1054 result: msg,
1055 events,
1056 }));
1057 } else {
1058 debug!("dropping orphan result event with no pending turn");
1059 }
1060 InboundAction::None
1061 }
1062 Some("control_request") => {
1063 if msg
1066 .get("request")
1067 .and_then(|r| r.get("subtype"))
1068 .and_then(Value::as_str)
1069 == Some("can_use_tool")
1070 && let Some(req) = parse_permission_request(&msg)
1071 {
1072 if let Some((_, events)) = pending.as_mut() {
1073 events.push(msg);
1074 }
1075 return InboundAction::Permission(req);
1076 }
1077 debug!(
1078 ?msg,
1079 "received unhandled control_request; treating as Other"
1080 );
1081 let _ = events_tx.send(InboundEvent::Other(msg.clone()));
1082 if let Some((_, events)) = pending.as_mut() {
1083 events.push(msg);
1084 }
1085 InboundAction::None
1086 }
1087 Some("control_response") => {
1088 if let Some((request_id, outcome)) = parse_control_response(&msg) {
1089 return InboundAction::ControlResponse {
1090 request_id,
1091 outcome,
1092 };
1093 }
1094 debug!(
1095 ?msg,
1096 "received malformed control_response; treating as Other"
1097 );
1098 let _ = events_tx.send(InboundEvent::Other(msg.clone()));
1099 if let Some((_, events)) = pending.as_mut() {
1100 events.push(msg);
1101 }
1102 InboundAction::None
1103 }
1104 _ => {
1105 let _ = events_tx.send(classify(&msg));
1108
1109 if let Some((_, events)) = pending.as_mut() {
1110 events.push(msg);
1111 } else {
1112 debug!("dropping inbound event with no pending turn");
1113 }
1114 InboundAction::None
1115 }
1116 }
1117}
1118
1119fn parse_permission_request(msg: &Value) -> Option<PermissionRequest> {
1120 let request_id = msg.get("request_id").and_then(Value::as_str)?;
1121 let request = msg.get("request")?;
1122 let tool_name = request.get("tool_name").and_then(Value::as_str)?;
1123 let input = request.get("input").cloned().unwrap_or(Value::Null);
1124 Some(PermissionRequest {
1125 request_id: request_id.to_string(),
1126 tool_name: tool_name.to_string(),
1127 input,
1128 raw: request.clone(),
1129 })
1130}
1131
1132fn parse_control_response(msg: &Value) -> Option<(String, Result<()>)> {
1138 let response = msg.get("response")?;
1139 let request_id = response.get("request_id").and_then(Value::as_str)?;
1140 let outcome = match response.get("subtype").and_then(Value::as_str) {
1141 Some("success") => Ok(()),
1142 Some("error") => {
1143 let message = response
1144 .get("error")
1145 .and_then(Value::as_str)
1146 .unwrap_or("unknown control_response error")
1147 .to_string();
1148 Err(Error::DuplexControlFailed { message })
1149 }
1150 _ => return None,
1151 };
1152 Some((request_id.to_string(), outcome))
1153}
1154
1155async fn write_user(stdin: &mut ChildStdin, prompt: &str) -> Result<()> {
1156 let user_msg = serde_json::json!({
1157 "type": "user",
1158 "message": {
1159 "role": "user",
1160 "content": prompt,
1161 },
1162 "parent_tool_use_id": null,
1163 });
1164 write_line(stdin, &user_msg, "user message").await
1165}
1166
1167async fn write_control_request(
1168 stdin: &mut ChildStdin,
1169 request_id: &str,
1170 subtype: &str,
1171) -> Result<()> {
1172 let envelope = serde_json::json!({
1173 "type": "control_request",
1174 "request_id": request_id,
1175 "request": { "subtype": subtype },
1176 });
1177 write_line(stdin, &envelope, "control_request").await
1178}
1179
1180async fn write_permission_response(
1181 stdin: &mut ChildStdin,
1182 request_id: &str,
1183 decision: &PermissionDecision,
1184) -> Result<()> {
1185 let inner = match decision {
1186 PermissionDecision::Allow { updated_input } => {
1187 let mut obj = serde_json::Map::new();
1188 obj.insert("behavior".to_string(), Value::String("allow".to_string()));
1189 if let Some(input) = updated_input {
1190 obj.insert("updatedInput".to_string(), input.clone());
1191 }
1192 Value::Object(obj)
1193 }
1194 PermissionDecision::Deny { message } => serde_json::json!({
1195 "behavior": "deny",
1196 "message": message,
1197 }),
1198 PermissionDecision::Defer => {
1199 return Ok(());
1201 }
1202 };
1203 let envelope = serde_json::json!({
1204 "type": "control_response",
1205 "response": {
1206 "request_id": request_id,
1207 "subtype": "success",
1208 "response": inner,
1209 },
1210 });
1211 write_line(stdin, &envelope, "control_response").await
1212}
1213
1214async fn write_line(stdin: &mut ChildStdin, value: &Value, what: &'static str) -> Result<()> {
1215 let mut line = serde_json::to_string(value).map_err(|e| Error::Json {
1216 message: format!("failed to serialize duplex {what}"),
1217 source: e,
1218 })?;
1219 line.push('\n');
1220 stdin
1221 .write_all(line.as_bytes())
1222 .await
1223 .map_err(|e| Error::Io {
1224 message: format!("failed to write {what} to duplex stdin"),
1225 source: e,
1226 working_dir: None,
1227 })?;
1228 stdin.flush().await.map_err(|e| Error::Io {
1229 message: "failed to flush duplex stdin".to_string(),
1230 source: e,
1231 working_dir: None,
1232 })?;
1233 Ok(())
1234}
1235
1236#[cfg(test)]
1237mod tests {
1238 use super::*;
1239 use serde_json::json;
1240
1241 #[test]
1242 fn into_args_default_includes_required_flags() {
1243 let args = DuplexOptions::default().into_args();
1244 assert!(args.contains(&"--print".to_string()));
1245 assert!(args.contains(&"--verbose".to_string()));
1246 assert!(
1247 args.windows(2)
1248 .any(|w| w == ["--output-format", "stream-json"])
1249 );
1250 assert!(
1251 args.windows(2)
1252 .any(|w| w == ["--input-format", "stream-json"])
1253 );
1254 }
1255
1256 #[test]
1257 fn into_args_includes_model() {
1258 let args = DuplexOptions::default().model("haiku").into_args();
1259 assert!(args.windows(2).any(|w| w == ["--model", "haiku"]));
1260 }
1261
1262 #[test]
1263 fn into_args_includes_system_prompts() {
1264 let args = DuplexOptions::default()
1265 .system_prompt("be concise")
1266 .append_system_prompt("also polite")
1267 .into_args();
1268 assert!(
1269 args.windows(2)
1270 .any(|w| w == ["--system-prompt", "be concise"])
1271 );
1272 assert!(
1273 args.windows(2)
1274 .any(|w| w == ["--append-system-prompt", "also polite"])
1275 );
1276 }
1277
1278 #[test]
1279 fn into_args_appends_raw_args_last() {
1280 let args = DuplexOptions::default()
1281 .arg("--add-dir")
1282 .arg("/tmp/foo")
1283 .into_args();
1284 assert_eq!(&args[args.len() - 2..], &["--add-dir", "/tmp/foo"]);
1286 }
1287
1288 #[test]
1289 fn turn_result_accessors_pull_from_result() {
1290 let r = TurnResult {
1291 result: json!({
1292 "type": "result",
1293 "result": "hello",
1294 "session_id": "sess-123",
1295 "total_cost_usd": 0.0042,
1296 "duration_ms": 1234_u64,
1297 }),
1298 events: vec![],
1299 };
1300 assert_eq!(r.result_text(), Some("hello"));
1301 assert_eq!(r.session_id(), Some("sess-123"));
1302 assert_eq!(r.total_cost_usd(), Some(0.0042));
1303 assert_eq!(r.duration_ms(), Some(1234));
1304 }
1305
1306 #[test]
1307 fn turn_result_total_cost_falls_back_to_legacy_field() {
1308 let r = TurnResult {
1309 result: json!({ "cost_usd": 0.5 }),
1310 events: vec![],
1311 };
1312 assert_eq!(r.total_cost_usd(), Some(0.5));
1313 }
1314
1315 #[test]
1316 fn turn_result_accessors_return_none_when_missing() {
1317 let r = TurnResult {
1318 result: json!({}),
1319 events: vec![],
1320 };
1321 assert_eq!(r.result_text(), None);
1322 assert_eq!(r.session_id(), None);
1323 assert_eq!(r.total_cost_usd(), None);
1324 assert_eq!(r.duration_ms(), None);
1325 }
1326
1327 #[test]
1328 fn handle_inbound_appends_non_result_to_pending_events() {
1329 let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
1330 let (events_tx, _events_rx) = broadcast::channel(16);
1331 let mut pending = Some((tx, Vec::new()));
1332 handle_inbound(
1333 json!({ "type": "assistant", "message": {} }),
1334 &mut pending,
1335 &events_tx,
1336 );
1337 let (_, events) = pending.as_ref().unwrap();
1338 assert_eq!(events.len(), 1);
1339 assert_eq!(
1340 events[0].get("type").and_then(Value::as_str),
1341 Some("assistant")
1342 );
1343 }
1344
1345 #[test]
1346 fn handle_inbound_resolves_pending_on_result() {
1347 let (tx, rx) = oneshot::channel::<Result<TurnResult>>();
1348 let (events_tx, _events_rx) = broadcast::channel(16);
1349 let mut pending = Some((tx, vec![json!({ "type": "assistant" })]));
1350 handle_inbound(
1351 json!({ "type": "result", "result": "ok" }),
1352 &mut pending,
1353 &events_tx,
1354 );
1355 assert!(pending.is_none());
1356 let received = rx.blocking_recv().unwrap().unwrap();
1357 assert_eq!(received.result_text(), Some("ok"));
1358 assert_eq!(received.events.len(), 1);
1359 }
1360
1361 #[test]
1362 fn handle_inbound_drops_orphans_without_pending_turn() {
1363 let (events_tx, _events_rx) = broadcast::channel(16);
1364 let mut pending: Option<(oneshot::Sender<Result<TurnResult>>, Vec<Value>)> = None;
1365 handle_inbound(json!({ "type": "assistant" }), &mut pending, &events_tx);
1366 handle_inbound(
1367 json!({ "type": "result", "result": "ok" }),
1368 &mut pending,
1369 &events_tx,
1370 );
1371 assert!(pending.is_none());
1372 }
1373
1374 #[test]
1375 fn handle_inbound_broadcasts_classified_event() {
1376 let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
1377 let (events_tx, mut events_rx) = broadcast::channel(16);
1378 let mut pending = Some((tx, Vec::new()));
1379 handle_inbound(
1380 json!({ "type": "assistant", "message": { "role": "assistant" } }),
1381 &mut pending,
1382 &events_tx,
1383 );
1384 let event = events_rx.try_recv().expect("classified event broadcast");
1385 assert!(matches!(event, InboundEvent::Assistant(_)));
1386 }
1387
1388 #[test]
1389 fn handle_inbound_does_not_broadcast_result() {
1390 let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
1391 let (events_tx, mut events_rx) = broadcast::channel(16);
1392 let mut pending = Some((tx, Vec::new()));
1393 handle_inbound(
1394 json!({ "type": "result", "result": "ok" }),
1395 &mut pending,
1396 &events_tx,
1397 );
1398 assert!(events_rx.try_recv().is_err());
1400 }
1401
1402 #[test]
1403 fn classify_system_init_pulls_session_id() {
1404 let v = json!({
1405 "type": "system",
1406 "subtype": "init",
1407 "session_id": "sess-abc",
1408 });
1409 match classify(&v) {
1410 InboundEvent::SystemInit { session_id } => assert_eq!(session_id, "sess-abc"),
1411 other => panic!("expected SystemInit, got {other:?}"),
1412 }
1413 }
1414
1415 #[test]
1416 fn classify_system_without_init_subtype_is_other() {
1417 let v = json!({ "type": "system", "subtype": "compaction" });
1418 assert!(matches!(classify(&v), InboundEvent::Other(_)));
1419 }
1420
1421 #[test]
1422 fn classify_system_init_without_session_id_is_other() {
1423 let v = json!({ "type": "system", "subtype": "init" });
1424 assert!(matches!(classify(&v), InboundEvent::Other(_)));
1425 }
1426
1427 #[test]
1428 fn classify_assistant_stream_event_user() {
1429 assert!(matches!(
1430 classify(&json!({ "type": "assistant" })),
1431 InboundEvent::Assistant(_)
1432 ));
1433 assert!(matches!(
1434 classify(&json!({ "type": "stream_event" })),
1435 InboundEvent::StreamEvent(_)
1436 ));
1437 assert!(matches!(
1438 classify(&json!({ "type": "user" })),
1439 InboundEvent::User(_)
1440 ));
1441 }
1442
1443 #[test]
1444 fn classify_unknown_type_is_other() {
1445 assert!(matches!(
1446 classify(&json!({ "type": "control_request" })),
1447 InboundEvent::Other(_)
1448 ));
1449 assert!(matches!(
1450 classify(&json!({ "type": "future_thing" })),
1451 InboundEvent::Other(_)
1452 ));
1453 assert!(matches!(classify(&json!({})), InboundEvent::Other(_)));
1454 }
1455
1456 #[test]
1457 fn into_args_does_not_emit_subscriber_capacity_flag() {
1458 let args = DuplexOptions::default().subscriber_capacity(64).into_args();
1460 assert!(!args.iter().any(|a| a.contains("subscriber")));
1461 assert!(!args.iter().any(|a| a.contains("capacity")));
1462 }
1463
1464 #[test]
1465 fn into_args_includes_permission_prompt_tool_when_handler_set() {
1466 let handler = PermissionHandler::new(|_req| async move {
1467 PermissionDecision::Allow {
1468 updated_input: None,
1469 }
1470 });
1471 let args = DuplexOptions::default().on_permission(handler).into_args();
1472 assert!(
1473 args.windows(2)
1474 .any(|w| w == ["--permission-prompt-tool", "stdio"])
1475 );
1476 }
1477
1478 #[test]
1479 fn into_args_omits_permission_prompt_tool_without_handler() {
1480 let args = DuplexOptions::default().into_args();
1481 assert!(!args.iter().any(|a| a == "--permission-prompt-tool"));
1482 }
1483
1484 #[test]
1485 fn parse_permission_request_extracts_fields() {
1486 let msg = json!({
1487 "type": "control_request",
1488 "request_id": "req-1",
1489 "request": {
1490 "subtype": "can_use_tool",
1491 "tool_name": "Bash",
1492 "input": { "command": "ls" }
1493 }
1494 });
1495 let req = parse_permission_request(&msg).expect("permission request");
1496 assert_eq!(req.request_id, "req-1");
1497 assert_eq!(req.tool_name, "Bash");
1498 assert_eq!(req.input, json!({ "command": "ls" }));
1499 assert_eq!(
1500 req.raw.get("subtype").and_then(Value::as_str),
1501 Some("can_use_tool")
1502 );
1503 }
1504
1505 #[test]
1506 fn parse_permission_request_returns_none_when_missing_request_id() {
1507 let msg = json!({
1508 "type": "control_request",
1509 "request": {
1510 "subtype": "can_use_tool",
1511 "tool_name": "Bash",
1512 }
1513 });
1514 assert!(parse_permission_request(&msg).is_none());
1515 }
1516
1517 #[test]
1518 fn parse_permission_request_returns_none_when_missing_tool_name() {
1519 let msg = json!({
1520 "type": "control_request",
1521 "request_id": "req-1",
1522 "request": { "subtype": "can_use_tool" }
1523 });
1524 assert!(parse_permission_request(&msg).is_none());
1525 }
1526
1527 #[test]
1528 fn parse_permission_request_handles_missing_input() {
1529 let msg = json!({
1530 "type": "control_request",
1531 "request_id": "req-1",
1532 "request": {
1533 "subtype": "can_use_tool",
1534 "tool_name": "Bash",
1535 }
1536 });
1537 let req = parse_permission_request(&msg).expect("request");
1538 assert_eq!(req.input, Value::Null);
1539 }
1540
1541 #[test]
1542 fn handle_inbound_returns_permission_for_can_use_tool() {
1543 let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
1544 let (events_tx, _events_rx) = broadcast::channel(16);
1545 let mut pending = Some((tx, Vec::new()));
1546 let action = handle_inbound(
1547 json!({
1548 "type": "control_request",
1549 "request_id": "req-1",
1550 "request": {
1551 "subtype": "can_use_tool",
1552 "tool_name": "Bash",
1553 "input": { "command": "ls" }
1554 }
1555 }),
1556 &mut pending,
1557 &events_tx,
1558 );
1559 match action {
1560 InboundAction::Permission(req) => {
1561 assert_eq!(req.request_id, "req-1");
1562 assert_eq!(req.tool_name, "Bash");
1563 }
1564 InboundAction::None | InboundAction::ControlResponse { .. } => {
1565 panic!("expected Permission action");
1566 }
1567 }
1568 let (_, events) = pending.as_ref().unwrap();
1570 assert_eq!(events.len(), 1);
1571 }
1572
1573 #[test]
1574 fn handle_inbound_treats_unknown_control_request_as_other() {
1575 let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
1576 let (events_tx, mut events_rx) = broadcast::channel(16);
1577 let mut pending = Some((tx, Vec::new()));
1578 let action = handle_inbound(
1579 json!({
1580 "type": "control_request",
1581 "request_id": "req-2",
1582 "request": { "subtype": "future_subtype" }
1583 }),
1584 &mut pending,
1585 &events_tx,
1586 );
1587 assert!(matches!(action, InboundAction::None));
1588 let event = events_rx.try_recv().expect("broadcast");
1589 assert!(matches!(event, InboundEvent::Other(_)));
1590 }
1591
1592 #[tokio::test]
1593 async fn permission_handler_invokes_closure_async() {
1594 let handler = PermissionHandler::new(|req| async move {
1595 if req.tool_name == "Bash" {
1596 PermissionDecision::Deny {
1597 message: "no bash".into(),
1598 }
1599 } else {
1600 PermissionDecision::Allow {
1601 updated_input: None,
1602 }
1603 }
1604 });
1605 let req = PermissionRequest {
1606 request_id: "r1".into(),
1607 tool_name: "Bash".into(),
1608 input: Value::Null,
1609 raw: Value::Null,
1610 };
1611 match handler.invoke(req).await {
1612 PermissionDecision::Deny { message } => assert_eq!(message, "no bash"),
1613 other => panic!("expected Deny, got {other:?}"),
1614 }
1615 }
1616
1617 #[test]
1618 fn parse_control_response_extracts_success() {
1619 let msg = json!({
1620 "type": "control_response",
1621 "response": {
1622 "request_id": "interrupt-1",
1623 "subtype": "success",
1624 "response": {}
1625 }
1626 });
1627 let (id, outcome) = parse_control_response(&msg).expect("parsed");
1628 assert_eq!(id, "interrupt-1");
1629 assert!(outcome.is_ok());
1630 }
1631
1632 #[test]
1633 fn parse_control_response_extracts_error_with_message() {
1634 let msg = json!({
1635 "type": "control_response",
1636 "response": {
1637 "request_id": "interrupt-2",
1638 "subtype": "error",
1639 "error": "no turn in flight"
1640 }
1641 });
1642 let (id, outcome) = parse_control_response(&msg).expect("parsed");
1643 assert_eq!(id, "interrupt-2");
1644 match outcome {
1645 Err(Error::DuplexControlFailed { message }) => {
1646 assert_eq!(message, "no turn in flight");
1647 }
1648 other => panic!("expected DuplexControlFailed, got {other:?}"),
1649 }
1650 }
1651
1652 #[test]
1653 fn parse_control_response_returns_none_on_missing_request_id() {
1654 let msg = json!({
1655 "type": "control_response",
1656 "response": { "subtype": "success" }
1657 });
1658 assert!(parse_control_response(&msg).is_none());
1659 }
1660
1661 #[test]
1662 fn parse_control_response_returns_none_on_unknown_subtype() {
1663 let msg = json!({
1664 "type": "control_response",
1665 "response": { "request_id": "x", "subtype": "future_subtype" }
1666 });
1667 assert!(parse_control_response(&msg).is_none());
1668 }
1669
1670 #[test]
1671 fn handle_inbound_returns_control_response_action() {
1672 let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
1673 let (events_tx, _events_rx) = broadcast::channel(16);
1674 let mut pending = Some((tx, Vec::new()));
1675 let action = handle_inbound(
1676 json!({
1677 "type": "control_response",
1678 "response": {
1679 "request_id": "interrupt-1",
1680 "subtype": "success",
1681 "response": {}
1682 }
1683 }),
1684 &mut pending,
1685 &events_tx,
1686 );
1687 match action {
1688 InboundAction::ControlResponse {
1689 request_id,
1690 outcome,
1691 } => {
1692 assert_eq!(request_id, "interrupt-1");
1693 assert!(outcome.is_ok());
1694 }
1695 InboundAction::None | InboundAction::Permission(_) => {
1696 panic!("expected ControlResponse action");
1697 }
1698 }
1699 }
1700
1701 #[test]
1702 fn handle_inbound_treats_malformed_control_response_as_other() {
1703 let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
1704 let (events_tx, mut events_rx) = broadcast::channel(16);
1705 let mut pending = Some((tx, Vec::new()));
1706 let action = handle_inbound(
1707 json!({
1708 "type": "control_response",
1709 "response": { "subtype": "success" }
1710 }),
1711 &mut pending,
1712 &events_tx,
1713 );
1714 assert!(matches!(action, InboundAction::None));
1715 let event = events_rx.try_recv().expect("broadcast");
1716 assert!(matches!(event, InboundEvent::Other(_)));
1717 }
1718
1719 #[tokio::test]
1720 async fn permission_handler_clones_arc() {
1721 let handler = PermissionHandler::new(|_req| async move {
1722 PermissionDecision::Allow {
1723 updated_input: None,
1724 }
1725 });
1726 let cloned = handler.clone();
1727 let req = PermissionRequest {
1728 request_id: "r1".into(),
1729 tool_name: "Read".into(),
1730 input: Value::Null,
1731 raw: Value::Null,
1732 };
1733 let _ = handler.invoke(req.clone()).await;
1735 let _ = cloned.invoke(req).await;
1736 }
1737
1738 fn fake_session(
1745 initial: SessionExitStatus,
1746 ) -> (
1747 DuplexSession,
1748 watch::Sender<SessionExitStatus>,
1749 oneshot::Sender<()>,
1750 ) {
1751 let (outbound_tx, outbound_rx) = mpsc::unbounded_channel::<OutboundMsg>();
1752 let (events_tx, _events_rx) = broadcast::channel::<InboundEvent>(16);
1753 let (exit_tx, exit_rx) = watch::channel(initial);
1754 let (stop_tx, stop_rx) = oneshot::channel::<()>();
1755
1756 let join = tokio::spawn(async move {
1757 let _outbound_rx = outbound_rx;
1758 let _ = stop_rx.await;
1759 Ok::<(), Error>(())
1760 });
1761
1762 let session = DuplexSession {
1763 outbound_tx,
1764 events_tx,
1765 exit_rx,
1766 join,
1767 };
1768 (session, exit_tx, stop_tx)
1769 }
1770
1771 #[tokio::test]
1772 async fn is_alive_true_while_running() {
1773 let (session, _exit_tx, _stop) = fake_session(SessionExitStatus::Running);
1774 assert!(session.is_alive());
1775 }
1776
1777 #[tokio::test]
1778 async fn is_alive_false_after_completed() {
1779 let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
1780 exit_tx.send(SessionExitStatus::Completed).unwrap();
1781 assert!(!session.is_alive());
1782 }
1783
1784 #[tokio::test]
1785 async fn is_alive_false_after_failed() {
1786 let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
1787 exit_tx
1788 .send(SessionExitStatus::Failed("boom".into()))
1789 .unwrap();
1790 assert!(!session.is_alive());
1791 }
1792
1793 #[tokio::test]
1794 async fn exit_status_reports_running_initially() {
1795 let (session, _exit_tx, _stop) = fake_session(SessionExitStatus::Running);
1796 assert!(matches!(session.exit_status(), SessionExitStatus::Running));
1797 }
1798
1799 #[tokio::test]
1800 async fn exit_status_reflects_completed() {
1801 let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
1802 exit_tx.send(SessionExitStatus::Completed).unwrap();
1803 assert!(matches!(
1804 session.exit_status(),
1805 SessionExitStatus::Completed
1806 ));
1807 }
1808
1809 #[tokio::test]
1810 async fn exit_status_reflects_failed_with_message() {
1811 let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
1812 exit_tx
1813 .send(SessionExitStatus::Failed("oh no".into()))
1814 .unwrap();
1815 match session.exit_status() {
1816 SessionExitStatus::Failed(msg) => assert_eq!(msg, "oh no"),
1817 other => panic!("expected Failed, got {other:?}"),
1818 }
1819 }
1820
1821 #[tokio::test]
1822 async fn wait_for_exit_returns_immediately_when_already_terminal() {
1823 let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
1824 exit_tx.send(SessionExitStatus::Completed).unwrap();
1825 let status = tokio::time::timeout(Duration::from_secs(1), session.wait_for_exit())
1826 .await
1827 .expect("wait_for_exit should not block when already terminal");
1828 assert!(matches!(status, SessionExitStatus::Completed));
1829 }
1830
1831 #[tokio::test]
1832 async fn wait_for_exit_blocks_until_state_transitions() {
1833 let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
1834
1835 let waiter = async { session.wait_for_exit().await };
1836 let driver = async {
1837 tokio::time::sleep(Duration::from_millis(20)).await;
1838 exit_tx.send(SessionExitStatus::Completed).unwrap();
1839 };
1840 let (status, ()) = tokio::join!(waiter, driver);
1841 assert!(matches!(status, SessionExitStatus::Completed));
1842 }
1843
1844 #[tokio::test]
1845 async fn wait_for_exit_supports_multiple_observers() {
1846 let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
1847
1848 let waiter1 = async { session.wait_for_exit().await };
1849 let waiter2 = async { session.wait_for_exit().await };
1850 let driver = async {
1851 tokio::time::sleep(Duration::from_millis(20)).await;
1852 exit_tx
1853 .send(SessionExitStatus::Failed("crash".into()))
1854 .unwrap();
1855 };
1856 let (s1, s2, ()) = tokio::join!(waiter1, waiter2, driver);
1857 match s1 {
1858 SessionExitStatus::Failed(msg) => assert_eq!(msg, "crash"),
1859 other => panic!("waiter1 expected Failed, got {other:?}"),
1860 }
1861 match s2 {
1862 SessionExitStatus::Failed(msg) => assert_eq!(msg, "crash"),
1863 other => panic!("waiter2 expected Failed, got {other:?}"),
1864 }
1865 }
1866
1867 #[tokio::test]
1868 async fn wait_for_exit_returns_last_value_when_sender_dropped() {
1869 let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
1873 let waiter = async { session.wait_for_exit().await };
1874 let driver = async {
1875 tokio::time::sleep(Duration::from_millis(20)).await;
1876 drop(exit_tx);
1877 };
1878 let (status, ()) = tokio::time::timeout(Duration::from_secs(1), async {
1879 tokio::join!(waiter, driver)
1880 })
1881 .await
1882 .expect("wait_for_exit must not hang when sender is dropped");
1883 assert!(matches!(status, SessionExitStatus::Running));
1884 }
1885}