use std::collections::HashMap;
use std::path::Path;
use std::sync::{
Arc,
atomic::{AtomicU32, Ordering},
};
use std::time::Duration;
use microsandbox_protocol::{
codec::{self, MAX_FRAME_SIZE, RawFrame},
core::Ready,
message::{FLAG_TERMINAL, FRAME_HEADER_SIZE, Message, MessageType, PROTOCOL_VERSION},
};
use serde::Serialize;
use tokio::io::AsyncReadExt;
use tokio::net::UnixStream;
use tokio::sync::{Mutex, mpsc};
use tokio::task::JoinHandle;
use tokio::time::Instant;
use super::error::{AgentClientError, AgentClientResult};
const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
const LEGACY_PROTOCOL_VERSION: u8 = 1;
const LEGACY_RELAY_ID_RANGE_STEP: u32 = u32::MAX / 16;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AgentProtocol {
Current,
LegacyV1,
}
pub struct AgentClient {
writer: Arc<Mutex<tokio::net::unix::OwnedWriteHalf>>,
next_id: AtomicU32,
id_min: u32,
id_max: u32,
protocol: AgentProtocol,
pending: Arc<Mutex<HashMap<u32, mpsc::UnboundedSender<RawFrame>>>>,
reader_handle: JoinHandle<()>,
ready_body: Vec<u8>,
ready: Ready,
}
impl AgentProtocol {
fn version(self) -> u8 {
match self {
Self::Current => PROTOCOL_VERSION,
Self::LegacyV1 => LEGACY_PROTOCOL_VERSION,
}
}
}
impl AgentClient {
pub async fn connect(sock_path: impl AsRef<Path>) -> AgentClientResult<Self> {
Self::connect_with_timeout(sock_path, DEFAULT_HANDSHAKE_TIMEOUT).await
}
pub async fn connect_with_timeout(
sock_path: impl AsRef<Path>,
timeout: Duration,
) -> AgentClientResult<Self> {
let deadline = Instant::now() + timeout;
Self::connect_with_deadline(sock_path, deadline).await
}
pub async fn connect_with_deadline(
sock_path: impl AsRef<Path>,
deadline: Instant,
) -> AgentClientResult<Self> {
let sock_path = sock_path.as_ref();
let stream =
UnixStream::connect(sock_path)
.await
.map_err(|source| AgentClientError::Connect {
path: sock_path.to_path_buf(),
source,
})?;
let (mut reader, writer) = stream.into_split();
let mut range_buf = [0u8; 8];
tokio::time::timeout_at(deadline, reader.read_exact(&mut range_buf))
.await
.map_err(|_| {
AgentClientError::Handshake(
"read id range: timed out before relay sent bytes".into(),
)
})?
.map_err(|e| AgentClientError::Handshake(format!("read id range: {e}")))?;
let id_start_or_offset = u32::from_be_bytes(range_buf[0..4].try_into().unwrap());
let id_max_or_frame_len = u32::from_be_bytes(range_buf[4..8].try_into().unwrap());
let legacy_handshake =
looks_like_legacy_relay_handshake(id_start_or_offset, id_max_or_frame_len);
let (id_min, id_max, ready_frame, protocol) = if legacy_handshake {
let id_offset = id_start_or_offset;
let ready_frame = read_raw_frame_after_len_prefix(
&mut reader,
range_buf[4..8].try_into().unwrap(),
deadline,
)
.await?;
(
id_offset.saturating_add(1),
id_offset.saturating_add(LEGACY_RELAY_ID_RANGE_STEP),
ready_frame,
AgentProtocol::LegacyV1,
)
} else if id_start_or_offset >= id_max_or_frame_len {
return Err(AgentClientError::Handshake(format!(
"invalid relay id range: start={id_start_or_offset}, end={id_max_or_frame_len}"
)));
} else {
let ready_frame = tokio::time::timeout_at(deadline, codec::read_raw_frame(&mut reader))
.await
.map_err(|_| {
AgentClientError::Handshake(
"read ready frame: timed out before relay sent frame".into(),
)
})?
.map_err(|e| AgentClientError::Handshake(format!("read ready frame: {e}")))?;
(
id_start_or_offset,
id_max_or_frame_len,
ready_frame,
AgentProtocol::Current,
)
};
let ready_msg = codec::raw_frame_to_message(ready_frame.clone())
.map_err(|e| AgentClientError::Handshake(format!("decode ready frame: {e}")))?;
if ready_msg.t != MessageType::Ready {
return Err(AgentClientError::Handshake(format!(
"expected core.ready frame, got {}",
ready_msg.t.as_str()
)));
}
let ready: Ready = ready_msg
.payload()
.map_err(|e| AgentClientError::Handshake(format!("decode ready payload: {e}")))?;
tracing::info!(
id_min,
id_max,
protocol = ?protocol,
ready_bytes = ready_frame.body.len(),
boot_time_ns = ready.boot_time_ns,
"agent client: connected to relay"
);
if protocol == AgentProtocol::LegacyV1 {
tracing::warn!(
"agent client: connected to a sandbox started before microsandbox 0.5; exec compatibility is temporary and filesystem/SFTP require stop/start"
);
}
let pending: Arc<Mutex<HashMap<u32, mpsc::UnboundedSender<RawFrame>>>> =
Arc::new(Mutex::new(HashMap::new()));
let reader_handle = tokio::spawn(reader_loop(reader, Arc::clone(&pending)));
let writer = Arc::new(Mutex::new(writer));
Ok(Self {
writer,
next_id: AtomicU32::new(first_request_id(id_min)),
id_min,
id_max,
protocol,
pending,
reader_handle,
ready_body: ready_frame.body,
ready,
})
}
pub async fn connect_sandbox(name: &str) -> AgentClientResult<Self> {
Self::connect_sandbox_with_timeout(name, DEFAULT_HANDSHAKE_TIMEOUT).await
}
pub async fn connect_sandbox_with_timeout(
name: &str,
timeout: Duration,
) -> AgentClientResult<Self> {
if let Some(message) = crate::sandbox::sandbox_name_validation_message(name) {
return Err(AgentClientError::InvalidSandboxName(message));
}
let mut last_error = None;
for sock_path in crate::runtime::sandbox_agent_socket_path_candidates(name) {
if !sock_path.exists() {
continue;
}
match Self::connect_with_timeout(&sock_path, timeout).await {
Ok(client) => return Ok(client),
Err(error) => last_error = Some(error),
}
}
match last_error {
Some(error) => Err(error),
None => Err(AgentClientError::SandboxNotFound(name.to_string())),
}
}
pub async fn close(self) {
}
}
impl AgentClient {
pub async fn request_raw(&self, flags: u8, body: Vec<u8>) -> AgentClientResult<RawFrame> {
let id = self.alloc_id();
let (tx, mut rx) = mpsc::unbounded_channel();
self.pending.lock().await.insert(id, tx);
if let Err(e) = self.write_frame(id, flags, &body).await {
self.pending.lock().await.remove(&id);
return Err(e);
}
rx.recv().await.ok_or(AgentClientError::ReaderClosed(id))
}
pub async fn stream_raw(
&self,
flags: u8,
body: Vec<u8>,
) -> AgentClientResult<(u32, mpsc::UnboundedReceiver<RawFrame>)> {
let id = self.alloc_id();
let (tx, rx) = mpsc::unbounded_channel();
self.pending.lock().await.insert(id, tx);
if let Err(e) = self.write_frame(id, flags, &body).await {
self.pending.lock().await.remove(&id);
return Err(e);
}
Ok((id, rx))
}
pub async fn send_raw(&self, id: u32, flags: u8, body: &[u8]) -> AgentClientResult<()> {
self.write_frame(id, flags, body).await
}
pub fn ready_bytes(&self) -> &[u8] {
&self.ready_body
}
pub fn protocol(&self) -> AgentProtocol {
self.protocol
}
pub fn is_legacy_protocol(&self) -> bool {
self.protocol == AgentProtocol::LegacyV1
}
}
impl AgentClient {
pub async fn request<T: Serialize>(
&self,
t: MessageType,
payload: &T,
) -> AgentClientResult<Message> {
let flags = t.flags();
let body = encode_message_body(self.protocol.version(), t, payload)?;
let frame = self.request_raw(flags, body).await?;
Ok(codec::raw_frame_to_message(frame)?)
}
pub async fn stream<T: Serialize>(
&self,
t: MessageType,
payload: &T,
) -> AgentClientResult<(u32, mpsc::UnboundedReceiver<Message>)> {
let flags = t.flags();
let body = encode_message_body(self.protocol.version(), t, payload)?;
let (id, raw_rx) = self.stream_raw(flags, body).await?;
let (tx, rx) = mpsc::unbounded_channel();
tokio::spawn(decode_stream_task(raw_rx, tx));
Ok((id, rx))
}
pub async fn send<T: Serialize>(
&self,
id: u32,
t: MessageType,
payload: &T,
) -> AgentClientResult<()> {
let flags = t.flags();
let body = encode_message_body(self.protocol.version(), t, payload)?;
self.write_frame(id, flags, &body).await
}
pub fn ready(&self) -> AgentClientResult<Ready> {
Ok(self.ready.clone())
}
}
impl AgentClient {
fn alloc_id(&self) -> u32 {
loop {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
if id != 0 && id >= self.id_min && id < self.id_max {
return id;
}
self.next_id
.store(first_request_id(self.id_min), Ordering::Relaxed);
}
}
async fn write_frame(&self, id: u32, flags: u8, body: &[u8]) -> AgentClientResult<()> {
let mut buf = Vec::with_capacity(4 + 5 + body.len());
codec::encode_raw_to_buf(
&RawFrame {
id,
flags,
body: body.to_vec(),
},
&mut buf,
)?;
let mut writer = self.writer.lock().await;
tokio::io::AsyncWriteExt::write_all(&mut *writer, &buf).await?;
Ok(())
}
}
fn looks_like_legacy_relay_handshake(_id_min: u32, id_max: u32) -> bool {
id_max >= FRAME_HEADER_SIZE as u32 && id_max <= MAX_FRAME_SIZE
}
fn first_request_id(id_min: u32) -> u32 {
id_min.max(1)
}
async fn read_raw_frame_after_len_prefix<R: AsyncReadExt + Unpin>(
reader: &mut R,
len_buf: [u8; 4],
deadline: Instant,
) -> AgentClientResult<RawFrame> {
let frame_len = u32::from_be_bytes(len_buf);
if frame_len > MAX_FRAME_SIZE {
return Err(AgentClientError::Handshake(format!(
"legacy ready frame too large: {frame_len} bytes (max {MAX_FRAME_SIZE})"
)));
}
if frame_len < FRAME_HEADER_SIZE as u32 {
return Err(AgentClientError::Handshake(format!(
"legacy ready frame too short: {frame_len} bytes"
)));
}
let mut data = vec![0u8; frame_len as usize];
tokio::time::timeout_at(deadline, reader.read_exact(&mut data))
.await
.map_err(|_| {
AgentClientError::Handshake(
"read legacy ready frame: timed out before relay sent frame".into(),
)
})?
.map_err(|e| AgentClientError::Handshake(format!("read legacy ready frame: {e}")))?;
let id = u32::from_be_bytes(data[0..4].try_into().unwrap());
let flags = data[4];
let body = data[FRAME_HEADER_SIZE..].to_vec();
Ok(RawFrame { id, flags, body })
}
async fn reader_loop(
mut reader: tokio::net::unix::OwnedReadHalf,
pending: Arc<Mutex<HashMap<u32, mpsc::UnboundedSender<RawFrame>>>>,
) {
loop {
let frame = match codec::read_raw_frame(&mut reader).await {
Ok(frame) => frame,
Err(e) => {
tracing::debug!("agent client: reader EOF or error: {e}");
break;
}
};
let id = frame.id;
let is_terminal = (frame.flags & FLAG_TERMINAL) != 0;
let mut map = pending.lock().await;
if let Some(tx) = map.get(&id) {
if tx.send(frame).is_err() {
map.remove(&id);
} else if is_terminal {
map.remove(&id);
}
} else {
tracing::trace!("agent client: no pending handler for id={id}");
}
}
let mut map = pending.lock().await;
map.clear();
}
async fn decode_stream_task(
mut raw_rx: mpsc::UnboundedReceiver<RawFrame>,
tx: mpsc::UnboundedSender<Message>,
) {
while let Some(frame) = raw_rx.recv().await {
match codec::raw_frame_to_message(frame) {
Ok(msg) => {
if tx.send(msg).is_err() {
break;
}
}
Err(e) => {
tracing::warn!("agent client: failed to decode frame in stream: {e}");
}
}
}
}
fn encode_message_body<T: Serialize>(
version: u8,
t: MessageType,
payload: &T,
) -> AgentClientResult<Vec<u8>> {
let mut msg = Message::with_payload(t, 0, payload)?;
msg.v = version;
let mut body = Vec::new();
ciborium::into_writer(&msg, &mut body).map_err(microsandbox_protocol::ProtocolError::from)?;
Ok(body)
}
#[cfg(test)]
mod tests {
use microsandbox_protocol::core::Ready;
use microsandbox_protocol::exec::ExecRequest;
use tokio::io::AsyncWriteExt;
use tokio::net::UnixListener;
use tokio::sync::oneshot;
use super::*;
#[tokio::test]
async fn connect_decodes_ready_payload() {
let temp = tempfile::tempdir().unwrap();
let sock_path = temp.path().join("agent.sock");
let listener = UnixListener::bind(&sock_path).unwrap();
let ready = Ready {
boot_time_ns: 11,
init_time_ns: 22,
ready_time_ns: 33,
};
let ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
tokio::spawn(async move {
let (mut socket, _) = listener.accept().await.unwrap();
socket.write_all(&1u32.to_be_bytes()).await.unwrap();
socket
.write_all(µsandbox_protocol::AGENT_RELAY_ID_RANGE_STEP.to_be_bytes())
.await
.unwrap();
codec::write_message(&mut socket, &ready_msg).await.unwrap();
});
let client =
AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
.await
.unwrap();
assert_eq!(client.protocol(), AgentProtocol::Current);
let decoded = client.ready().unwrap();
assert_eq!(decoded.boot_time_ns, ready.boot_time_ns);
assert_eq!(decoded.init_time_ns, ready.init_time_ns);
assert_eq!(decoded.ready_time_ns, ready.ready_time_ns);
let raw_msg: Message = ciborium::from_reader(client.ready_bytes()).unwrap();
assert_eq!(raw_msg.t, MessageType::Ready);
}
#[tokio::test]
async fn connect_accepts_legacy_relay_handshake() {
assert_accepts_legacy_relay_handshake(0).await;
assert_accepts_legacy_relay_handshake(268_435_455).await;
}
#[tokio::test]
async fn legacy_relay_requests_use_v1_and_legacy_id_range() {
let temp = tempfile::tempdir().unwrap();
let sock_path = temp.path().join("agent.sock");
let listener = UnixListener::bind(&sock_path).unwrap();
let ready = Ready {
boot_time_ns: 11,
init_time_ns: 22,
ready_time_ns: 33,
};
let ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
let id_offset = 268_435_455u32;
let (frame_tx, frame_rx) = oneshot::channel();
tokio::spawn(async move {
let (mut socket, _) = listener.accept().await.unwrap();
socket.write_all(&id_offset.to_be_bytes()).await.unwrap();
codec::write_message(&mut socket, &ready_msg).await.unwrap();
let frame = codec::read_raw_frame(&mut socket).await.unwrap();
frame_tx.send(frame).unwrap();
});
let client =
AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
.await
.unwrap();
let request = ExecRequest {
cmd: "/bin/true".into(),
args: Vec::new(),
env: Vec::new(),
cwd: None,
user: None,
tty: false,
rows: 24,
cols: 80,
rlimits: Vec::new(),
};
let (id, _rx) = client
.stream(MessageType::ExecRequest, &request)
.await
.unwrap();
let frame = frame_rx.await.unwrap();
let message = codec::raw_frame_to_message(frame).unwrap();
assert_eq!(id, id_offset + 1);
assert_eq!(message.id, id_offset + 1);
assert_eq!(message.v, LEGACY_PROTOCOL_VERSION);
assert_eq!(message.t, MessageType::ExecRequest);
}
async fn assert_accepts_legacy_relay_handshake(id_offset: u32) {
let temp = tempfile::tempdir().unwrap();
let sock_path = temp.path().join("agent.sock");
let listener = UnixListener::bind(&sock_path).unwrap();
let ready = Ready {
boot_time_ns: 11,
init_time_ns: 22,
ready_time_ns: 33,
};
let ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
tokio::spawn(async move {
let (mut socket, _) = listener.accept().await.unwrap();
socket.write_all(&id_offset.to_be_bytes()).await.unwrap();
codec::write_message(&mut socket, &ready_msg).await.unwrap();
});
let client =
AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
.await
.unwrap();
assert_eq!(client.protocol(), AgentProtocol::LegacyV1);
let decoded = client.ready().unwrap();
assert_eq!(decoded.boot_time_ns, ready.boot_time_ns);
assert_eq!(decoded.init_time_ns, ready.init_time_ns);
assert_eq!(decoded.ready_time_ns, ready.ready_time_ns);
}
}
impl Drop for AgentClient {
fn drop(&mut self) {
self.reader_handle.abort();
}
}