#[cfg(unix)]
use std::collections::VecDeque;
#[cfg(windows)]
use std::ffi::OsString;
#[cfg(unix)]
use std::io;
#[cfg(unix)]
use std::os::fd::RawFd;
use std::sync::Arc;
#[cfg(unix)]
use std::sync::Mutex;
#[cfg(windows)]
use std::time::Duration;
use crossbeam_queue::ArrayQueue;
use microsandbox_utils::wake_pipe::WakePipe;
#[cfg(unix)]
use msb_krun::ConsolePortBackend;
#[cfg(windows)]
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[cfg(windows)]
use tokio::net::windows::named_pipe::{NamedPipeServer, PipeMode, ServerOptions};
const DEFAULT_QUEUE_CAPACITY: usize = 2048;
#[cfg(windows)]
const NAMED_PIPE_BRIDGE_BUFFER_SIZE: usize = 8192;
#[cfg(windows)]
const NAMED_PIPE_BRIDGE_TX_POLL_INTERVAL: Duration = Duration::from_millis(1);
pub struct ConsoleSharedState {
pub tx_ring: ArrayQueue<Vec<u8>>,
pub rx_ring: ArrayQueue<Vec<u8>>,
pub tx_wake: WakePipe,
pub rx_wake: WakePipe,
}
pub struct AgentConsoleBackend {
#[cfg(unix)]
shared: Arc<ConsoleSharedState>,
#[cfg(unix)]
pending: Mutex<VecDeque<u8>>,
}
#[cfg(windows)]
pub(crate) struct AgentConsolePipeBridge {
task: tokio::task::JoinHandle<()>,
}
impl ConsoleSharedState {
pub fn new() -> Self {
Self::with_capacity(DEFAULT_QUEUE_CAPACITY)
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
tx_ring: ArrayQueue::new(capacity),
rx_ring: ArrayQueue::new(capacity),
tx_wake: WakePipe::new(),
rx_wake: WakePipe::new(),
}
}
}
impl AgentConsoleBackend {
pub fn new(shared: Arc<ConsoleSharedState>) -> Self {
#[cfg(unix)]
{
Self {
shared,
pending: Mutex::new(VecDeque::new()),
}
}
#[cfg(windows)]
{
let _ = shared;
Self {}
}
}
}
#[cfg(windows)]
impl AgentConsolePipeBridge {
pub(crate) fn spawn(
pipe_name: impl Into<OsString>,
shared: Arc<ConsoleSharedState>,
handle: &tokio::runtime::Handle,
) -> std::io::Result<Self> {
let pipe_name = pipe_name.into();
let server = {
let _guard = handle.enter();
ServerOptions::new()
.first_pipe_instance(true)
.pipe_mode(PipeMode::Byte)
.create(&pipe_name)?
};
let task = handle.spawn(async move {
if let Err(error) = run_agent_console_pipe_bridge(server, shared).await {
tracing::warn!(error = %error, "agent console named-pipe bridge stopped");
}
});
Ok(Self { task })
}
}
impl Default for ConsoleSharedState {
fn default() -> Self {
Self::new()
}
}
#[cfg(windows)]
impl Drop for AgentConsolePipeBridge {
fn drop(&mut self) {
self.task.abort();
}
}
#[cfg(unix)]
impl ConsolePortBackend for AgentConsoleBackend {
fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
self.shared.rx_wake.drain();
let mut pending = self.pending.lock().unwrap();
if !pending.is_empty() {
let n = pending.len().min(buf.len());
let (head, tail) = pending.as_slices();
let from_head = n.min(head.len());
buf[..from_head].copy_from_slice(&head[..from_head]);
if from_head < n {
let from_tail = n - from_head;
buf[from_head..n].copy_from_slice(&tail[..from_tail]);
}
pending.drain(..n);
return Ok(n);
}
match self.shared.rx_ring.pop() {
Some(chunk) => {
let n = chunk.len().min(buf.len());
buf[..n].copy_from_slice(&chunk[..n]);
if chunk.len() > buf.len() {
pending.extend(&chunk[buf.len()..]);
}
Ok(n)
}
None => Err(io::ErrorKind::WouldBlock.into()),
}
}
fn write(&self, buf: &[u8]) -> io::Result<usize> {
self.shared
.tx_ring
.push(buf.to_vec())
.map_err(|_| io::Error::from(io::ErrorKind::WouldBlock))?;
self.shared.tx_wake.wake();
Ok(buf.len())
}
fn read_wake_fd(&self) -> RawFd {
self.shared.rx_wake.as_raw_fd()
}
}
#[cfg(windows)]
async fn run_agent_console_pipe_bridge(
server: NamedPipeServer,
shared: Arc<ConsoleSharedState>,
) -> std::io::Result<()> {
server.connect().await?;
tracing::debug!("agent console named-pipe bridge connected");
let (reader, writer) = tokio::io::split(server);
let reader_shared = Arc::clone(&shared);
let mut reader_task =
tokio::spawn(async move { bridge_guest_to_host(reader, reader_shared).await });
let mut writer_task = tokio::spawn(async move { bridge_host_to_guest(writer, shared).await });
tokio::select! {
result = &mut reader_task => {
writer_task.abort();
result.map_err(std::io::Error::other)?
}
result = &mut writer_task => {
reader_task.abort();
result.map_err(std::io::Error::other)?
}
}
}
#[cfg(windows)]
async fn bridge_guest_to_host(
mut reader: tokio::io::ReadHalf<NamedPipeServer>,
shared: Arc<ConsoleSharedState>,
) -> std::io::Result<()> {
let mut buf = vec![0u8; NAMED_PIPE_BRIDGE_BUFFER_SIZE];
loop {
let n = reader.read(&mut buf).await?;
if n == 0 {
return Ok(());
}
push_queue_lossless(&shared.tx_ring, buf[..n].to_vec()).await;
shared.tx_wake.wake();
}
}
#[cfg(windows)]
async fn bridge_host_to_guest(
mut writer: tokio::io::WriteHalf<NamedPipeServer>,
shared: Arc<ConsoleSharedState>,
) -> std::io::Result<()> {
loop {
let mut wrote = false;
while let Some(chunk) = shared.rx_ring.pop() {
writer.write_all(&chunk).await?;
wrote = true;
}
if wrote {
writer.flush().await?;
continue;
}
shared.rx_wake.drain();
tokio::time::sleep(NAMED_PIPE_BRIDGE_TX_POLL_INTERVAL).await;
}
}
#[cfg(windows)]
async fn push_queue_lossless(queue: &ArrayQueue<Vec<u8>>, mut chunk: Vec<u8>) {
loop {
match queue.push(chunk) {
Ok(()) => return,
Err(returned) => {
chunk = returned;
tokio::time::sleep(NAMED_PIPE_BRIDGE_TX_POLL_INTERVAL).await;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(unix)]
#[test]
fn backend_write_and_read_roundtrip() {
let shared = Arc::new(ConsoleSharedState::new());
let backend = AgentConsoleBackend::new(Arc::clone(&shared));
assert_eq!(backend.write(b"hello").unwrap(), 5);
let chunk = shared.tx_ring.pop().unwrap();
assert_eq!(chunk, b"hello");
shared.rx_ring.push(b"world".to_vec()).unwrap();
shared.rx_wake.wake();
let mut buf = [0u8; 16];
let n = backend.read(&mut buf).unwrap();
assert_eq!(&buf[..n], b"world");
}
#[cfg(unix)]
#[test]
fn backend_read_empty_returns_would_block() {
let shared = Arc::new(ConsoleSharedState::new());
let backend = AgentConsoleBackend::new(shared);
let mut buf = [0u8; 16];
let err = backend.read(&mut buf).unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
}
#[cfg(unix)]
#[test]
fn backend_write_full_returns_would_block() {
let shared = Arc::new(ConsoleSharedState::with_capacity(1));
let backend = AgentConsoleBackend::new(shared);
assert!(backend.write(b"a").is_ok());
let err = backend.write(b"b").unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
}
#[cfg(unix)]
#[test]
fn backend_read_drains_rx_wake_pipe() {
let shared = Arc::new(ConsoleSharedState::new());
let backend = AgentConsoleBackend::new(Arc::clone(&shared));
shared.rx_ring.push(b"ping".to_vec()).unwrap();
shared.rx_wake.wake();
let mut pollfd = libc::pollfd {
fd: backend.read_wake_fd(),
events: libc::POLLIN,
revents: 0,
};
let ret = unsafe { libc::poll(&mut pollfd, 1, 0) };
assert_eq!(ret, 1, "wake pipe should be readable before read()");
assert_ne!(pollfd.revents & libc::POLLIN, 0);
let mut buf = [0u8; 8];
let n = backend.read(&mut buf).unwrap();
assert_eq!(&buf[..n], b"ping");
pollfd.revents = 0;
let ret = unsafe { libc::poll(&mut pollfd, 1, 0) };
assert_eq!(ret, 0, "wake pipe should be drained by read()");
}
#[cfg(windows)]
#[tokio::test]
async fn named_pipe_bridge_exchanges_agent_bytes() {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::windows::named_pipe::ClientOptions;
let pipe_name = unique_named_pipe("console-bridge");
let shared = Arc::new(ConsoleSharedState::new());
let _bridge = AgentConsolePipeBridge::spawn(
&pipe_name,
Arc::clone(&shared),
&tokio::runtime::Handle::current(),
)
.unwrap();
let mut client = ClientOptions::new().open(&pipe_name).unwrap();
client.write_all(b"guest-ready").await.unwrap();
tokio::time::timeout(Duration::from_secs(1), async {
loop {
if let Some(bytes) = shared.tx_ring.pop() {
assert_eq!(bytes, b"guest-ready");
return;
}
tokio::time::sleep(Duration::from_millis(1)).await;
}
})
.await
.unwrap();
shared.rx_ring.push(b"host-ack".to_vec()).unwrap();
shared.rx_wake.wake();
let mut buf = [0u8; 8];
tokio::time::timeout(Duration::from_secs(1), client.read_exact(&mut buf))
.await
.unwrap()
.unwrap();
assert_eq!(&buf, b"host-ack");
}
#[cfg(windows)]
fn unique_named_pipe(name: &str) -> String {
let id =
std::sync::atomic::AtomicU64::new(0).fetch_add(1, std::sync::atomic::Ordering::Relaxed);
format!(r"\\.\pipe\msb-runtime-{name}-{}-{id}", std::process::id())
}
}