use std::collections::{HashMap, HashSet};
use std::os::fd::RawFd;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use bytes::{Bytes, BytesMut};
use microsandbox_protocol::codec::{self, MAX_FRAME_SIZE};
use microsandbox_protocol::exec::ExecSignal;
use microsandbox_protocol::message::{
FLAG_SESSION_START, FLAG_SHUTDOWN, FLAG_TERMINAL, FRAME_HEADER_SIZE, Message, MessageType,
};
use tokio::io::{AsyncReadExt, AsyncWriteExt, unix::AsyncFd};
use tokio::net::UnixListener;
use tokio::net::unix::OwnedReadHalf;
use tokio::sync::{Mutex, mpsc, watch};
use crate::console::ConsoleSharedState;
use crate::{RuntimeError, RuntimeResult};
const MAX_CLIENTS: u32 = 16;
const ID_RANGE_STEP: u32 = u32::MAX / MAX_CLIENTS;
const LEN_PREFIX_SIZE: usize = 4;
struct ClientState {
active_sessions: HashSet<u32>,
write_tx: mpsc::Sender<Bytes>,
}
const CLIENT_WRITE_CHANNEL_CAPACITY: usize = 64;
pub struct AgentRelay {
shared: Arc<ConsoleSharedState>,
listener: UnixListener,
sock_path: PathBuf,
ready_frame: Option<Vec<u8>>,
}
struct RawFrame {
data: Bytes,
id: u32,
flags: u8,
}
impl AgentRelay {
pub async fn new(
agent_sock_path: &Path,
shared: Arc<ConsoleSharedState>,
) -> RuntimeResult<Self> {
if agent_sock_path.exists() {
let _ = std::fs::remove_file(agent_sock_path);
}
if let Some(parent) = agent_sock_path.parent() {
std::fs::create_dir_all(parent)?;
}
let listener = UnixListener::bind(agent_sock_path)?;
tracing::info!("agent relay listening on {}", agent_sock_path.display());
Ok(Self {
shared,
listener,
sock_path: agent_sock_path.to_path_buf(),
ready_frame: None,
})
}
pub fn wait_ready(&mut self) -> RuntimeResult<()> {
const READY_TIMEOUT_SECS: i32 = 60;
let mut buf = BytesMut::new();
let deadline =
std::time::Instant::now() + std::time::Duration::from_secs(READY_TIMEOUT_SECS as u64);
loop {
self.shared.tx_wake.drain();
while let Some(chunk) = self.shared.tx_ring.pop() {
buf.extend_from_slice(&chunk);
}
while let Some(frame) = try_extract_frame(&mut buf) {
let raw_data = frame.data.to_vec();
let msg = decode_frame(raw_data.clone())?;
if msg.t == MessageType::Ready {
tracing::info!("agent relay: received core.ready from agentd");
self.ready_frame = Some(raw_data);
return Ok(());
}
tracing::debug!(
"agent relay: discarding pre-ready frame type={:?} id={}",
msg.t,
msg.id
);
}
let remaining = deadline.saturating_duration_since(std::time::Instant::now());
if remaining.is_zero() {
return Err(RuntimeError::Custom(
"agent relay: timed out waiting for core.ready from agentd".into(),
));
}
let timeout_ms = remaining.as_millis().min(i32::MAX as u128) as i32;
poll_fd_readable_timeout(self.shared.tx_wake.as_raw_fd(), timeout_ms);
}
}
pub async fn run(
self,
mut shutdown: watch::Receiver<bool>,
drain_tx: mpsc::Sender<()>,
) -> RuntimeResult<()> {
let ready_frame = self.ready_frame.ok_or_else(|| {
RuntimeError::Custom("agent relay: run() called before wait_ready()".into())
})?;
let clients: Arc<Mutex<HashMap<u32, ClientState>>> = Arc::new(Mutex::new(HashMap::new()));
let (agent_tx, agent_rx) = mpsc::channel::<Vec<u8>>(256);
let used_slots: Arc<Mutex<HashSet<u32>>> = Arc::new(Mutex::new(HashSet::new()));
let shared_for_writer = Arc::clone(&self.shared);
let ring_writer_handle = tokio::spawn(ring_writer_task(shared_for_writer, agent_rx));
let clients_for_reader = Arc::clone(&clients);
let shared_for_reader = Arc::clone(&self.shared);
let ring_reader_handle =
tokio::spawn(ring_reader_task(shared_for_reader, clients_for_reader));
loop {
tokio::select! {
accept_result = self.listener.accept() => {
match accept_result {
Ok((stream, _addr)) => {
let slot = {
let mut slots = used_slots.lock().await;
let mut found = None;
for i in 0..MAX_CLIENTS {
if !slots.contains(&i) {
slots.insert(i);
found = Some(i);
break;
}
}
found
};
let slot = match slot {
Some(s) => s,
None => {
tracing::error!("agent relay: max clients reached, rejecting connection");
drop(stream);
continue;
}
};
let id_offset = slot * ID_RANGE_STEP;
tracing::info!(
"agent relay: client connected slot={slot} id_offset={id_offset}"
);
let (reader_half, mut writer_half) = stream.into_split();
let mut handshake = Vec::with_capacity(4 + ready_frame.len());
handshake.extend_from_slice(&id_offset.to_be_bytes());
handshake.extend_from_slice(&ready_frame);
if let Err(e) = writer_half.write_all(&handshake).await {
tracing::error!(
"agent relay: handshake write failed slot={slot}: {e}"
);
used_slots.lock().await.remove(&slot);
continue;
}
let (write_tx, mut write_rx) =
mpsc::channel::<Bytes>(CLIENT_WRITE_CHANNEL_CAPACITY);
tokio::spawn(async move {
while let Some(data) = write_rx.recv().await {
if let Err(e) = writer_half.write_all(&data).await {
tracing::error!(
"agent relay: client writer slot={slot} failed: {e}"
);
break;
}
}
});
{
let mut map = clients.lock().await;
map.insert(slot, ClientState {
active_sessions: HashSet::new(),
write_tx,
});
}
let agent_tx_clone = agent_tx.clone();
let clients_clone = Arc::clone(&clients);
let used_slots_clone = Arc::clone(&used_slots);
let drain_tx_clone = drain_tx.clone();
tokio::spawn(client_reader_task(
slot,
reader_half,
agent_tx_clone,
clients_clone,
used_slots_clone,
drain_tx_clone,
));
}
Err(e) => {
tracing::error!("agent relay: accept error: {e}");
}
}
}
_ = shutdown.changed() => {
if *shutdown.borrow() {
tracing::info!("agent relay: shutdown signal received");
break;
}
}
}
}
let _ = std::fs::remove_file(&self.sock_path);
ring_writer_handle.abort();
ring_reader_handle.abort();
Ok(())
}
}
fn poll_fd_readable_timeout(fd: RawFd, timeout_ms: i32) {
loop {
let mut pfd = libc::pollfd {
fd,
events: libc::POLLIN,
revents: 0,
};
let ret = unsafe { libc::poll(&mut pfd, 1, timeout_ms) };
if ret >= 0 {
return; }
let errno = std::io::Error::last_os_error();
if errno.raw_os_error() != Some(libc::EINTR) {
tracing::error!("agent relay: poll() failed: {errno}");
return;
}
}
}
fn try_extract_frame(buf: &mut BytesMut) -> Option<RawFrame> {
if buf.len() < LEN_PREFIX_SIZE {
return None;
}
let frame_len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
if frame_len > MAX_FRAME_SIZE as usize {
tracing::error!(
"agent relay: frame too large ({frame_len} bytes), clearing {} bytes of buffer",
buf.len()
);
buf.clear();
return None;
}
if buf.len() < LEN_PREFIX_SIZE + frame_len {
return None; }
if frame_len < FRAME_HEADER_SIZE {
tracing::error!("agent relay: frame too short ({frame_len} bytes), discarding");
let _ = buf.split_to(LEN_PREFIX_SIZE + frame_len);
return None;
}
let data = buf.split_to(LEN_PREFIX_SIZE + frame_len).freeze();
let id = u32::from_be_bytes([data[4], data[5], data[6], data[7]]);
let flags = data[8];
Some(RawFrame { data, id, flags })
}
fn decode_frame(mut buf: Vec<u8>) -> RuntimeResult<Message> {
codec::try_decode_from_buf(&mut buf)
.map_err(|e| RuntimeError::Custom(format!("decode frame: {e}")))?
.ok_or_else(|| RuntimeError::Custom("decode frame: incomplete frame".into()))
}
async fn ring_writer_task(shared: Arc<ConsoleSharedState>, mut rx: mpsc::Receiver<Vec<u8>>) {
while let Some(frame_bytes) = rx.recv().await {
let mut data = frame_bytes;
for attempt in 0..50 {
match shared.rx_ring.push(data) {
Ok(()) => {
shared.rx_wake.wake();
break;
}
Err(returned) => {
if attempt == 49 {
tracing::error!("agent relay: rx_ring full after retries, dropping frame");
break;
}
data = returned;
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
}
}
}
}
tracing::debug!("agent relay: ring writer task exiting");
}
async fn ring_reader_task(
shared: Arc<ConsoleSharedState>,
clients: Arc<Mutex<HashMap<u32, ClientState>>>,
) {
let wake_fd = shared.tx_wake.as_raw_fd();
let async_fd = match AsyncFd::new(wake_fd) {
Ok(fd) => fd,
Err(e) => {
tracing::error!("agent relay: failed to create AsyncFd for tx_wake: {e}");
return;
}
};
let mut buf = BytesMut::new();
let mut frames = Vec::new();
loop {
let mut guard = match async_fd.readable().await {
Ok(g) => g,
Err(e) => {
tracing::error!("agent relay: AsyncFd readable error: {e}");
break;
}
};
guard.clear_ready();
shared.tx_wake.drain();
while let Some(chunk) = shared.tx_ring.pop() {
buf.extend_from_slice(&chunk);
}
while let Some(frame) = try_extract_frame(&mut buf) {
frames.push(frame);
}
for frame in frames.drain(..) {
let client_slot = frame.id / ID_RANGE_STEP;
let client_slot = client_slot.min(MAX_CLIENTS - 1);
let is_terminal = (frame.flags & FLAG_TERMINAL) != 0;
let writer_result = {
let mut map = clients.lock().await;
if let Some(client) = map.get_mut(&client_slot) {
if is_terminal {
client.active_sessions.remove(&frame.id);
}
Ok(client.write_tx.clone())
} else {
Err(frame.id)
}
};
match writer_result {
Ok(write_tx) => {
if write_tx.send(frame.data).await.is_err() {
tracing::error!("agent relay: write channel closed for slot={client_slot}");
}
}
Err(id) => {
tracing::debug!(
"agent relay: no client for slot={client_slot} id={id} (frame dropped)"
);
}
}
}
}
tracing::debug!("agent relay: ring reader task exiting");
}
async fn read_raw_frame<R: AsyncReadExt + Unpin>(reader: &mut R) -> RuntimeResult<RawFrame> {
let mut len_buf = [0u8; LEN_PREFIX_SIZE];
match reader.read_exact(&mut len_buf).await {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
return Err(RuntimeError::Custom("agent relay: unexpected EOF".into()));
}
Err(e) => return Err(RuntimeError::Io(e)),
}
let frame_len = u32::from_be_bytes(len_buf);
if frame_len > MAX_FRAME_SIZE {
return Err(RuntimeError::Custom(format!(
"agent relay: frame too large: {frame_len} bytes (max {MAX_FRAME_SIZE})"
)));
}
let frame_len = frame_len as usize;
if frame_len < FRAME_HEADER_SIZE {
return Err(RuntimeError::Custom(format!(
"agent relay: frame too short: {frame_len} bytes"
)));
}
let mut data = Vec::with_capacity(LEN_PREFIX_SIZE + frame_len);
data.extend_from_slice(&len_buf);
data.resize(LEN_PREFIX_SIZE + frame_len, 0);
reader.read_exact(&mut data[LEN_PREFIX_SIZE..]).await?;
let id = u32::from_be_bytes([
data[LEN_PREFIX_SIZE],
data[LEN_PREFIX_SIZE + 1],
data[LEN_PREFIX_SIZE + 2],
data[LEN_PREFIX_SIZE + 3],
]);
let flags = data[LEN_PREFIX_SIZE + 4];
Ok(RawFrame {
data: Bytes::from(data),
id,
flags,
})
}
async fn client_reader_task(
slot: u32,
mut reader: OwnedReadHalf,
agent_tx: mpsc::Sender<Vec<u8>>,
clients: Arc<Mutex<HashMap<u32, ClientState>>>,
used_slots: Arc<Mutex<HashSet<u32>>>,
drain_tx: mpsc::Sender<()>,
) {
loop {
let frame = match read_raw_frame(&mut reader).await {
Ok(f) => f,
Err(_) => {
tracing::info!("agent relay: client disconnected slot={slot}");
break;
}
};
let is_session_start = (frame.flags & FLAG_SESSION_START) != 0;
let is_terminal = (frame.flags & FLAG_TERMINAL) != 0;
let is_shutdown = (frame.flags & FLAG_SHUTDOWN) != 0;
if is_shutdown {
tracing::info!("agent relay: client slot={slot} sent core.shutdown, notifying drain");
let _ = drain_tx.try_send(());
}
if is_session_start || is_terminal {
let mut map = clients.lock().await;
if let Some(client) = map.get_mut(&slot) {
if is_session_start {
client.active_sessions.insert(frame.id);
}
if is_terminal {
client.active_sessions.remove(&frame.id);
}
}
}
if agent_tx.send(frame.data.to_vec()).await.is_err() {
tracing::error!("agent relay: ring writer channel closed");
break;
}
}
let active_sessions = {
let mut map = clients.lock().await;
if let Some(client) = map.remove(&slot) {
client.active_sessions
} else {
HashSet::new()
}
};
if !active_sessions.is_empty() {
tracing::info!(
"agent relay: cleaning up {} active sessions for slot={slot}",
active_sessions.len()
);
for session_id in active_sessions {
let kill_msg = match Message::with_payload(
MessageType::ExecSignal,
session_id,
&ExecSignal { signal: 9 }, ) {
Ok(msg) => msg,
Err(e) => {
tracing::error!(
"agent relay: failed to encode SIGKILL for session {session_id}: {e}"
);
continue;
}
};
let mut buf = Vec::new();
if let Err(e) = codec::encode_to_buf(&kill_msg, &mut buf) {
tracing::error!(
"agent relay: failed to encode SIGKILL frame for session {session_id}: {e}"
);
continue;
}
if agent_tx.send(buf).await.is_err() {
tracing::error!("agent relay: ring writer channel closed during cleanup");
break;
}
}
}
used_slots.lock().await.remove(&slot);
tracing::debug!("agent relay: slot={slot} released");
}