use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::process::Stdio;
use std::sync::Arc;
use std::time::Duration;
use serde_json::Value;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
use tokio::sync::{broadcast, mpsc, oneshot, watch};
use tokio::task::JoinHandle;
use tracing::{debug, warn};
use crate::Claude;
use crate::error::{Error, Result};
pub const DEFAULT_SUBSCRIBER_CAPACITY: usize = 256;
#[derive(Debug, Clone)]
pub struct PermissionRequest {
pub request_id: String,
pub tool_name: String,
pub input: Value,
pub raw: Value,
}
#[derive(Debug, Clone)]
pub enum PermissionDecision {
Allow {
updated_input: Option<Value>,
},
Deny {
message: String,
},
Defer,
}
type PermissionFuture = Pin<Box<dyn Future<Output = PermissionDecision> + Send + 'static>>;
type PermissionFn = dyn Fn(PermissionRequest) -> PermissionFuture + Send + Sync + 'static;
#[derive(Clone)]
pub struct PermissionHandler {
inner: Arc<PermissionFn>,
}
impl PermissionHandler {
pub fn new<F, Fut>(f: F) -> Self
where
F: Fn(PermissionRequest) -> Fut + Send + Sync + 'static,
Fut: Future<Output = PermissionDecision> + Send + 'static,
{
Self {
inner: Arc::new(move |req| Box::pin(f(req))),
}
}
fn invoke(&self, req: PermissionRequest) -> PermissionFuture {
(self.inner)(req)
}
}
impl std::fmt::Debug for PermissionHandler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PermissionHandler").finish_non_exhaustive()
}
}
#[derive(Debug, Default, Clone)]
pub struct DuplexOptions {
model: Option<String>,
system_prompt: Option<String>,
append_system_prompt: Option<String>,
additional_args: Vec<String>,
subscriber_capacity: Option<usize>,
on_permission: Option<PermissionHandler>,
}
impl DuplexOptions {
#[must_use]
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
#[must_use]
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
#[must_use]
pub fn append_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.append_system_prompt = Some(prompt.into());
self
}
#[must_use]
pub fn arg(mut self, arg: impl Into<String>) -> Self {
self.additional_args.push(arg.into());
self
}
#[must_use]
pub fn subscriber_capacity(mut self, capacity: usize) -> Self {
self.subscriber_capacity = Some(capacity);
self
}
#[must_use]
pub fn on_permission(mut self, handler: PermissionHandler) -> Self {
self.on_permission = Some(handler);
self
}
fn into_args(self) -> Vec<String> {
let mut args = vec![
"--print".to_string(),
"--verbose".to_string(),
"--output-format".to_string(),
"stream-json".to_string(),
"--input-format".to_string(),
"stream-json".to_string(),
];
if let Some(m) = self.model {
args.push("--model".to_string());
args.push(m);
}
if let Some(p) = self.system_prompt {
args.push("--system-prompt".to_string());
args.push(p);
}
if let Some(p) = self.append_system_prompt {
args.push("--append-system-prompt".to_string());
args.push(p);
}
if self.on_permission.is_some() {
args.push("--permission-prompt-tool".to_string());
args.push("stdio".to_string());
}
args.extend(self.additional_args);
args
}
}
#[derive(Debug, Clone)]
pub struct TurnResult {
pub result: Value,
pub events: Vec<Value>,
}
impl TurnResult {
#[must_use]
pub fn result_text(&self) -> Option<&str> {
self.result.get("result").and_then(Value::as_str)
}
#[must_use]
pub fn session_id(&self) -> Option<&str> {
self.result.get("session_id").and_then(Value::as_str)
}
#[must_use]
pub fn total_cost_usd(&self) -> Option<f64> {
self.result
.get("total_cost_usd")
.or_else(|| self.result.get("cost_usd"))
.and_then(Value::as_f64)
}
#[must_use]
pub fn duration_ms(&self) -> Option<u64> {
self.result.get("duration_ms").and_then(Value::as_u64)
}
}
#[derive(Debug, Clone)]
pub enum InboundEvent {
SystemInit {
session_id: String,
},
Assistant(Value),
StreamEvent(Value),
User(Value),
Other(Value),
}
fn classify(msg: &Value) -> InboundEvent {
match msg.get("type").and_then(Value::as_str) {
Some("system") => {
if msg.get("subtype").and_then(Value::as_str) == Some("init")
&& let Some(id) = msg.get("session_id").and_then(Value::as_str)
{
return InboundEvent::SystemInit {
session_id: id.to_string(),
};
}
InboundEvent::Other(msg.clone())
}
Some("assistant") => InboundEvent::Assistant(msg.clone()),
Some("stream_event") => InboundEvent::StreamEvent(msg.clone()),
Some("user") => InboundEvent::User(msg.clone()),
_ => InboundEvent::Other(msg.clone()),
}
}
#[derive(Debug, Clone)]
pub enum SessionExitStatus {
Running,
Completed,
Failed(String),
}
#[derive(Debug)]
pub struct DuplexSession {
outbound_tx: mpsc::UnboundedSender<OutboundMsg>,
events_tx: broadcast::Sender<InboundEvent>,
exit_rx: watch::Receiver<SessionExitStatus>,
join: JoinHandle<Result<()>>,
}
#[derive(Debug)]
enum OutboundMsg {
Send {
prompt: String,
reply: oneshot::Sender<Result<TurnResult>>,
},
PermissionResponse {
request_id: String,
decision: PermissionDecision,
},
Interrupt {
reply: oneshot::Sender<Result<()>>,
},
}
impl DuplexSession {
pub async fn spawn(claude: &Claude, opts: DuplexOptions) -> Result<Self> {
let capacity = opts
.subscriber_capacity
.unwrap_or(DEFAULT_SUBSCRIBER_CAPACITY);
let permission_handler = opts.on_permission.clone();
let mut command_args = Vec::new();
command_args.extend(claude.global_args.clone());
command_args.extend(opts.into_args());
debug!(
binary = %claude.binary.display(),
args = ?command_args,
"spawning duplex claude session"
);
let mut cmd = Command::new(&claude.binary);
cmd.args(&command_args)
.env_remove("CLAUDECODE")
.env_remove("CLAUDE_CODE_ENTRYPOINT")
.envs(&claude.env)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.kill_on_drop(true);
if let Some(ref dir) = claude.working_dir {
cmd.current_dir(dir);
}
let mut child = cmd.spawn().map_err(|e| Error::Io {
message: format!("failed to spawn claude: {e}"),
source: e,
working_dir: claude.working_dir.clone(),
})?;
let stdin = child.stdin.take().expect("stdin was piped");
let stdout = child.stdout.take().expect("stdout was piped");
let (outbound_tx, outbound_rx) = mpsc::unbounded_channel();
let (events_tx, _initial_rx) = broadcast::channel(capacity);
let (exit_tx, exit_rx) = watch::channel(SessionExitStatus::Running);
let join = tokio::spawn(run_session(
child,
stdin,
stdout,
outbound_rx,
events_tx.clone(),
permission_handler,
exit_tx,
));
Ok(Self {
outbound_tx,
events_tx,
exit_rx,
join,
})
}
pub async fn send(&self, prompt: impl Into<String>) -> Result<TurnResult> {
let (reply_tx, reply_rx) = oneshot::channel();
self.outbound_tx
.send(OutboundMsg::Send {
prompt: prompt.into(),
reply: reply_tx,
})
.map_err(|_| Error::DuplexClosed)?;
reply_rx.await.map_err(|_| Error::DuplexClosed)?
}
#[must_use]
pub fn subscribe(&self) -> broadcast::Receiver<InboundEvent> {
self.events_tx.subscribe()
}
#[must_use]
pub fn is_alive(&self) -> bool {
matches!(*self.exit_rx.borrow(), SessionExitStatus::Running)
}
#[must_use]
pub fn exit_status(&self) -> SessionExitStatus {
self.exit_rx.borrow().clone()
}
pub async fn wait_for_exit(&self) -> SessionExitStatus {
let mut rx = self.exit_rx.clone();
loop {
{
let value = rx.borrow_and_update();
if !matches!(*value, SessionExitStatus::Running) {
return value.clone();
}
}
if rx.changed().await.is_err() {
return rx.borrow().clone();
}
}
}
pub fn respond_to_permission(
&self,
request_id: impl Into<String>,
decision: PermissionDecision,
) -> Result<()> {
if matches!(decision, PermissionDecision::Defer) {
warn!("respond_to_permission called with Defer; ignoring");
return Ok(());
}
self.outbound_tx
.send(OutboundMsg::PermissionResponse {
request_id: request_id.into(),
decision,
})
.map_err(|_| Error::DuplexClosed)?;
Ok(())
}
pub async fn interrupt(&self) -> Result<()> {
let (reply_tx, reply_rx) = oneshot::channel();
self.outbound_tx
.send(OutboundMsg::Interrupt { reply: reply_tx })
.map_err(|_| Error::DuplexClosed)?;
reply_rx.await.map_err(|_| Error::DuplexClosed)?
}
pub async fn close(self) -> Result<()> {
drop(self.outbound_tx);
drop(self.events_tx);
match self.join.await {
Ok(result) => result,
Err(e) if e.is_cancelled() => Ok(()),
Err(e) => Err(Error::Io {
message: format!("duplex session task panicked: {e}"),
source: std::io::Error::other(e.to_string()),
working_dir: None,
}),
}
}
}
const SHUTDOWN_BUDGET: Duration = Duration::from_secs(5);
async fn run_session(
mut child: Child,
mut stdin: ChildStdin,
stdout: ChildStdout,
mut outbound_rx: mpsc::UnboundedReceiver<OutboundMsg>,
events_tx: broadcast::Sender<InboundEvent>,
permission_handler: Option<PermissionHandler>,
exit_tx: watch::Sender<SessionExitStatus>,
) -> Result<()> {
let mut lines = BufReader::new(stdout).lines();
let mut pending: Option<(oneshot::Sender<Result<TurnResult>>, Vec<Value>)> = None;
let mut pending_control: HashMap<String, oneshot::Sender<Result<()>>> = HashMap::new();
let mut next_control_id: u64 = 0;
let mut stream_err: Option<Error> = None;
loop {
tokio::select! {
biased;
line = lines.next_line() => match line {
Ok(Some(l)) => {
if l.trim().is_empty() {
continue;
}
let parsed = match serde_json::from_str::<Value>(&l) {
Ok(v) => v,
Err(e) => {
debug!(line = %l, error = %e, "failed to parse duplex event, skipping");
continue;
}
};
match handle_inbound(parsed, &mut pending, &events_tx) {
InboundAction::None => {}
InboundAction::Permission(req) => {
let request_id = req.request_id.clone();
let decision = match permission_handler.as_ref() {
Some(h) => h.invoke(req).await,
None => {
warn!(
request_id = %request_id,
"received can_use_tool with no permission handler; auto-denying"
);
PermissionDecision::Deny {
message:
"no permission handler configured on duplex session"
.into(),
}
}
};
if matches!(decision, PermissionDecision::Defer) {
debug!(
request_id = %request_id,
"permission handler deferred; waiting for respond_to_permission"
);
} else if let Err(e) =
write_permission_response(&mut stdin, &request_id, &decision).await
{
warn!(error = %e, "failed to write permission response");
}
}
InboundAction::ControlResponse { request_id, outcome } => {
if let Some(reply) = pending_control.remove(&request_id) {
let _ = reply.send(outcome);
} else {
debug!(
request_id = %request_id,
"received control_response with no pending request"
);
}
}
}
}
Ok(None) => break,
Err(e) => {
stream_err = Some(Error::Io {
message: "failed to read duplex stdout".to_string(),
source: e,
working_dir: None,
});
break;
}
},
msg = outbound_rx.recv() => match msg {
Some(OutboundMsg::Send { prompt, reply }) => {
if pending.is_some() {
let _ = reply.send(Err(Error::DuplexTurnInFlight));
continue;
}
if let Err(e) = write_user(&mut stdin, &prompt).await {
let _ = reply.send(Err(e));
continue;
}
pending = Some((reply, Vec::new()));
}
Some(OutboundMsg::PermissionResponse { request_id, decision }) => {
if let Err(e) =
write_permission_response(&mut stdin, &request_id, &decision).await
{
warn!(error = %e, "failed to write deferred permission response");
}
}
Some(OutboundMsg::Interrupt { reply }) => {
next_control_id += 1;
let request_id = format!("interrupt-{next_control_id}");
if let Err(e) =
write_control_request(&mut stdin, &request_id, "interrupt").await
{
let _ = reply.send(Err(e));
continue;
}
pending_control.insert(request_id, reply);
}
None => break,
},
}
}
drop(stdin);
match tokio::time::timeout(SHUTDOWN_BUDGET, child.wait()).await {
Ok(Ok(_status)) => {}
Ok(Err(e)) => {
warn!(error = %e, "failed to wait for duplex child");
}
Err(_) => {
warn!("duplex child did not exit within shutdown budget; killing");
let _ = child.kill().await;
}
}
if let Some((reply, _)) = pending.take() {
let _ = reply.send(Err(Error::DuplexClosed));
}
for (_, reply) in pending_control.drain() {
let _ = reply.send(Err(Error::DuplexClosed));
}
let result = match stream_err {
Some(e) => Err(e),
None => Ok(()),
};
let final_state = match &result {
Ok(()) => SessionExitStatus::Completed,
Err(e) => SessionExitStatus::Failed(e.to_string()),
};
let _ = exit_tx.send(final_state);
result
}
enum InboundAction {
None,
Permission(PermissionRequest),
ControlResponse {
request_id: String,
outcome: Result<()>,
},
}
fn handle_inbound(
msg: Value,
pending: &mut Option<(oneshot::Sender<Result<TurnResult>>, Vec<Value>)>,
events_tx: &broadcast::Sender<InboundEvent>,
) -> InboundAction {
match msg.get("type").and_then(Value::as_str) {
Some("result") => {
if let Some((reply, events)) = pending.take() {
let _ = reply.send(Ok(TurnResult {
result: msg,
events,
}));
} else {
debug!("dropping orphan result event with no pending turn");
}
InboundAction::None
}
Some("control_request") => {
if msg
.get("request")
.and_then(|r| r.get("subtype"))
.and_then(Value::as_str)
== Some("can_use_tool")
&& let Some(req) = parse_permission_request(&msg)
{
if let Some((_, events)) = pending.as_mut() {
events.push(msg);
}
return InboundAction::Permission(req);
}
debug!(
?msg,
"received unhandled control_request; treating as Other"
);
let _ = events_tx.send(InboundEvent::Other(msg.clone()));
if let Some((_, events)) = pending.as_mut() {
events.push(msg);
}
InboundAction::None
}
Some("control_response") => {
if let Some((request_id, outcome)) = parse_control_response(&msg) {
return InboundAction::ControlResponse {
request_id,
outcome,
};
}
debug!(
?msg,
"received malformed control_response; treating as Other"
);
let _ = events_tx.send(InboundEvent::Other(msg.clone()));
if let Some((_, events)) = pending.as_mut() {
events.push(msg);
}
InboundAction::None
}
_ => {
let _ = events_tx.send(classify(&msg));
if let Some((_, events)) = pending.as_mut() {
events.push(msg);
} else {
debug!("dropping inbound event with no pending turn");
}
InboundAction::None
}
}
}
fn parse_permission_request(msg: &Value) -> Option<PermissionRequest> {
let request_id = msg.get("request_id").and_then(Value::as_str)?;
let request = msg.get("request")?;
let tool_name = request.get("tool_name").and_then(Value::as_str)?;
let input = request.get("input").cloned().unwrap_or(Value::Null);
Some(PermissionRequest {
request_id: request_id.to_string(),
tool_name: tool_name.to_string(),
input,
raw: request.clone(),
})
}
fn parse_control_response(msg: &Value) -> Option<(String, Result<()>)> {
let response = msg.get("response")?;
let request_id = response.get("request_id").and_then(Value::as_str)?;
let outcome = match response.get("subtype").and_then(Value::as_str) {
Some("success") => Ok(()),
Some("error") => {
let message = response
.get("error")
.and_then(Value::as_str)
.unwrap_or("unknown control_response error")
.to_string();
Err(Error::DuplexControlFailed { message })
}
_ => return None,
};
Some((request_id.to_string(), outcome))
}
async fn write_user(stdin: &mut ChildStdin, prompt: &str) -> Result<()> {
let user_msg = serde_json::json!({
"type": "user",
"message": {
"role": "user",
"content": prompt,
},
"parent_tool_use_id": null,
});
write_line(stdin, &user_msg, "user message").await
}
async fn write_control_request(
stdin: &mut ChildStdin,
request_id: &str,
subtype: &str,
) -> Result<()> {
let envelope = serde_json::json!({
"type": "control_request",
"request_id": request_id,
"request": { "subtype": subtype },
});
write_line(stdin, &envelope, "control_request").await
}
async fn write_permission_response(
stdin: &mut ChildStdin,
request_id: &str,
decision: &PermissionDecision,
) -> Result<()> {
let inner = match decision {
PermissionDecision::Allow { updated_input } => {
let mut obj = serde_json::Map::new();
obj.insert("behavior".to_string(), Value::String("allow".to_string()));
if let Some(input) = updated_input {
obj.insert("updatedInput".to_string(), input.clone());
}
Value::Object(obj)
}
PermissionDecision::Deny { message } => serde_json::json!({
"behavior": "deny",
"message": message,
}),
PermissionDecision::Defer => {
return Ok(());
}
};
let envelope = serde_json::json!({
"type": "control_response",
"response": {
"request_id": request_id,
"subtype": "success",
"response": inner,
},
});
write_line(stdin, &envelope, "control_response").await
}
async fn write_line(stdin: &mut ChildStdin, value: &Value, what: &'static str) -> Result<()> {
let mut line = serde_json::to_string(value).map_err(|e| Error::Json {
message: format!("failed to serialize duplex {what}"),
source: e,
})?;
line.push('\n');
stdin
.write_all(line.as_bytes())
.await
.map_err(|e| Error::Io {
message: format!("failed to write {what} to duplex stdin"),
source: e,
working_dir: None,
})?;
stdin.flush().await.map_err(|e| Error::Io {
message: "failed to flush duplex stdin".to_string(),
source: e,
working_dir: None,
})?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn into_args_default_includes_required_flags() {
let args = DuplexOptions::default().into_args();
assert!(args.contains(&"--print".to_string()));
assert!(args.contains(&"--verbose".to_string()));
assert!(
args.windows(2)
.any(|w| w == ["--output-format", "stream-json"])
);
assert!(
args.windows(2)
.any(|w| w == ["--input-format", "stream-json"])
);
}
#[test]
fn into_args_includes_model() {
let args = DuplexOptions::default().model("haiku").into_args();
assert!(args.windows(2).any(|w| w == ["--model", "haiku"]));
}
#[test]
fn into_args_includes_system_prompts() {
let args = DuplexOptions::default()
.system_prompt("be concise")
.append_system_prompt("also polite")
.into_args();
assert!(
args.windows(2)
.any(|w| w == ["--system-prompt", "be concise"])
);
assert!(
args.windows(2)
.any(|w| w == ["--append-system-prompt", "also polite"])
);
}
#[test]
fn into_args_appends_raw_args_last() {
let args = DuplexOptions::default()
.arg("--add-dir")
.arg("/tmp/foo")
.into_args();
assert_eq!(&args[args.len() - 2..], &["--add-dir", "/tmp/foo"]);
}
#[test]
fn turn_result_accessors_pull_from_result() {
let r = TurnResult {
result: json!({
"type": "result",
"result": "hello",
"session_id": "sess-123",
"total_cost_usd": 0.0042,
"duration_ms": 1234_u64,
}),
events: vec![],
};
assert_eq!(r.result_text(), Some("hello"));
assert_eq!(r.session_id(), Some("sess-123"));
assert_eq!(r.total_cost_usd(), Some(0.0042));
assert_eq!(r.duration_ms(), Some(1234));
}
#[test]
fn turn_result_total_cost_falls_back_to_legacy_field() {
let r = TurnResult {
result: json!({ "cost_usd": 0.5 }),
events: vec![],
};
assert_eq!(r.total_cost_usd(), Some(0.5));
}
#[test]
fn turn_result_accessors_return_none_when_missing() {
let r = TurnResult {
result: json!({}),
events: vec![],
};
assert_eq!(r.result_text(), None);
assert_eq!(r.session_id(), None);
assert_eq!(r.total_cost_usd(), None);
assert_eq!(r.duration_ms(), None);
}
#[test]
fn handle_inbound_appends_non_result_to_pending_events() {
let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
let (events_tx, _events_rx) = broadcast::channel(16);
let mut pending = Some((tx, Vec::new()));
handle_inbound(
json!({ "type": "assistant", "message": {} }),
&mut pending,
&events_tx,
);
let (_, events) = pending.as_ref().unwrap();
assert_eq!(events.len(), 1);
assert_eq!(
events[0].get("type").and_then(Value::as_str),
Some("assistant")
);
}
#[test]
fn handle_inbound_resolves_pending_on_result() {
let (tx, rx) = oneshot::channel::<Result<TurnResult>>();
let (events_tx, _events_rx) = broadcast::channel(16);
let mut pending = Some((tx, vec![json!({ "type": "assistant" })]));
handle_inbound(
json!({ "type": "result", "result": "ok" }),
&mut pending,
&events_tx,
);
assert!(pending.is_none());
let received = rx.blocking_recv().unwrap().unwrap();
assert_eq!(received.result_text(), Some("ok"));
assert_eq!(received.events.len(), 1);
}
#[test]
fn handle_inbound_drops_orphans_without_pending_turn() {
let (events_tx, _events_rx) = broadcast::channel(16);
let mut pending: Option<(oneshot::Sender<Result<TurnResult>>, Vec<Value>)> = None;
handle_inbound(json!({ "type": "assistant" }), &mut pending, &events_tx);
handle_inbound(
json!({ "type": "result", "result": "ok" }),
&mut pending,
&events_tx,
);
assert!(pending.is_none());
}
#[test]
fn handle_inbound_broadcasts_classified_event() {
let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
let (events_tx, mut events_rx) = broadcast::channel(16);
let mut pending = Some((tx, Vec::new()));
handle_inbound(
json!({ "type": "assistant", "message": { "role": "assistant" } }),
&mut pending,
&events_tx,
);
let event = events_rx.try_recv().expect("classified event broadcast");
assert!(matches!(event, InboundEvent::Assistant(_)));
}
#[test]
fn handle_inbound_does_not_broadcast_result() {
let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
let (events_tx, mut events_rx) = broadcast::channel(16);
let mut pending = Some((tx, Vec::new()));
handle_inbound(
json!({ "type": "result", "result": "ok" }),
&mut pending,
&events_tx,
);
assert!(events_rx.try_recv().is_err());
}
#[test]
fn classify_system_init_pulls_session_id() {
let v = json!({
"type": "system",
"subtype": "init",
"session_id": "sess-abc",
});
match classify(&v) {
InboundEvent::SystemInit { session_id } => assert_eq!(session_id, "sess-abc"),
other => panic!("expected SystemInit, got {other:?}"),
}
}
#[test]
fn classify_system_without_init_subtype_is_other() {
let v = json!({ "type": "system", "subtype": "compaction" });
assert!(matches!(classify(&v), InboundEvent::Other(_)));
}
#[test]
fn classify_system_init_without_session_id_is_other() {
let v = json!({ "type": "system", "subtype": "init" });
assert!(matches!(classify(&v), InboundEvent::Other(_)));
}
#[test]
fn classify_assistant_stream_event_user() {
assert!(matches!(
classify(&json!({ "type": "assistant" })),
InboundEvent::Assistant(_)
));
assert!(matches!(
classify(&json!({ "type": "stream_event" })),
InboundEvent::StreamEvent(_)
));
assert!(matches!(
classify(&json!({ "type": "user" })),
InboundEvent::User(_)
));
}
#[test]
fn classify_unknown_type_is_other() {
assert!(matches!(
classify(&json!({ "type": "control_request" })),
InboundEvent::Other(_)
));
assert!(matches!(
classify(&json!({ "type": "future_thing" })),
InboundEvent::Other(_)
));
assert!(matches!(classify(&json!({})), InboundEvent::Other(_)));
}
#[test]
fn into_args_does_not_emit_subscriber_capacity_flag() {
let args = DuplexOptions::default().subscriber_capacity(64).into_args();
assert!(!args.iter().any(|a| a.contains("subscriber")));
assert!(!args.iter().any(|a| a.contains("capacity")));
}
#[test]
fn into_args_includes_permission_prompt_tool_when_handler_set() {
let handler = PermissionHandler::new(|_req| async move {
PermissionDecision::Allow {
updated_input: None,
}
});
let args = DuplexOptions::default().on_permission(handler).into_args();
assert!(
args.windows(2)
.any(|w| w == ["--permission-prompt-tool", "stdio"])
);
}
#[test]
fn into_args_omits_permission_prompt_tool_without_handler() {
let args = DuplexOptions::default().into_args();
assert!(!args.iter().any(|a| a == "--permission-prompt-tool"));
}
#[test]
fn parse_permission_request_extracts_fields() {
let msg = json!({
"type": "control_request",
"request_id": "req-1",
"request": {
"subtype": "can_use_tool",
"tool_name": "Bash",
"input": { "command": "ls" }
}
});
let req = parse_permission_request(&msg).expect("permission request");
assert_eq!(req.request_id, "req-1");
assert_eq!(req.tool_name, "Bash");
assert_eq!(req.input, json!({ "command": "ls" }));
assert_eq!(
req.raw.get("subtype").and_then(Value::as_str),
Some("can_use_tool")
);
}
#[test]
fn parse_permission_request_returns_none_when_missing_request_id() {
let msg = json!({
"type": "control_request",
"request": {
"subtype": "can_use_tool",
"tool_name": "Bash",
}
});
assert!(parse_permission_request(&msg).is_none());
}
#[test]
fn parse_permission_request_returns_none_when_missing_tool_name() {
let msg = json!({
"type": "control_request",
"request_id": "req-1",
"request": { "subtype": "can_use_tool" }
});
assert!(parse_permission_request(&msg).is_none());
}
#[test]
fn parse_permission_request_handles_missing_input() {
let msg = json!({
"type": "control_request",
"request_id": "req-1",
"request": {
"subtype": "can_use_tool",
"tool_name": "Bash",
}
});
let req = parse_permission_request(&msg).expect("request");
assert_eq!(req.input, Value::Null);
}
#[test]
fn handle_inbound_returns_permission_for_can_use_tool() {
let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
let (events_tx, _events_rx) = broadcast::channel(16);
let mut pending = Some((tx, Vec::new()));
let action = handle_inbound(
json!({
"type": "control_request",
"request_id": "req-1",
"request": {
"subtype": "can_use_tool",
"tool_name": "Bash",
"input": { "command": "ls" }
}
}),
&mut pending,
&events_tx,
);
match action {
InboundAction::Permission(req) => {
assert_eq!(req.request_id, "req-1");
assert_eq!(req.tool_name, "Bash");
}
InboundAction::None | InboundAction::ControlResponse { .. } => {
panic!("expected Permission action");
}
}
let (_, events) = pending.as_ref().unwrap();
assert_eq!(events.len(), 1);
}
#[test]
fn handle_inbound_treats_unknown_control_request_as_other() {
let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
let (events_tx, mut events_rx) = broadcast::channel(16);
let mut pending = Some((tx, Vec::new()));
let action = handle_inbound(
json!({
"type": "control_request",
"request_id": "req-2",
"request": { "subtype": "future_subtype" }
}),
&mut pending,
&events_tx,
);
assert!(matches!(action, InboundAction::None));
let event = events_rx.try_recv().expect("broadcast");
assert!(matches!(event, InboundEvent::Other(_)));
}
#[tokio::test]
async fn permission_handler_invokes_closure_async() {
let handler = PermissionHandler::new(|req| async move {
if req.tool_name == "Bash" {
PermissionDecision::Deny {
message: "no bash".into(),
}
} else {
PermissionDecision::Allow {
updated_input: None,
}
}
});
let req = PermissionRequest {
request_id: "r1".into(),
tool_name: "Bash".into(),
input: Value::Null,
raw: Value::Null,
};
match handler.invoke(req).await {
PermissionDecision::Deny { message } => assert_eq!(message, "no bash"),
other => panic!("expected Deny, got {other:?}"),
}
}
#[test]
fn parse_control_response_extracts_success() {
let msg = json!({
"type": "control_response",
"response": {
"request_id": "interrupt-1",
"subtype": "success",
"response": {}
}
});
let (id, outcome) = parse_control_response(&msg).expect("parsed");
assert_eq!(id, "interrupt-1");
assert!(outcome.is_ok());
}
#[test]
fn parse_control_response_extracts_error_with_message() {
let msg = json!({
"type": "control_response",
"response": {
"request_id": "interrupt-2",
"subtype": "error",
"error": "no turn in flight"
}
});
let (id, outcome) = parse_control_response(&msg).expect("parsed");
assert_eq!(id, "interrupt-2");
match outcome {
Err(Error::DuplexControlFailed { message }) => {
assert_eq!(message, "no turn in flight");
}
other => panic!("expected DuplexControlFailed, got {other:?}"),
}
}
#[test]
fn parse_control_response_returns_none_on_missing_request_id() {
let msg = json!({
"type": "control_response",
"response": { "subtype": "success" }
});
assert!(parse_control_response(&msg).is_none());
}
#[test]
fn parse_control_response_returns_none_on_unknown_subtype() {
let msg = json!({
"type": "control_response",
"response": { "request_id": "x", "subtype": "future_subtype" }
});
assert!(parse_control_response(&msg).is_none());
}
#[test]
fn handle_inbound_returns_control_response_action() {
let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
let (events_tx, _events_rx) = broadcast::channel(16);
let mut pending = Some((tx, Vec::new()));
let action = handle_inbound(
json!({
"type": "control_response",
"response": {
"request_id": "interrupt-1",
"subtype": "success",
"response": {}
}
}),
&mut pending,
&events_tx,
);
match action {
InboundAction::ControlResponse {
request_id,
outcome,
} => {
assert_eq!(request_id, "interrupt-1");
assert!(outcome.is_ok());
}
InboundAction::None | InboundAction::Permission(_) => {
panic!("expected ControlResponse action");
}
}
}
#[test]
fn handle_inbound_treats_malformed_control_response_as_other() {
let (tx, _reply_rx) = oneshot::channel::<Result<TurnResult>>();
let (events_tx, mut events_rx) = broadcast::channel(16);
let mut pending = Some((tx, Vec::new()));
let action = handle_inbound(
json!({
"type": "control_response",
"response": { "subtype": "success" }
}),
&mut pending,
&events_tx,
);
assert!(matches!(action, InboundAction::None));
let event = events_rx.try_recv().expect("broadcast");
assert!(matches!(event, InboundEvent::Other(_)));
}
#[tokio::test]
async fn permission_handler_clones_arc() {
let handler = PermissionHandler::new(|_req| async move {
PermissionDecision::Allow {
updated_input: None,
}
});
let cloned = handler.clone();
let req = PermissionRequest {
request_id: "r1".into(),
tool_name: "Read".into(),
input: Value::Null,
raw: Value::Null,
};
let _ = handler.invoke(req.clone()).await;
let _ = cloned.invoke(req).await;
}
fn fake_session(
initial: SessionExitStatus,
) -> (
DuplexSession,
watch::Sender<SessionExitStatus>,
oneshot::Sender<()>,
) {
let (outbound_tx, outbound_rx) = mpsc::unbounded_channel::<OutboundMsg>();
let (events_tx, _events_rx) = broadcast::channel::<InboundEvent>(16);
let (exit_tx, exit_rx) = watch::channel(initial);
let (stop_tx, stop_rx) = oneshot::channel::<()>();
let join = tokio::spawn(async move {
let _outbound_rx = outbound_rx;
let _ = stop_rx.await;
Ok::<(), Error>(())
});
let session = DuplexSession {
outbound_tx,
events_tx,
exit_rx,
join,
};
(session, exit_tx, stop_tx)
}
#[tokio::test]
async fn is_alive_true_while_running() {
let (session, _exit_tx, _stop) = fake_session(SessionExitStatus::Running);
assert!(session.is_alive());
}
#[tokio::test]
async fn is_alive_false_after_completed() {
let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
exit_tx.send(SessionExitStatus::Completed).unwrap();
assert!(!session.is_alive());
}
#[tokio::test]
async fn is_alive_false_after_failed() {
let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
exit_tx
.send(SessionExitStatus::Failed("boom".into()))
.unwrap();
assert!(!session.is_alive());
}
#[tokio::test]
async fn exit_status_reports_running_initially() {
let (session, _exit_tx, _stop) = fake_session(SessionExitStatus::Running);
assert!(matches!(session.exit_status(), SessionExitStatus::Running));
}
#[tokio::test]
async fn exit_status_reflects_completed() {
let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
exit_tx.send(SessionExitStatus::Completed).unwrap();
assert!(matches!(
session.exit_status(),
SessionExitStatus::Completed
));
}
#[tokio::test]
async fn exit_status_reflects_failed_with_message() {
let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
exit_tx
.send(SessionExitStatus::Failed("oh no".into()))
.unwrap();
match session.exit_status() {
SessionExitStatus::Failed(msg) => assert_eq!(msg, "oh no"),
other => panic!("expected Failed, got {other:?}"),
}
}
#[tokio::test]
async fn wait_for_exit_returns_immediately_when_already_terminal() {
let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
exit_tx.send(SessionExitStatus::Completed).unwrap();
let status = tokio::time::timeout(Duration::from_secs(1), session.wait_for_exit())
.await
.expect("wait_for_exit should not block when already terminal");
assert!(matches!(status, SessionExitStatus::Completed));
}
#[tokio::test]
async fn wait_for_exit_blocks_until_state_transitions() {
let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
let waiter = async { session.wait_for_exit().await };
let driver = async {
tokio::time::sleep(Duration::from_millis(20)).await;
exit_tx.send(SessionExitStatus::Completed).unwrap();
};
let (status, ()) = tokio::join!(waiter, driver);
assert!(matches!(status, SessionExitStatus::Completed));
}
#[tokio::test]
async fn wait_for_exit_supports_multiple_observers() {
let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
let waiter1 = async { session.wait_for_exit().await };
let waiter2 = async { session.wait_for_exit().await };
let driver = async {
tokio::time::sleep(Duration::from_millis(20)).await;
exit_tx
.send(SessionExitStatus::Failed("crash".into()))
.unwrap();
};
let (s1, s2, ()) = tokio::join!(waiter1, waiter2, driver);
match s1 {
SessionExitStatus::Failed(msg) => assert_eq!(msg, "crash"),
other => panic!("waiter1 expected Failed, got {other:?}"),
}
match s2 {
SessionExitStatus::Failed(msg) => assert_eq!(msg, "crash"),
other => panic!("waiter2 expected Failed, got {other:?}"),
}
}
#[tokio::test]
async fn wait_for_exit_returns_last_value_when_sender_dropped() {
let (session, exit_tx, _stop) = fake_session(SessionExitStatus::Running);
let waiter = async { session.wait_for_exit().await };
let driver = async {
tokio::time::sleep(Duration::from_millis(20)).await;
drop(exit_tx);
};
let (status, ()) = tokio::time::timeout(Duration::from_secs(1), async {
tokio::join!(waiter, driver)
})
.await
.expect("wait_for_exit must not hang when sender is dropped");
assert!(matches!(status, SessionExitStatus::Running));
}
}