use crate::product::agent::function_tool::FunctionCallError;
use crate::product::agent::is_safe_command::is_known_safe_command;
use crate::product::agent::protocol::EventMsg;
use crate::product::agent::protocol::TerminalInteractionEvent;
use crate::product::agent::sandboxing::SandboxPermissions;
use crate::product::agent::shell::Shell;
use crate::product::agent::shell::get_shell_by_model_provided_path;
use crate::product::agent::tools::context::ToolInvocation;
use crate::product::agent::tools::context::ToolOutput;
use crate::product::agent::tools::context::ToolPayload;
use crate::product::agent::tools::handlers::apply_patch::intercept_apply_patch;
use crate::product::agent::tools::handlers::parse_arguments;
use crate::product::agent::tools::registry::ToolHandler;
use crate::product::agent::tools::registry::ToolKind;
use crate::product::agent::unified_exec::ExecCommandRequest;
use crate::product::agent::unified_exec::UnifiedExecContext;
use crate::product::agent::unified_exec::UnifiedExecProcessManager;
use crate::product::agent::unified_exec::UnifiedExecResponse;
use crate::product::agent::unified_exec::WriteStdinRequest;
use async_trait::async_trait;
use serde::Deserialize;
use std::path::PathBuf;
use std::sync::Arc;
pub struct UnifiedExecHandler;
#[derive(Debug, Deserialize)]
struct ExecCommandArgs {
cmd: String,
#[serde(default)]
workdir: Option<String>,
#[serde(default)]
shell: Option<String>,
#[serde(default = "default_login")]
login: bool,
#[serde(default = "default_tty")]
tty: bool,
#[serde(default = "default_exec_yield_time_ms")]
yield_time_ms: u64,
#[serde(default)]
max_output_tokens: Option<usize>,
#[serde(default)]
sandbox_permissions: SandboxPermissions,
#[serde(default)]
justification: Option<String>,
#[serde(default)]
prefix_rule: Option<Vec<String>>,
}
#[derive(Debug, Deserialize)]
struct WriteStdinArgs {
session_id: i32,
#[serde(default)]
chars: String,
#[serde(default = "default_write_stdin_yield_time_ms")]
yield_time_ms: u64,
#[serde(default)]
max_output_tokens: Option<usize>,
}
fn default_exec_yield_time_ms() -> u64 {
10000
}
fn default_write_stdin_yield_time_ms() -> u64 {
250
}
fn default_login() -> bool {
true
}
fn default_tty() -> bool {
false
}
#[async_trait]
impl ToolHandler for UnifiedExecHandler {
fn kind(&self) -> ToolKind {
ToolKind::Function
}
fn matches_kind(&self, payload: &ToolPayload) -> bool {
matches!(payload, ToolPayload::Function { .. })
}
async fn is_mutating(&self, invocation: &ToolInvocation) -> bool {
let ToolPayload::Function { arguments } = &invocation.payload else {
tracing::error!(
"This should never happen, invocation payload is wrong: {:?}",
invocation.payload
);
return true;
};
let Ok(params) = serde_json::from_str::<ExecCommandArgs>(arguments) else {
return true;
};
let command = get_command(¶ms, invocation.session.user_shell());
!is_known_safe_command(&command)
}
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
turn,
tracker,
call_id,
tool_name,
payload,
..
} = invocation;
let arguments = match payload {
ToolPayload::Function { arguments } => arguments,
_ => {
return Err(FunctionCallError::RespondToModel(
"unified_exec handler received unsupported payload".to_string(),
));
}
};
let manager: &UnifiedExecProcessManager = &session.services.unified_exec_manager;
let context = UnifiedExecContext::new(session.clone(), turn.clone(), call_id.clone());
let (response, polled_output) = match tool_name.as_str() {
"exec_command" => {
let args: ExecCommandArgs = parse_arguments(&arguments)?;
validate_exec_command_args(&args)?;
let process_id = manager.allocate_process_id().await;
let command = get_command(&args, session.user_shell());
let ExecCommandArgs {
workdir,
tty,
yield_time_ms,
max_output_tokens,
sandbox_permissions,
justification,
prefix_rule,
..
} = args;
let features = session.features();
let request_rule_enabled =
features.enabled(crate::product::agent::features::Feature::RequestRule);
let prefix_rule = if request_rule_enabled {
prefix_rule
} else {
None
};
if sandbox_permissions.requires_escalated_permissions()
&& !matches!(
context.turn.approval_policy,
crate::product::protocol::protocol::AskForApproval::OnRequest
)
{
let approval_policy = context.turn.approval_policy;
manager.release_process_id(&process_id).await;
return Err(FunctionCallError::RespondToModel(format!(
"approval policy is {approval_policy:?}; reject command — you cannot ask for escalated permissions if the approval policy is {approval_policy:?}"
)));
}
let workdir = workdir.filter(|value| !value.is_empty());
let workdir = workdir.map(|dir| context.turn.resolve_path(Some(dir)));
let cwd = workdir.clone().unwrap_or_else(|| context.turn.cwd.clone());
if let Some(output) = intercept_apply_patch(
&command,
&cwd,
Some(yield_time_ms),
context.session.as_ref(),
context.turn.as_ref(),
Some(&tracker),
&context.call_id,
tool_name.as_str(),
)
.await?
{
manager.release_process_id(&process_id).await;
return Ok(output);
}
(
manager
.exec_command(
ExecCommandRequest {
command,
process_id,
yield_time_ms,
max_output_tokens,
workdir,
tty,
sandbox_permissions,
justification,
prefix_rule,
},
&context,
)
.await
.map_err(|err| {
FunctionCallError::RespondToModel(format!(
"exec_command failed: {err:?}"
))
})?,
false,
)
}
"write_stdin" => {
let args: WriteStdinArgs = parse_arguments(&arguments)?;
let polled_output = args.chars.is_empty();
let response = manager
.write_stdin(WriteStdinRequest {
process_id: &args.session_id.to_string(),
input: &args.chars,
yield_time_ms: args.yield_time_ms,
max_output_tokens: args.max_output_tokens,
})
.await
.map_err(|err| {
FunctionCallError::RespondToModel(format!("write_stdin failed: {err}"))
})?;
let interaction = TerminalInteractionEvent {
call_id: response.event_call_id.clone(),
process_id: args.session_id.to_string(),
stdin: args.chars.clone(),
};
session
.send_event(turn.as_ref(), EventMsg::TerminalInteraction(interaction))
.await;
(response, polled_output)
}
other => {
return Err(FunctionCallError::RespondToModel(format!(
"unsupported unified exec function {other}"
)));
}
};
let content = format_response(&response, polled_output);
Ok(ToolOutput::Function {
content,
content_items: None,
success: Some(true),
})
}
}
fn validate_exec_command_args(args: &ExecCommandArgs) -> Result<(), FunctionCallError> {
if args.cmd.trim().is_empty() {
Err(FunctionCallError::RespondToModel(
"cmd must be non-empty".to_string(),
))
} else {
Ok(())
}
}
fn get_command(args: &ExecCommandArgs, session_shell: Arc<Shell>) -> Vec<String> {
let model_shell = args.shell.as_ref().map(|shell_str| {
let mut shell = get_shell_by_model_provided_path(&PathBuf::from(shell_str));
shell.shell_snapshot = crate::product::agent::shell::empty_shell_snapshot_receiver();
shell
});
let shell = model_shell.as_ref().unwrap_or(session_shell.as_ref());
shell.derive_exec_args(&args.cmd, args.login)
}
fn format_response(response: &UnifiedExecResponse, polled_output: bool) -> String {
let mut sections = Vec::new();
if !response.chunk_id.is_empty() {
sections.push(format!("Chunk ID: {}", response.chunk_id));
}
if polled_output {
sections.push("Polled process output without writing to stdin.".to_string());
}
let wall_time_seconds = response.wall_time.as_secs_f64();
sections.push(format!("Wall time: {wall_time_seconds:.4} seconds"));
if let Some(exit_code) = response.exit_code {
sections.push(format!("Process exited with code {exit_code}"));
}
if let Some(process_id) = &response.process_id {
sections.push(format!("Process running with session ID {process_id}"));
}
if let Some(original_token_count) = response.original_token_count {
sections.push(format!("Original token count: {original_token_count}"));
}
sections.push("Output:".to_string());
sections.push(response.output.clone());
sections.join("\n")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::product::agent::shell::default_user_shell;
use std::sync::Arc;
use std::time::Duration;
#[test]
fn validate_exec_command_args_rejects_empty_cmd() -> anyhow::Result<()> {
let args: ExecCommandArgs = parse_arguments(r#"{"cmd": ""}"#)?;
assert_eq!(
validate_exec_command_args(&args),
Err(FunctionCallError::RespondToModel(
"cmd must be non-empty".to_string()
))
);
Ok(())
}
#[test]
fn validate_exec_command_args_rejects_whitespace_cmd() -> anyhow::Result<()> {
let args: ExecCommandArgs = parse_arguments("{\"cmd\": \" \\n\\t\"}")?;
assert_eq!(
validate_exec_command_args(&args),
Err(FunctionCallError::RespondToModel(
"cmd must be non-empty".to_string()
))
);
Ok(())
}
#[test]
fn validate_exec_command_args_accepts_non_empty_cmd() -> anyhow::Result<()> {
let args: ExecCommandArgs = parse_arguments(r#"{"cmd": "echo hello"}"#)?;
assert_eq!(validate_exec_command_args(&args), Ok(()));
Ok(())
}
#[test]
fn format_response_marks_empty_write_stdin_as_polling() {
let response = UnifiedExecResponse {
event_call_id: "call".to_string(),
chunk_id: "chunk".to_string(),
wall_time: Duration::from_millis(25),
output: "still running".to_string(),
raw_output: b"still running".to_vec(),
process_id: Some("123".to_string()),
exit_code: None,
original_token_count: Some(2),
session_command: None,
};
let output = format_response(&response, true);
assert!(output.contains("Polled process output without writing to stdin."));
assert!(output.contains("Output:\nstill running"));
}
#[test]
fn format_response_omits_polling_marker_for_stdin_writes() {
let response = UnifiedExecResponse {
event_call_id: "call".to_string(),
chunk_id: "chunk".to_string(),
wall_time: Duration::from_millis(25),
output: "ok".to_string(),
raw_output: b"ok".to_vec(),
process_id: Some("123".to_string()),
exit_code: None,
original_token_count: Some(1),
session_command: None,
};
let output = format_response(&response, false);
assert!(!output.contains("Polled process output"));
}
#[test]
fn test_get_command_uses_default_shell_when_unspecified() -> anyhow::Result<()> {
let json = r#"{"cmd": "echo hello"}"#;
let args: ExecCommandArgs = parse_arguments(json)?;
assert!(args.shell.is_none());
let command = get_command(&args, Arc::new(default_user_shell()));
assert_eq!(command.len(), 3);
assert_eq!(command[2], "echo hello");
Ok(())
}
#[test]
fn test_get_command_respects_explicit_bash_shell() -> anyhow::Result<()> {
let json = r#"{"cmd": "echo hello", "shell": "/bin/bash"}"#;
let args: ExecCommandArgs = parse_arguments(json)?;
assert_eq!(args.shell.as_deref(), Some("/bin/bash"));
let command = get_command(&args, Arc::new(default_user_shell()));
assert_eq!(command.last(), Some(&"echo hello".to_string()));
if command
.iter()
.any(|arg| arg.eq_ignore_ascii_case("-Command"))
{
assert!(command.contains(&"-NoProfile".to_string()));
}
Ok(())
}
#[test]
fn test_get_command_respects_explicit_powershell_shell() -> anyhow::Result<()> {
let json = r#"{"cmd": "echo hello", "shell": "powershell"}"#;
let args: ExecCommandArgs = parse_arguments(json)?;
assert_eq!(args.shell.as_deref(), Some("powershell"));
let command = get_command(&args, Arc::new(default_user_shell()));
assert_eq!(command[2], "echo hello");
Ok(())
}
#[test]
fn test_get_command_respects_explicit_cmd_shell() -> anyhow::Result<()> {
let json = r#"{"cmd": "echo hello", "shell": "cmd"}"#;
let args: ExecCommandArgs = parse_arguments(json)?;
assert_eq!(args.shell.as_deref(), Some("cmd"));
let command = get_command(&args, Arc::new(default_user_shell()));
assert_eq!(command[2], "echo hello");
Ok(())
}
}