use std::collections::HashMap;
use std::os::fd::RawFd;
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
use std::sync::mpsc::{sync_channel, Receiver, RecvTimeoutError, SyncSender};
use std::sync::{Arc, Condvar, Mutex, OnceLock};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use block2::RcBlock;
use dispatch2::{DispatchQueue, DispatchRetained, DispatchTime};
use objc2::rc::Retained;
use objc2::runtime::ProtocolObject;
use objc2::AllocAnyThread;
use objc2_foundation::NSError;
use objc2_virtualization::{
VZVirtioSocketDevice, VZVirtioSocketListener, VZVirtualMachine, VZVirtualMachineDelegate,
};
use crate::desktop::{self, Action, ResponseHeader};
use crate::error::Error;
use crate::vz::config::{build as build_vz_config, resolve_vsock_port, SerialSink};
use crate::vz::delegate::{DelegateState, VmetteDelegate};
use crate::vz::vsock::{ListenerMode, ListenerState, VsockLogger};
use crate::{cmdline, Config, ShareMount, WorkloadStrategy};
const AGENT_CONNECT_TIMEOUT: Duration = Duration::from_secs(60);
const AGENT_READ_TIMEOUT: Duration = Duration::from_secs(30);
struct QueueBound<T>(Retained<T>);
unsafe impl<T> Send for QueueBound<T> {}
unsafe impl<T> Sync for QueueBound<T> {}
impl<T> std::ops::Deref for QueueBound<T> {
type Target = T;
fn deref(&self) -> &T {
&self.0
}
}
#[derive(Debug, Clone)]
pub enum SessionEnd {
Exited(i32),
TimedOut,
Stopped,
Error(String),
}
pub(crate) struct EndSlot {
end: Mutex<Option<SessionEnd>>,
cv: Condvar,
}
impl EndSlot {
fn new() -> Arc<Self> {
Arc::new(Self {
end: Mutex::new(None),
cv: Condvar::new(),
})
}
pub(crate) fn set(&self, e: SessionEnd) {
let mut g = self.end.lock().unwrap();
if g.is_none() {
*g = Some(e);
}
self.cv.notify_all();
}
fn wait_end(&self) -> SessionEnd {
let mut g = self.end.lock().unwrap();
while g.is_none() {
g = self.cv.wait(g).unwrap();
}
g.clone().unwrap()
}
fn is_set(&self) -> bool {
self.end.lock().unwrap().is_some()
}
}
type Reply = Result<(ResponseHeader, Vec<u8>), String>;
struct Demux {
write: Mutex<()>,
next_id: AtomicU32,
waiters: Arc<Mutex<HashMap<u32, SyncSender<Reply>>>>,
poison: Arc<Mutex<Option<String>>>,
fd: RawFd,
}
impl Demux {
fn start(fd: RawFd) -> Demux {
let waiters: Arc<Mutex<HashMap<u32, SyncSender<Reply>>>> =
Arc::new(Mutex::new(HashMap::new()));
let poison: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
let waiters_r = waiters.clone();
let poison_r = poison.clone();
std::thread::spawn(move || demux_reader(fd, waiters_r, poison_r));
Demux {
write: Mutex::new(()),
next_id: AtomicU32::new(0),
waiters,
poison,
fd,
}
}
fn poisoned(&self) -> Option<String> {
self.poison.lock().unwrap().clone()
}
fn poison_with(&self, msg: String) {
let mut p = self.poison.lock().unwrap();
if p.is_none() {
*p = Some(msg);
}
}
fn request(&self, action: &Action) -> Result<(ResponseHeader, Vec<u8>), Error> {
if let Some(msg) = self.poisoned() {
return Err(Error::Vsock(msg));
}
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = sync_channel::<Reply>(1);
self.waiters.lock().unwrap().insert(id, tx);
if let Some(msg) = self.poisoned() {
self.waiters.lock().unwrap().remove(&id);
return Err(Error::Vsock(msg));
}
{
let _w = self.write.lock().unwrap();
let mut stream = FdStream(self.fd);
if let Err(e) = desktop::send_action(&mut stream, id, action) {
self.waiters.lock().unwrap().remove(&id);
let msg = format!("agent request write failed: {e}");
self.poison_with(msg.clone());
return Err(Error::Vsock(msg));
}
}
match rx.recv_timeout(AGENT_READ_TIMEOUT) {
Ok(Ok(resp)) => Ok(resp),
Ok(Err(msg)) => Err(Error::Vsock(msg)),
Err(RecvTimeoutError::Timeout) => {
self.waiters.lock().unwrap().remove(&id);
Err(Error::Vsock(
"timed out waiting for the guest agent response".into(),
))
}
Err(RecvTimeoutError::Disconnected) => {
Err(Error::Vsock(self.poisoned().unwrap_or_else(|| {
"agent connection closed unexpectedly".into()
})))
}
}
}
}
fn demux_reader(
fd: RawFd,
waiters: Arc<Mutex<HashMap<u32, SyncSender<Reply>>>>,
poison: Arc<Mutex<Option<String>>>,
) {
let mut stream = FdStream(fd);
loop {
match desktop::read_response(&mut stream) {
Ok((id, header, payload)) => {
if let Some(tx) = waiters.lock().unwrap().remove(&id) {
let _ = tx.send(Ok((header, payload)));
}
}
Err(e) => {
let msg = format!("agent stream closed: {e}");
{
let mut p = poison.lock().unwrap();
if p.is_none() {
*p = Some(msg.clone());
}
}
let mut w = waiters.lock().unwrap();
for (_, tx) in w.drain() {
let _ = tx.send(Err(msg.clone()));
}
return;
}
}
}
}
pub(crate) struct AgentConn {
workload: WorkloadStrategy,
rx: Mutex<Option<Receiver<RawFd>>>,
fd: Mutex<Option<RawFd>>,
demux: OnceLock<Demux>,
init: Mutex<()>,
}
impl AgentConn {
fn request(&self, action: &Action) -> Result<(ResponseHeader, Vec<u8>), Error> {
if self.workload != WorkloadStrategy::Agent {
return Err(Error::Vsock(
"request() is only valid for Agent-workload sessions".into(),
));
}
self.demux()?.request(action)
}
fn demux(&self) -> Result<&Demux, Error> {
if let Some(d) = self.demux.get() {
return Ok(d);
}
let _g = self.init.lock().unwrap();
if let Some(d) = self.demux.get() {
return Ok(d); }
let fd = self.fd()?;
let _ = self.demux.set(Demux::start(fd));
Ok(self.demux.get().unwrap())
}
fn fd(&self) -> Result<RawFd, Error> {
let mut cached = self.fd.lock().unwrap();
if let Some(fd) = *cached {
return Ok(fd);
}
let rx_guard = self.rx.lock().unwrap();
let rx = rx_guard
.as_ref()
.ok_or_else(|| Error::Vsock("no agent vsock channel (vsock disabled?)".into()))?;
let fd = rx
.recv_timeout(AGENT_CONNECT_TIMEOUT)
.map_err(|_| Error::Vsock("timed out waiting for the guest agent to connect".into()))?;
*cached = Some(fd);
Ok(fd)
}
}
impl Drop for AgentConn {
fn drop(&mut self) {
if let Some(fd) = *self.fd.lock().unwrap() {
unsafe { libc::close(fd) };
}
if let Some(rx) = self.rx.lock().unwrap().take() {
while let Ok(fd) = rx.try_recv() {
unsafe { libc::close(fd) };
}
}
}
}
fn issue_stop(queue: &DispatchQueue, vm: &Retained<VZVirtualMachine>, end: &Arc<EndSlot>) {
if end.is_set() {
return;
}
let vm_for_stop = QueueBound(vm.clone());
let end_for_stop = end.clone();
queue.exec_async(move || {
let stop_cb = RcBlock::new(move |_err: *mut NSError| {
end_for_stop.set(SessionEnd::Stopped);
});
unsafe { vm_for_stop.stopWithCompletionHandler(&stop_cb) };
});
}
struct ControlDirGuard(PathBuf);
impl Drop for ControlDirGuard {
fn drop(&mut self) {
let _ = std::fs::remove_dir_all(&self.0);
}
}
fn unique_temp_path(prefix: &str) -> PathBuf {
static SEQ: AtomicU64 = AtomicU64::new(0);
let n = SEQ.fetch_add(1, Ordering::Relaxed);
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
std::env::temp_dir().join(format!("{}-{}-{}-{}", prefix, std::process::id(), n, nanos))
}
fn make_control_dir() -> Result<PathBuf, Error> {
let dir = unique_temp_path("vmette-ctl");
std::fs::create_dir_all(&dir).map_err(Error::Io)?;
Ok(dir)
}
struct ScratchFileGuard(PathBuf);
impl Drop for ScratchFileGuard {
fn drop(&mut self) {
let _ = std::fs::remove_file(&self.0);
}
}
fn make_scratch_image(mib: u64) -> Result<PathBuf, Error> {
let path = unique_temp_path("vmette-scratch").with_extension("img");
let f = std::fs::File::create(&path).map_err(Error::Io)?;
f.set_len(mib.saturating_mul(1024 * 1024))
.map_err(Error::Io)?;
Ok(path)
}
const CAPTURE_CAP_BYTES: usize = 1024 * 1024;
struct CapturePipe {
write_fd: RawFd,
}
impl Drop for CapturePipe {
fn drop(&mut self) {
unsafe { libc::close(self.write_fd) };
}
}
fn drain_capture(read_fd: RawFd, end: Arc<EndSlot>, tx: std::sync::mpsc::Sender<Vec<u8>>) {
unsafe {
let fl = libc::fcntl(read_fd, libc::F_GETFL);
libc::fcntl(read_fd, libc::F_SETFL, fl | libc::O_NONBLOCK);
}
let mut tmp = [0u8; 8192];
let mut sent: usize = 0;
let mut truncated = false;
let mut grace: u32 = 0;
loop {
let n = unsafe { libc::read(read_fd, tmp.as_mut_ptr() as *mut libc::c_void, tmp.len()) };
if n > 0 {
grace = 0;
if sent < CAPTURE_CAP_BYTES {
let take = (n as usize).min(CAPTURE_CAP_BYTES - sent);
let _ = tx.send(tmp[..take].to_vec());
sent += take;
if take < n as usize && !truncated {
truncated = true;
let _ = tx.send(b"\n[output truncated at 1048576 bytes]\n".to_vec());
}
}
continue;
}
if n == 0 || (end.is_set() && grace >= 25) {
break;
}
if end.is_set() {
grace += 1;
}
std::thread::sleep(Duration::from_millis(20));
}
unsafe { libc::close(read_fd) };
}
pub struct Session {
vm: Retained<VZVirtualMachine>,
queue: DispatchRetained<DispatchQueue>,
end: Arc<EndSlot>,
vsock_port: Option<u32>,
cmdline: String,
agent: Arc<AgentConn>,
_delegate: Retained<VmetteDelegate>,
_vsock_keepalive: Option<(Retained<VsockLogger>, Retained<VZVirtioSocketListener>)>,
_control_dir: Option<ControlDirGuard>,
_scratch_file: Option<ScratchFileGuard>,
_capture: Option<CapturePipe>,
capture_rx: Mutex<Option<std::sync::mpsc::Receiver<Vec<u8>>>>,
}
impl Session {
pub fn start(config: &Config) -> Result<Session, Error> {
let vsock_port = resolve_vsock_port(config.vsock_port);
let mut working = config.clone();
if config.shares.iter().any(|s| s.tag == crate::CTL_SHARE_TAG) {
return Err(Error::InvalidConfig(
"share tag \"ctl\" is reserved for the boot/exit channel".into(),
));
}
let ctl_dir = make_control_dir()?;
working.shares.push(ShareMount {
tag: crate::CTL_SHARE_TAG.into(),
path: ctl_dir.clone(),
});
let control_dir = Some(ControlDirGuard(ctl_dir.clone()));
let writable_root = match &config.rootfs {
Some(crate::Rootfs::Block(_)) => true,
Some(crate::Rootfs::Share(rs)) => !rs.read_only,
None => false,
};
let exit_code_file = if writable_root {
let p = ctl_dir.join(".vmette-exit");
let _ = std::fs::remove_file(&p);
Some(p)
} else {
None
};
let scratch_file = match (writable_root, config.scratch_mib) {
(true, Some(mib)) => Some(ScratchFileGuard(make_scratch_image(mib)?)),
_ => None,
};
let scratch_path = scratch_file.as_ref().map(|g| g.0.clone());
let mut cmdline = cmdline::build(&working, vsock_port);
let capture: Option<(RawFd, CapturePipe)> = if config.capture_output {
let mut fds: [libc::c_int; 2] = [0; 2];
if unsafe { libc::pipe(fds.as_mut_ptr()) } != 0 {
return Err(Error::Io(std::io::Error::last_os_error()));
}
cmdline = cmdline.replace("console=hvc0", "console=hvc1");
Some((fds[0], CapturePipe { write_fd: fds[1] }))
} else {
None
};
let sink = match &capture {
Some((_, p)) => SerialSink::Capture {
write_fd: p.write_fd,
},
None => SerialSink::Inherit,
};
let (cfg, scratch_dev) = build_vz_config(
&working,
&cmdline,
vsock_port,
scratch_path.as_deref(),
sink,
)?;
if let Some(guard) = &control_dir {
let params = crate::boot::from_config(config, scratch_dev.as_deref());
std::fs::write(guard.0.join("boot.env"), crate::boot::to_env(¶ms))
.map_err(Error::Io)?;
}
let queue = DispatchQueue::new("com.chamuka.vmette.session", None);
let vm = unsafe {
VZVirtualMachine::initWithConfiguration_queue(VZVirtualMachine::alloc(), &cfg, &queue)
};
let end = EndSlot::new();
let timed_out = Arc::new(AtomicBool::new(false));
let capture_rx = match &capture {
Some((read_fd, _)) => {
let (tx, rx) = std::sync::mpsc::channel::<Vec<u8>>();
let end_for_reader = end.clone();
let read_fd = *read_fd;
std::thread::spawn(move || drain_capture(read_fd, end_for_reader, tx));
Some(rx)
}
None => None,
};
let delegate = VmetteDelegate::new(DelegateState {
exit_code_file,
timed_out: timed_out.clone(),
end: end.clone(),
});
let mut agent_rx = None;
let vsock_keepalive = if let Some(port) = vsock_port {
let mode = match config.workload {
WorkloadStrategy::Agent => {
let (tx, rx) = sync_channel::<RawFd>(1);
agent_rx = Some(rx);
ListenerMode::Agent {
fd_tx: Mutex::new(Some(tx)),
}
}
WorkloadStrategy::OneShot => ListenerMode::Echo {
ready_handler: Arc::new(Mutex::new(None)),
},
};
let logger = VsockLogger::new(ListenerState { port, mode });
let listener = unsafe { VZVirtioSocketListener::new() };
unsafe {
listener.setDelegate(Some(ProtocolObject::from_ref(&*logger)));
}
Some((logger, listener))
} else {
None
};
let agent = Arc::new(AgentConn {
workload: config.workload,
rx: Mutex::new(agent_rx),
fd: Mutex::new(None),
demux: OnceLock::new(),
init: Mutex::new(()),
});
let setup_vm = QueueBound(vm.clone());
let setup_delegate = QueueBound(delegate.clone());
let setup_listener = vsock_keepalive
.as_ref()
.map(|(_, l)| (QueueBound(l.clone()), vsock_port.unwrap_or(0)));
queue.exec_sync(move || unsafe {
let proto: &ProtocolObject<dyn VZVirtualMachineDelegate> =
ProtocolObject::from_ref(&*setup_delegate.0);
setup_vm.setDelegate(Some(proto));
if let Some((listener, port)) = &setup_listener {
let sock_dev = setup_vm.socketDevices();
if let Some(dev) = sock_dev.firstObject() {
let dev: Retained<VZVirtioSocketDevice> = Retained::cast_unchecked(dev);
dev.setSocketListener_forPort(listener, *port);
}
}
});
if let Some(secs) = config.timeout_seconds {
let vm_for_timer = QueueBound(vm.clone());
let timed_out_setter = timed_out.clone();
let end_for_timer = end.clone();
let when = DispatchTime::try_from(Duration::from_secs(secs as u64))
.unwrap_or(DispatchTime::NOW);
let _ = queue.after(when, move || {
timed_out_setter.store(true, Ordering::SeqCst);
let end_for_stop = end_for_timer.clone();
let stop_cb = RcBlock::new(move |_err: *mut NSError| {
end_for_stop.set(SessionEnd::TimedOut);
});
unsafe { vm_for_timer.stopWithCompletionHandler(&stop_cb) };
});
}
let vm_for_start = QueueBound(vm.clone());
let end_for_start = end.clone();
queue.exec_async(move || {
let start_cb = RcBlock::new(move |err: *mut NSError| {
if !err.is_null() {
let err = unsafe { &*err };
end_for_start.set(SessionEnd::Error(format!(
"vm.start failed: {}",
err.localizedDescription()
)));
}
});
unsafe { vm_for_start.startWithCompletionHandler(&start_cb) };
});
Ok(Session {
vm,
queue,
end,
vsock_port,
cmdline,
agent,
_delegate: delegate,
_vsock_keepalive: vsock_keepalive,
_control_dir: control_dir,
_scratch_file: scratch_file,
_capture: capture.map(|(_, p)| p),
capture_rx: Mutex::new(capture_rx),
})
}
pub fn vsock_port(&self) -> Option<u32> {
self.vsock_port
}
pub fn cmdline(&self) -> &str {
&self.cmdline
}
pub fn wait(&self) -> SessionEnd {
self.end.wait_end()
}
pub fn wait_captured(&self) -> crate::RunOutput {
let end = self.end.wait_end();
let mut out = Vec::new();
if let Some(rx) = self.capture_rx.lock().unwrap().take() {
for chunk in rx {
out.extend_from_slice(&chunk);
}
}
let exit_code = match end {
SessionEnd::Exited(code) => code,
SessionEnd::TimedOut => 124,
SessionEnd::Stopped => 0,
SessionEnd::Error(_) => 1,
};
crate::RunOutput {
exit_code,
output: String::from_utf8_lossy(&out).replace("\r\n", "\n"),
}
}
pub fn capture_rx(&self) -> Option<std::sync::mpsc::Receiver<Vec<u8>>> {
self.capture_rx.lock().unwrap().take()
}
pub fn stop(&self) {
issue_stop(&self.queue, &self.vm, &self.end);
}
pub fn request(&self, action: &Action) -> Result<(ResponseHeader, Vec<u8>), Error> {
self.agent.request(action)
}
pub fn client(&self) -> SessionClient {
SessionClient {
agent: self.agent.clone(),
}
}
pub fn stop_handle(&self) -> StopHandle {
StopHandle {
vm: QueueBound(self.vm.clone()),
queue: self.queue.clone(),
end: self.end.clone(),
}
}
}
#[derive(Clone)]
pub struct SessionClient {
agent: Arc<AgentConn>,
}
impl SessionClient {
pub fn request(&self, action: &Action) -> Result<(ResponseHeader, Vec<u8>), Error> {
self.agent.request(action)
}
}
pub struct StopHandle {
vm: QueueBound<VZVirtualMachine>,
queue: DispatchRetained<DispatchQueue>,
end: Arc<EndSlot>,
}
impl StopHandle {
pub fn stop(&self) {
issue_stop(&self.queue, &self.vm.0, &self.end);
}
}
struct FdStream(RawFd);
impl std::io::Read for FdStream {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let n = unsafe { libc::read(self.0, buf.as_mut_ptr() as *mut _, buf.len()) };
if n < 0 {
Err(std::io::Error::last_os_error())
} else {
Ok(n as usize)
}
}
}
impl std::io::Write for FdStream {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let n = unsafe { libc::write(self.0, buf.as_ptr() as *const _, buf.len()) };
if n < 0 {
Err(std::io::Error::last_os_error())
} else {
Ok(n as usize)
}
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}