use std::{
collections::VecDeque,
io::Read,
sync::{Arc, Mutex as StdMutex},
time::SystemTime,
};
use bytes::Bytes;
use chrono::Utc;
use identity::{AgentHint, AuthorizationRequest, CheckOutcome, Principal, Tier};
use observe::{BrainEvent, Observer, PrincipalSummary};
use portable_pty::{native_pty_system, CommandBuilder, PtySize};
use tokio::sync::{broadcast, mpsc, Mutex};
use tokio_stream::{wrappers::ReceiverStream, StreamExt};
use tonic::{Request, Response, Status, Streaming};
use tracing::{error, warn};
use uuid::Uuid;
use crate::{
graph::TerminalGraphSink,
pb::{
self, client_frame, server_frame, terminal_session_server::TerminalSession, ClientFrame,
CloseAck, InputChunk, OpenRequest, OutputChunk, ResizeAck, ResizeRequest, SendAck,
ServerFrame, SessionHandle, Sig, SignalAck, SignalRequest,
},
session::{Session, IN_MPSC_CAPACITY, OUT_BROADCAST_CAPACITY},
types::{SessionMeta, TermSize},
SessionRegistry, TerminalAuth,
};
const PTY_READ_BUFFER_SIZE: usize = 8 * 1024;
const STREAM_OUT_BUFFER: usize = 64;
#[derive(Clone)]
pub struct TerminalSvc {
registry: Arc<SessionRegistry>,
auth: Option<TerminalAuth>,
observer: Option<Arc<dyn Observer>>,
graph_sink: Option<Arc<dyn TerminalGraphSink>>,
}
impl TerminalSvc {
pub fn new(
registry: Arc<SessionRegistry>,
auth: Option<TerminalAuth>,
observer: Option<Arc<dyn Observer>>,
graph_sink: Option<Arc<dyn TerminalGraphSink>>,
) -> Self {
Self {
registry,
auth,
observer,
graph_sink,
}
}
pub fn registry(&self) -> &Arc<SessionRegistry> {
&self.registry
}
pub async fn open_via_pipeline(
&self,
request: OpenRequest,
principal: Option<Principal>,
) -> Result<SessionHandle, Status> {
self.open_inner(request, principal).await
}
pub async fn close_via_pipeline(&self, id: &str) -> Result<CloseAck, Status> {
self.close_inner(id).await
}
}
fn term_size_from_pb(pb: Option<pb::PtySize>) -> TermSize {
match pb {
Some(s) => TermSize {
rows: s.rows as u16,
cols: s.cols as u16,
pixel_width: s.pixel_width as u16,
pixel_height: s.pixel_height as u16,
},
None => TermSize::default(),
}
}
fn to_pty_size(s: TermSize) -> PtySize {
PtySize {
rows: s.rows,
cols: s.cols,
pixel_width: s.pixel_width,
pixel_height: s.pixel_height,
}
}
fn timestamp_now() -> Option<prost_types::Timestamp> {
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.ok()?;
Some(prost_types::Timestamp {
seconds: now.as_secs() as i64,
nanos: now.subsec_nanos() as i32,
})
}
fn principal_summary(p: &Principal) -> PrincipalSummary {
PrincipalSummary {
user_id: p.user_id.0.clone(),
agent_id: p.agent_id.0.clone(),
}
}
fn api_key_from_metadata<T>(req: &Request<T>) -> Option<String> {
let metadata = req.metadata();
if let Some(v) = metadata.get("x-api-key").and_then(|v| v.to_str().ok()) {
return Some(v.to_string());
}
let authz = metadata
.get("authorization")
.and_then(|v| v.to_str().ok())?;
let key = brain::auth::extract_bearer_from_value(authz).unwrap_or(authz);
Some(key.to_string())
}
impl TerminalSvc {
async fn authorize<T>(
&self,
req: &Request<T>,
verb_action: &str,
modifiers: serde_json::Value,
) -> Result<Option<Principal>, Status> {
let Some(auth) = self.auth.as_ref() else {
return Ok(None);
};
let key =
api_key_from_metadata(req).ok_or_else(|| Status::unauthenticated("missing api-key"))?;
let agent_id = brain::auth::find_key_ct(&auth.api_keys, &key)
.and_then(|k| k.agent_id.clone())
.ok_or_else(|| Status::unauthenticated("unknown api-key"))?;
let principal = auth
.identity
.principal_for(&AgentHint::AgentId(agent_id.into()))
.await
.map_err(|e| Status::unauthenticated(format!("principal lookup: {e}")))?;
let authz_req =
AuthorizationRequest::new("terminal", verb_action).with_modifiers(modifiers);
match auth
.identity
.check(&principal, &authz_req, Tier::Execute)
.await
{
CheckOutcome::Allow => Ok(Some(principal)),
CheckOutcome::EscalateToUser { reason } => Err(Status::permission_denied(format!(
"terminal.{verb_action} requires user confirmation: {reason}"
))),
CheckOutcome::Deny { reason } => Err(Status::permission_denied(format!(
"terminal.{verb_action} denied: {reason}"
))),
}
}
async fn publish(&self, ev: BrainEvent) {
if let Some(obs) = self.observer.as_ref() {
let _ = obs.publish(ev).await;
}
}
async fn open_inner(
&self,
r: OpenRequest,
principal: Option<Principal>,
) -> Result<SessionHandle, Status> {
let size = term_size_from_pb(r.initial_size);
let pty = native_pty_system();
let pair = pty
.openpty(to_pty_size(size))
.map_err(|e| Status::internal(format!("openpty: {e}")))?;
let mut cmd = if r.program.is_empty() {
CommandBuilder::new_default_prog()
} else {
CommandBuilder::new(&r.program)
};
for a in &r.args {
cmd.arg(a);
}
for (k, v) in &r.env {
cmd.env(k, v);
}
if !r.cwd.is_empty() {
cmd.cwd(&r.cwd);
}
let child = pair
.slave
.spawn_command(cmd)
.map_err(|e| Status::internal(format!("spawn: {e}")))?;
drop(pair.slave);
let master = Arc::new(Mutex::new(pair.master));
let (out_tx, out_anchor) = broadcast::channel::<Bytes>(OUT_BROADCAST_CAPACITY);
let (in_tx, mut in_rx) = mpsc::channel::<Bytes>(IN_MPSC_CAPACITY);
let replay: Arc<StdMutex<VecDeque<Bytes>>> = Arc::new(StdMutex::new(
VecDeque::with_capacity(OUT_BROADCAST_CAPACITY),
));
{
let reader_res = master.lock().await.try_clone_reader();
let reader = reader_res.map_err(|e| Status::internal(format!("clone_reader: {e}")))?;
let replay_for_pump = replay.clone();
tokio::task::spawn_blocking(move || pump_reader(reader, out_tx, replay_for_pump));
}
{
let writer_res = master.lock().await.take_writer();
let writer = writer_res.map_err(|e| Status::internal(format!("take_writer: {e}")))?;
tokio::task::spawn_blocking(move || {
let mut writer = writer;
while let Some(chunk) = in_rx.blocking_recv() {
use std::io::Write;
if writer.write_all(&chunk).is_err() {
break;
}
let _ = writer.flush();
}
});
}
let session_uuid = Uuid::new_v4();
let session_id = session_uuid.to_string();
let cwd = if r.cwd.is_empty() { None } else { Some(r.cwd) };
let meta = SessionMeta {
session_id: session_id.clone(),
program: r.program,
args: r.args,
cwd,
opened_at: Utc::now(),
client_id: if r.client_id.is_empty() {
None
} else {
Some(r.client_id)
},
size,
principal: principal.clone(),
};
let event = BrainEvent::TerminalSessionOpened {
id: session_uuid,
session_id: session_id.clone(),
program: meta.program.clone(),
args: meta.args.clone(),
cwd: meta.cwd.clone(),
principal: principal.as_ref().map(principal_summary),
ts: Utc::now(),
};
let graph_handles = if let Some(sink) = &self.graph_sink {
let principal_ref = principal.as_ref();
match sink
.record_open(
&session_id,
&meta.program,
&meta.args,
meta.cwd.as_deref(),
principal_ref,
)
.await
{
Ok(h) => Some(h),
Err(e) => {
warn!(session_id = %session_id, error = %e, "graph mirror record_open failed");
None
}
}
} else {
None
};
let session = Arc::new(Session {
meta,
out_anchor,
replay,
in_tx,
master,
child: Arc::new(Mutex::new(child)),
graph_handles,
});
self.registry.insert(session).await;
self.publish(event).await;
Ok(SessionHandle { session_id })
}
async fn close_inner(&self, id: &str) -> Result<CloseAck, Status> {
let session = self
.registry
.remove(&id.to_string())
.await
.ok_or_else(|| Status::not_found(format!("session '{id}' not found")))?;
let mut child = session.child.lock().await;
let already_exited = matches!(child.try_wait(), Ok(Some(_)));
let was_killed = if already_exited {
false
} else {
child.kill().is_ok()
};
let exit_code = child.wait().map(|s| s.exit_code() as i32).unwrap_or(-1);
drop(child);
let session_uuid = Uuid::parse_str(id).unwrap_or_else(|_| Uuid::new_v4());
self.publish(BrainEvent::TerminalSessionClosed {
id: session_uuid,
session_id: id.to_string(),
exit_code,
was_killed,
principal: session.meta.principal.as_ref().map(principal_summary),
ts: Utc::now(),
})
.await;
if let (Some(sink), Some(handles)) = (&self.graph_sink, &session.graph_handles) {
if let Err(e) = sink.record_close(handles, id, exit_code, was_killed).await {
warn!(session_id = %id, error = %e, "graph mirror record_close failed");
}
}
Ok(CloseAck {
exit_code,
was_killed,
})
}
async fn lookup(&self, id: &str) -> Result<Arc<Session>, Status> {
self.registry
.get(&id.to_string())
.await
.ok_or_else(|| Status::not_found(format!("session '{id}' not found")))
}
async fn write_input_inner(&self, id: &str, data: Bytes) -> Result<u64, Status> {
let session = self.lookup(id).await?;
let len = data.len() as u64;
session
.in_tx
.send(data)
.await
.map_err(|_| Status::aborted("session writer closed"))?;
Ok(len)
}
async fn resize_inner(&self, id: &str, size: TermSize) -> Result<(), Status> {
let session = self.lookup(id).await?;
let result = session.master.lock().await.resize(to_pty_size(size));
result.map_err(|e| Status::internal(format!("resize: {e}")))
}
async fn signal_inner(&self, id: &str, sig: Sig) -> Result<(), Status> {
let session = self.lookup(id).await?;
match sig {
Sig::Sigint => session
.in_tx
.send(Bytes::from_static(b"\x03"))
.await
.map_err(|_| Status::aborted("session writer closed")),
Sig::Sigquit => session
.in_tx
.send(Bytes::from_static(b"\x1c"))
.await
.map_err(|_| Status::aborted("session writer closed")),
Sig::Sigterm | Sig::Sighup | Sig::Sigkill => session
.child
.lock()
.await
.kill()
.map_err(|e| Status::internal(format!("kill: {e}"))),
Sig::Unspecified => Err(Status::invalid_argument("signal must not be UNSPECIFIED")),
}
}
}
#[tonic::async_trait]
impl TerminalSession for TerminalSvc {
type AttachStream = ReceiverStream<Result<OutputChunk, Status>>;
type InteractStream = ReceiverStream<Result<ServerFrame, Status>>;
async fn open(&self, req: Request<OpenRequest>) -> Result<Response<SessionHandle>, Status> {
let principal = self
.authorize(
&req,
"open",
serde_json::json!({"program": req.get_ref().program.as_str()}),
)
.await?;
Ok(Response::new(
self.open_inner(req.into_inner(), principal).await?,
))
}
async fn close(&self, req: Request<SessionHandle>) -> Result<Response<CloseAck>, Status> {
self.authorize(&req, "close", serde_json::Value::Null)
.await?;
let id = req.into_inner().session_id;
Ok(Response::new(self.close_inner(&id).await?))
}
async fn attach(
&self,
req: Request<SessionHandle>,
) -> Result<Response<Self::AttachStream>, Status> {
self.authorize(&req, "attach", serde_json::Value::Null)
.await?;
let id = req.into_inner().session_id;
let session = self.lookup(&id).await?;
let (snapshot, mut rx) = session.attach_snapshot();
let (tx, out) = mpsc::channel::<Result<OutputChunk, Status>>(STREAM_OUT_BUFFER);
tokio::spawn(supervise("attach-pump", async move {
let mut seq: u64 = 0;
for bytes in snapshot {
seq += 1;
let chunk = OutputChunk {
data: bytes.to_vec(),
ts: timestamp_now(),
seq,
eof: false,
};
if tx.send(Ok(chunk)).await.is_err() {
return;
}
}
loop {
match rx.recv().await {
Ok(bytes) => {
seq += 1;
let chunk = OutputChunk {
data: bytes.to_vec(),
ts: timestamp_now(),
seq,
eof: false,
};
if tx.send(Ok(chunk)).await.is_err() {
break;
}
}
Err(broadcast::error::RecvError::Closed) => {
let _ = tx
.send(Ok(OutputChunk {
data: Vec::new(),
ts: timestamp_now(),
seq: seq + 1,
eof: true,
}))
.await;
break;
}
Err(broadcast::error::RecvError::Lagged(n)) => {
warn!(
session = %id,
dropped = n,
"attach stream lagged — bumped subscriber"
);
let _ = tx
.send(Err(Status::resource_exhausted(format!(
"attach lagged: {n} chunks dropped"
))))
.await;
break;
}
}
}
}));
Ok(Response::new(ReceiverStream::new(out)))
}
async fn send(&self, req: Request<Streaming<InputChunk>>) -> Result<Response<SendAck>, Status> {
self.authorize(&req, "send", serde_json::Value::Null)
.await?;
let mut stream = req.into_inner();
let mut total: u64 = 0;
while let Some(chunk_res) = stream.next().await {
let chunk = chunk_res?;
if chunk.session_id.is_empty() {
return Err(Status::invalid_argument("input chunk missing session_id"));
}
total += self
.write_input_inner(&chunk.session_id, Bytes::from(chunk.data))
.await?;
}
Ok(Response::new(SendAck {
bytes_written: total,
}))
}
async fn resize(&self, req: Request<ResizeRequest>) -> Result<Response<ResizeAck>, Status> {
self.authorize(&req, "resize", serde_json::Value::Null)
.await?;
let r = req.into_inner();
let size = term_size_from_pb(r.size);
self.resize_inner(&r.session_id, size).await?;
Ok(Response::new(ResizeAck {}))
}
async fn signal(&self, req: Request<SignalRequest>) -> Result<Response<SignalAck>, Status> {
self.authorize(&req, "signal", serde_json::Value::Null)
.await?;
let r = req.into_inner();
let sig = Sig::try_from(r.signal).unwrap_or(Sig::Unspecified);
self.signal_inner(&r.session_id, sig).await?;
Ok(Response::new(SignalAck {}))
}
async fn interact(
&self,
req: Request<Streaming<ClientFrame>>,
) -> Result<Response<Self::InteractStream>, Status> {
let principal = self
.authorize(&req, "interact", serde_json::Value::Null)
.await?;
let mut input = req.into_inner();
let (tx, out) = mpsc::channel::<Result<ServerFrame, Status>>(STREAM_OUT_BUFFER);
let svc = self.clone();
tokio::spawn(supervise("interact-bidi", async move {
let mut bound_id: Option<String> = None;
let mut output_task: Option<tokio::task::JoinHandle<()>> = None;
while let Some(frame_res) = input.next().await {
let frame = match frame_res {
Ok(f) => f,
Err(_) => break,
};
let Some(k) = frame.k else {
continue;
};
match k {
client_frame::K::Open(open_req) => {
if bound_id.is_some() {
let _ = tx
.send(Ok(error_frame("Interact: session already opened")))
.await;
continue;
}
match svc.open_inner(open_req, principal.clone()).await {
Ok(handle) => {
bound_id = Some(handle.session_id.clone());
output_task = Some(spawn_output_forwarder(
svc.clone(),
handle.session_id.clone(),
tx.clone(),
));
let _ = tx
.send(Ok(ServerFrame {
k: Some(server_frame::K::Handle(handle)),
}))
.await;
}
Err(s) => {
let _ = tx.send(Ok(error_frame(s.message()))).await;
break;
}
}
}
client_frame::K::Input(chunk) => {
let target = pick_id(&bound_id, &chunk.session_id);
if let Some(id) = target {
match svc.write_input_inner(&id, Bytes::from(chunk.data)).await {
Ok(n) => {
let _ = tx
.send(Ok(ServerFrame {
k: Some(server_frame::K::Ack(SendAck {
bytes_written: n,
})),
}))
.await;
}
Err(s) => {
let _ = tx.send(Ok(error_frame(s.message()))).await;
}
}
} else {
let _ = tx
.send(Ok(error_frame("Interact: Input before Open")))
.await;
}
}
client_frame::K::Resize(r) => {
let target = pick_id(&bound_id, &r.session_id);
if let Some(id) = target {
let size = term_size_from_pb(r.size);
if let Err(s) = svc.resize_inner(&id, size).await {
let _ = tx.send(Ok(error_frame(s.message()))).await;
}
}
}
client_frame::K::Signal(sg) => {
let target = pick_id(&bound_id, &sg.session_id);
if let Some(id) = target {
let sig = Sig::try_from(sg.signal).unwrap_or(Sig::Unspecified);
if let Err(s) = svc.signal_inner(&id, sig).await {
let _ = tx.send(Ok(error_frame(s.message()))).await;
}
}
}
client_frame::K::Close(handle) => {
let target = pick_id(&bound_id, &handle.session_id);
if let Some(id) = target {
match svc.close_inner(&id).await {
Ok(ack) => {
let _ = tx
.send(Ok(ServerFrame {
k: Some(server_frame::K::Closed(ack)),
}))
.await;
}
Err(s) => {
let _ = tx.send(Ok(error_frame(s.message()))).await;
}
}
}
break;
}
}
}
if let Some(id) = bound_id {
if svc.registry.get(&id).await.is_some() {
let _ = svc.close_inner(&id).await;
}
}
if let Some(t) = output_task {
t.abort();
}
}));
Ok(Response::new(ReceiverStream::new(out)))
}
}
fn error_frame(msg: impl Into<String>) -> ServerFrame {
ServerFrame {
k: Some(server_frame::K::Error(msg.into())),
}
}
async fn supervise(name: &'static str, body: impl std::future::Future<Output = ()> + Send) {
use futures::FutureExt;
let result = std::panic::AssertUnwindSafe(body).catch_unwind().await;
if result.is_err() {
error!(task = name, "terminal PTY pump panicked");
}
}
fn pick_id(bound: &Option<String>, per_frame: &str) -> Option<String> {
if !per_frame.is_empty() {
Some(per_frame.to_string())
} else {
bound.clone()
}
}
fn spawn_output_forwarder(
svc: TerminalSvc,
session_id: String,
tx: mpsc::Sender<Result<ServerFrame, Status>>,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(supervise("interact-output-forwarder", async move {
let Some(session) = svc.registry.get(&session_id).await else {
let _ = tx
.send(Ok(error_frame(format!(
"Interact: session '{session_id}' vanished"
))))
.await;
return;
};
let (snapshot, mut rx) = session.attach_snapshot();
let mut seq: u64 = 0;
for bytes in snapshot {
seq += 1;
let chunk = OutputChunk {
data: bytes.to_vec(),
ts: timestamp_now(),
seq,
eof: false,
};
let frame = ServerFrame {
k: Some(server_frame::K::Output(chunk)),
};
if tx.send(Ok(frame)).await.is_err() {
return;
}
}
loop {
match rx.recv().await {
Ok(bytes) => {
seq += 1;
let chunk = OutputChunk {
data: bytes.to_vec(),
ts: timestamp_now(),
seq,
eof: false,
};
let frame = ServerFrame {
k: Some(server_frame::K::Output(chunk)),
};
if tx.send(Ok(frame)).await.is_err() {
break;
}
}
Err(broadcast::error::RecvError::Closed) => {
let _ = tx
.send(Ok(ServerFrame {
k: Some(server_frame::K::Output(OutputChunk {
data: Vec::new(),
ts: timestamp_now(),
seq: seq + 1,
eof: true,
})),
}))
.await;
break;
}
Err(broadcast::error::RecvError::Lagged(n)) => {
warn!(
session = %session_id,
dropped = n,
"interact stream lagged — bumped subscriber"
);
let _ = tx
.send(Ok(error_frame(format!(
"interact lagged: {n} chunks dropped"
))))
.await;
break;
}
}
}
}))
}
fn pump_reader(
mut reader: Box<dyn Read + Send>,
out_tx: broadcast::Sender<Bytes>,
replay: Arc<StdMutex<VecDeque<Bytes>>>,
) {
let mut buf = vec![0u8; PTY_READ_BUFFER_SIZE];
loop {
match reader.read(&mut buf) {
Ok(0) => break, Ok(n) => {
let chunk = Bytes::copy_from_slice(&buf[..n]);
let mut guard = replay.lock().expect("replay mutex poisoned");
if guard.len() == OUT_BROADCAST_CAPACITY {
guard.pop_front();
}
guard.push_back(chunk.clone());
let _ = out_tx.send(chunk);
drop(guard);
}
Err(_) => break,
}
}
}