use std::collections::BTreeMap;
use std::sync::atomic::Ordering;
use anyhow::{Context, Result};
use uuid::Uuid;
use agent_os_sidecar::protocol::{
EventPayload, ExecuteRequest, KillProcessRequest, OwnershipScope, ProcessStartedResponse,
RejectedResponse, RequestPayload, ResponsePayload, StreamChannel, WriteStdinRequest,
};
use crate::agent_os::{AgentOs, ShellEntry};
use crate::error::ClientError;
use crate::process::{install_output_callback, OutputCallback, StdinInput};
use crate::stream::ByteStream;
const SHELL_DATA_CHANNEL_CAPACITY: usize = 1024;
const DEFAULT_SHELL_COMMAND: &str = "sh";
#[derive(Default)]
pub struct OpenShellOptions {
pub command: Option<String>,
pub args: Vec<String>,
pub env: BTreeMap<String, String>,
pub cwd: Option<String>,
pub cols: Option<u16>,
pub rows: Option<u16>,
pub on_stderr: Option<OutputCallback>,
}
#[derive(Default)]
pub struct ConnectTerminalOptions {
pub base: OpenShellOptions,
pub on_data: Option<OutputCallback>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ShellHandle {
pub shell_id: String,
}
fn rejected_to_error(rejected: RejectedResponse) -> ClientError {
ClientError::Kernel {
code: rejected.code,
message: rejected.message,
}
}
fn stdin_chunk(data: StdinInput) -> Vec<u8> {
match data {
StdinInput::Text(text) => text.into_bytes(),
StdinInput::Bytes(bytes) => bytes,
}
}
impl AgentOs {
fn vm_ownership(&self) -> OwnershipScope {
OwnershipScope::vm(
self.connection_id().to_string(),
self.wire_session_id().to_string(),
self.vm_id().to_string(),
)
}
}
impl AgentOs {
pub fn open_shell(&self, mut options: OpenShellOptions) -> Result<ShellHandle> {
let inner = self.inner();
let counter = inner.shell_counter.fetch_add(1, Ordering::SeqCst) + 1;
let shell_id = format!("shell-{counter}");
let process_id = format!("shell-{}", Uuid::new_v4());
let (data_tx, _) = tokio::sync::broadcast::channel(SHELL_DATA_CHANNEL_CAPACITY);
let (stderr_tx, _) = tokio::sync::broadcast::channel(SHELL_DATA_CHANNEL_CAPACITY);
let (spawned_tx, _) = tokio::sync::watch::channel(false);
if let Some(cb) = options.on_stderr.take() {
install_output_callback(stderr_tx.clone(), cb);
}
let entry = ShellEntry {
pid: 0,
data_tx: data_tx.clone(),
stderr_tx: stderr_tx.clone(),
process_id: process_id.clone(),
spawned_tx: spawned_tx.clone(),
};
let _ = inner.shells.insert(shell_id.clone(), entry);
let command = options
.command
.clone()
.unwrap_or_else(|| DEFAULT_SHELL_COMMAND.to_string());
let execute = ExecuteRequest {
process_id: process_id.clone(),
command: Some(command),
runtime: None,
entrypoint: None,
args: options.args.clone(),
env: options.env.clone().into_iter().collect(),
cwd: options.cwd.clone(),
wasm_permission_tier: None,
};
let agent = self.clone();
let ownership = self.vm_ownership();
let route_process_id = process_id.clone();
let exit_shell_id = shell_id.clone();
let exit_key = counter;
let handle = tokio::spawn(async move {
let mut events = agent.transport().subscribe_events();
let response = match agent
.transport()
.request(ownership.clone(), RequestPayload::Execute(execute))
.await
{
Ok(response) => response,
Err(error) => {
tracing::warn!(?error, shell_id = %exit_shell_id, "open_shell spawn failed");
agent.inner().shells.remove(&exit_shell_id);
agent.inner().pending_shell_exits.remove(&exit_key);
return;
}
};
if let ResponsePayload::ProcessStarted(ProcessStartedResponse { pid, .. }) = response {
if let Some(pid) = pid {
agent
.inner()
.shells
.update(&exit_shell_id, |_, existing| existing.pid = pid);
}
}
let _ = spawned_tx.send(true);
loop {
let (_scope, payload) = match events.recv().await {
Ok(value) => value,
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
};
match payload {
EventPayload::ProcessOutput(output) => {
if output.process_id != route_process_id {
continue;
}
match output.channel {
StreamChannel::Stdout => {
let _ = data_tx.send(output.chunk);
}
StreamChannel::Stderr => {
let _ = stderr_tx.send(output.chunk);
}
}
}
EventPayload::ProcessExited(exited) => {
if exited.process_id == route_process_id {
break;
}
}
EventPayload::VmLifecycle(_) | EventPayload::Structured(_) => {}
}
}
agent.inner().pending_shell_exits.remove(&exit_key);
agent
.inner()
.shells
.remove_if(&exit_shell_id, |existing| {
existing.process_id == route_process_id
});
});
let _ = inner.pending_shell_exits.insert(counter, handle);
Ok(ShellHandle { shell_id })
}
pub async fn connect_terminal(&self, options: ConnectTerminalOptions) -> Result<u32> {
let ConnectTerminalOptions { base, on_data } = options;
let process_id = format!("terminal-{}", Uuid::new_v4());
let command = base
.command
.clone()
.unwrap_or_else(|| DEFAULT_SHELL_COMMAND.to_string());
let (data_tx, _) = tokio::sync::broadcast::channel::<Vec<u8>>(SHELL_DATA_CHANNEL_CAPACITY);
let (stderr_tx, _) =
tokio::sync::broadcast::channel::<Vec<u8>>(SHELL_DATA_CHANNEL_CAPACITY);
if let Some(cb) = on_data {
install_output_callback(data_tx.clone(), cb);
}
if let Some(cb) = base.on_stderr {
install_output_callback(stderr_tx.clone(), cb);
}
let execute = ExecuteRequest {
process_id: process_id.clone(),
command: Some(command),
runtime: None,
entrypoint: None,
args: base.args.clone(),
env: base.env.clone().into_iter().collect(),
cwd: base.cwd.clone(),
wasm_permission_tier: None,
};
let mut events = self.transport().subscribe_events();
let response = self
.transport()
.request(self.vm_ownership(), RequestPayload::Execute(execute))
.await
.context("connect_terminal spawn failed")?;
let pid = match response {
ResponsePayload::ProcessStarted(ProcessStartedResponse { pid, .. }) => {
pid.context("connect_terminal: sidecar did not return a pid")?
}
ResponsePayload::Rejected(rejected) => return Err(rejected_to_error(rejected).into()),
_ => anyhow::bail!("unexpected response to connect_terminal"),
};
let route_process_id = process_id;
tokio::spawn(async move {
loop {
let (_scope, payload) = match events.recv().await {
Ok(value) => value,
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => continue,
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
};
match payload {
EventPayload::ProcessOutput(output) => {
if output.process_id != route_process_id {
continue;
}
match output.channel {
StreamChannel::Stdout => {
let _ = data_tx.send(output.chunk);
}
StreamChannel::Stderr => {
let _ = stderr_tx.send(output.chunk);
}
}
}
EventPayload::ProcessExited(exited) => {
if exited.process_id == route_process_id {
break;
}
}
EventPayload::VmLifecycle(_) | EventPayload::Structured(_) => {}
}
}
});
let _ = self.inner().acp_terminal_pids.insert(pid);
Ok(pid)
}
pub fn write_shell(
&self,
shell_id: &str,
data: StdinInput,
) -> std::result::Result<(), ClientError> {
let (process_id, spawned_rx) = self.shell_wire_handle(shell_id)?;
let chunk = stdin_chunk(data);
let agent = self.clone();
let ownership = self.vm_ownership();
tokio::spawn(async move {
wait_for_spawn(spawned_rx).await;
let payload = RequestPayload::WriteStdin(WriteStdinRequest { process_id, chunk });
if let Err(error) = agent.transport().request(ownership, payload).await {
tracing::warn!(?error, "write_shell failed");
}
});
Ok(())
}
pub fn on_shell_data(
&self,
shell_id: &str,
) -> std::result::Result<ByteStream, ClientError> {
self.inner()
.shells
.read(shell_id, |_, entry| entry.data_tx.subscribe())
.map(ByteStream::new)
.ok_or_else(|| ClientError::ShellNotFound(shell_id.to_string()))
}
pub fn on_shell_stderr(
&self,
shell_id: &str,
) -> std::result::Result<ByteStream, ClientError> {
self.inner()
.shells
.read(shell_id, |_, entry| entry.stderr_tx.subscribe())
.map(ByteStream::new)
.ok_or_else(|| ClientError::ShellNotFound(shell_id.to_string()))
}
pub fn resize_shell(
&self,
shell_id: &str,
cols: u16,
rows: u16,
) -> std::result::Result<(), ClientError> {
let _ = self.shell_wire_handle(shell_id)?;
tracing::warn!(
shell_id = %shell_id,
cols,
rows,
"resize_shell has no native winsize wire op; resize is a no-op"
);
Ok(())
}
pub fn close_shell(&self, shell_id: &str) -> std::result::Result<(), ClientError> {
let (process_id, spawned_rx) = self.shell_wire_handle(shell_id)?;
self.inner().shells.remove(shell_id);
let agent = self.clone();
let ownership = self.vm_ownership();
tokio::spawn(async move {
wait_for_spawn(spawned_rx).await;
let payload = RequestPayload::KillProcess(KillProcessRequest {
process_id,
signal: String::from("SIGTERM"),
});
if let Err(error) = agent.transport().request(ownership, payload).await {
tracing::warn!(?error, "close_shell kill failed");
}
});
Ok(())
}
fn shell_wire_handle(
&self,
shell_id: &str,
) -> std::result::Result<(String, tokio::sync::watch::Receiver<bool>), ClientError> {
self.inner()
.shells
.read(shell_id, |_, entry| {
(entry.process_id.clone(), entry.spawned_tx.subscribe())
})
.ok_or_else(|| ClientError::ShellNotFound(shell_id.to_string()))
}
}
async fn wait_for_spawn(mut spawned_rx: tokio::sync::watch::Receiver<bool>) {
if *spawned_rx.borrow() {
return;
}
while spawned_rx.changed().await.is_ok() {
if *spawned_rx.borrow() {
return;
}
}
}