use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread::{self, JoinHandle};
use std::time::{Duration, Instant};
use crate::error::{VmRuntimeError, VmRuntimeResult};
const SHUTDOWN_JOIN_BUDGET: Duration = Duration::from_millis(200);
const SHUTDOWN_POLL_INTERVAL: Duration = Duration::from_millis(5);
#[derive(Debug, Clone)]
pub struct UffdConfig {
pub socket_path: PathBuf,
pub mem_file_path: PathBuf,
}
#[derive(Debug)]
pub struct UffdHandler {
socket_path: PathBuf,
state: Arc<HandlerState>,
join: Option<JoinHandle<()>>,
}
#[derive(Debug)]
struct HandlerState {
alive: AtomicBool,
shutdown: AtomicBool,
}
impl UffdHandler {
pub fn start(config: UffdConfig) -> VmRuntimeResult<Self> {
platform::start(config)
}
pub fn is_alive(&self) -> bool {
self.state.alive.load(Ordering::SeqCst)
}
pub fn shutdown(&mut self) {
self.state.shutdown.store(true, Ordering::SeqCst);
let _ = std::fs::remove_file(&self.socket_path);
if let Some(handle) = self.join.take() {
let _ = join_with_deadline(handle, SHUTDOWN_JOIN_BUDGET);
}
self.state.alive.store(false, Ordering::SeqCst);
}
}
impl Drop for UffdHandler {
fn drop(&mut self) {
self.shutdown();
}
}
pub fn snapshot_load_mem_backend_uffd(socket_path: &Path) -> serde_json::Value {
serde_json::json!({
"backend_type": "Uffd",
"backend_path": socket_path,
})
}
fn join_with_deadline(handle: JoinHandle<()>, budget: Duration) -> Result<(), JoinHandle<()>> {
let deadline = Instant::now() + budget;
loop {
if handle.is_finished() {
let _ = handle.join();
return Ok(());
}
if Instant::now() >= deadline {
return Err(handle);
}
thread::sleep(SHUTDOWN_POLL_INTERVAL);
}
}
#[cfg(target_os = "linux")]
mod platform {
use super::*;
use std::fs::File;
use std::io::IoSliceMut;
use std::os::fd::{AsRawFd, FromRawFd, RawFd};
use std::os::unix::net::{UnixListener, UnixStream};
use std::ptr;
use nix::sys::socket::{ControlMessageOwned, MsgFlags, recvmsg};
use serde::Deserialize;
use userfaultfd::{Event, Uffd};
#[derive(Debug, Deserialize)]
struct GuestRegionUffdMapping {
base_host_virt_addr: u64,
size: u64,
offset: u64,
#[allow(dead_code)] page_size_kib: Option<u64>,
}
const PAGE_SIZE: usize = 4096;
const HANDSHAKE_BUF_BYTES: usize = 65_536;
pub fn start(config: UffdConfig) -> VmRuntimeResult<UffdHandler> {
let UffdConfig {
socket_path,
mem_file_path,
} = config;
if let Err(err) = std::fs::remove_file(&socket_path)
&& err.kind() != std::io::ErrorKind::NotFound
{
return Err(VmRuntimeError::Uffd(format!(
"remove stale socket {}: {err}",
socket_path.display()
)));
}
let listener = UnixListener::bind(&socket_path)
.map_err(|e| VmRuntimeError::Uffd(format!("bind {}: {e}", socket_path.display())))?;
let state = Arc::new(HandlerState {
alive: AtomicBool::new(true),
shutdown: AtomicBool::new(false),
});
let state_thread = Arc::clone(&state);
let mem_file_path_thread = mem_file_path.clone();
let socket_path_thread = socket_path.clone();
let join = thread::Builder::new()
.name(format!("microvm-uffd:{}", socket_path.display()))
.spawn(move || {
let result = run_handler(listener, &mem_file_path_thread, &state_thread);
state_thread.alive.store(false, Ordering::SeqCst);
if let Err(err) = result {
eprintln!(
"[microvm-uffd] handler exited with error \
(socket={}, mem_file={}): {err}",
socket_path_thread.display(),
mem_file_path_thread.display(),
);
}
})
.map_err(|e| VmRuntimeError::Uffd(format!("spawn handler thread: {e}")))?;
Ok(UffdHandler {
socket_path,
state,
join: Some(join),
})
}
fn run_handler(
listener: UnixListener,
mem_file_path: &Path,
state: &HandlerState,
) -> VmRuntimeResult<()> {
let stream = match listener.accept() {
Ok((stream, _addr)) => stream,
Err(_) if state.shutdown.load(Ordering::SeqCst) => return Ok(()),
Err(e) => return Err(VmRuntimeError::Uffd(format!("accept: {e}"))),
};
drop(listener);
let (uffd, regions) = receive_handshake(&stream)?;
let mem_file = File::open(mem_file_path).map_err(|e| {
VmRuntimeError::Uffd(format!("open mem file {}: {e}", mem_file_path.display()))
})?;
let mem_size = mem_file
.metadata()
.map_err(|e| {
VmRuntimeError::Uffd(format!("stat mem file {}: {e}", mem_file_path.display()))
})?
.len() as usize;
let mem_ptr = unsafe {
let ptr = libc::mmap(
ptr::null_mut(),
mem_size,
libc::PROT_READ,
libc::MAP_PRIVATE,
mem_file.as_raw_fd(),
0,
);
if ptr == libc::MAP_FAILED {
return Err(VmRuntimeError::Uffd(format!(
"mmap mem file {}: {}",
mem_file_path.display(),
std::io::Error::last_os_error()
)));
}
ptr as *const u8
};
let _mmap_guard = MmapGuard {
ptr: mem_ptr,
size: mem_size,
};
fault_loop(&uffd, &stream, mem_ptr, mem_size, ®ions, state);
Ok(())
}
fn receive_handshake(
stream: &UnixStream,
) -> VmRuntimeResult<(Uffd, Vec<GuestRegionUffdMapping>)> {
let mut buf = vec![0u8; HANDSHAKE_BUF_BYTES];
let mut cmsg_buf = nix::cmsg_space!(RawFd);
let mut iov = [IoSliceMut::new(&mut buf)];
let msg = recvmsg::<()>(
stream.as_raw_fd(),
&mut iov,
Some(&mut cmsg_buf),
MsgFlags::empty(),
)
.map_err(|e| VmRuntimeError::Uffd(format!("recvmsg from firecracker: {e}")))?;
let mut uffd_fd: Option<RawFd> = None;
for cmsg in msg
.cmsgs()
.map_err(|e| VmRuntimeError::Uffd(format!("decode cmsg from firecracker: {e}")))?
{
if let ControlMessageOwned::ScmRights(fds) = cmsg
&& let Some(&fd) = fds.first()
{
uffd_fd = Some(fd);
}
}
let uffd_fd = uffd_fd.ok_or_else(|| {
VmRuntimeError::Uffd("firecracker did not send a userfaultfd fd".into())
})?;
let json_len = msg.bytes;
let json_slice = &buf[..json_len];
let json_str = std::str::from_utf8(json_slice)
.map_err(|e| VmRuntimeError::Uffd(format!("region mapping payload not utf-8: {e}")))?;
let regions: Vec<GuestRegionUffdMapping> = serde_json::from_str(json_str)
.map_err(|e| VmRuntimeError::Uffd(format!("parse region mappings: {e}")))?;
let uffd = unsafe { Uffd::from_raw_fd(uffd_fd) };
Ok((uffd, regions))
}
fn fault_loop(
uffd: &Uffd,
stream: &UnixStream,
mem_base: *const u8,
mem_size: usize,
regions: &[GuestRegionUffdMapping],
state: &HandlerState,
) {
let uffd_fd = uffd.as_raw_fd();
let stream_fd = stream.as_raw_fd();
loop {
if state.shutdown.load(Ordering::SeqCst) {
return;
}
let mut fds = [
libc::pollfd {
fd: uffd_fd,
events: libc::POLLIN,
revents: 0,
},
libc::pollfd {
fd: stream_fd,
events: 0, revents: 0,
},
];
let timeout_ms: libc::c_int = 100;
let rc = unsafe { libc::poll(fds.as_mut_ptr(), fds.len() as libc::nfds_t, timeout_ms) };
if rc < 0 {
let err = std::io::Error::last_os_error();
if err.kind() == std::io::ErrorKind::Interrupted {
continue;
}
eprintln!("[microvm-uffd] poll error: {err}");
return;
}
if rc == 0 {
continue;
}
if fds[1].revents & (libc::POLLHUP | libc::POLLERR) != 0 {
return;
}
if fds[0].revents & libc::POLLHUP != 0 {
return;
}
if fds[0].revents & libc::POLLIN == 0 {
continue;
}
match uffd.read_event() {
Ok(Some(Event::Pagefault { addr, .. })) => {
handle_pagefault(uffd, addr as u64, regions, mem_base, mem_size);
}
Ok(Some(Event::Remove { start, end })) => {
let size = (end as usize).saturating_sub(start as usize);
if size > 0 {
let _ = uffd.unregister(start as *mut _, size);
}
}
Ok(Some(_)) => {
}
Ok(None) => {
continue;
}
Err(e) => {
eprintln!("[microvm-uffd] read_event error: {e}");
return;
}
}
}
}
fn handle_pagefault(
uffd: &Uffd,
fault_addr: u64,
regions: &[GuestRegionUffdMapping],
mem_base: *const u8,
mem_size: usize,
) {
let aligned = (fault_addr as usize) & !(PAGE_SIZE - 1);
for region in regions {
let start = region.base_host_virt_addr as usize;
let end = start.saturating_add(region.size as usize);
if aligned < start || aligned >= end {
continue;
}
let in_region = aligned - start;
let src_offset = (region.offset as usize).saturating_add(in_region);
if src_offset.saturating_add(PAGE_SIZE) > mem_size {
eprintln!(
"[microvm-uffd] fault page out of mem-file bounds: \
src_offset={src_offset}, mem_size={mem_size}"
);
return;
}
let src = unsafe { mem_base.add(src_offset) };
let dst = aligned as *mut libc::c_void;
match unsafe { uffd.copy(src.cast(), dst, PAGE_SIZE, true) } {
Ok(_) => {}
Err(userfaultfd::Error::CopyFailed(errno)) if errno as i32 == libc::EEXIST => {
}
Err(e) => {
eprintln!("[microvm-uffd] UFFDIO_COPY failed at 0x{aligned:x}: {e}");
}
}
return;
}
let zero = [0u8; PAGE_SIZE];
let dst = aligned as *mut libc::c_void;
let _ = unsafe { uffd.copy(zero.as_ptr().cast(), dst, PAGE_SIZE, true) };
}
struct MmapGuard {
ptr: *const u8,
size: usize,
}
unsafe impl Send for MmapGuard {}
impl Drop for MmapGuard {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe { libc::munmap(self.ptr as *mut libc::c_void, self.size) };
}
}
}
}
#[cfg(not(target_os = "linux"))]
mod platform {
use super::*;
pub fn start(_config: UffdConfig) -> VmRuntimeResult<UffdHandler> {
Err(VmRuntimeError::Uffd(
"userfaultfd is Linux-only; \
this build target has no UFFD handler implementation"
.into(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::os::unix::net::UnixListener;
#[test]
fn snapshot_load_mem_backend_uffd_returns_expected_shape() {
let path = Path::new("/tmp/some-vm/uffd.sock");
let value = snapshot_load_mem_backend_uffd(path);
assert_eq!(value["backend_type"], "Uffd");
assert_eq!(value["backend_path"], "/tmp/some-vm/uffd.sock");
let obj = value.as_object().expect("object");
assert_eq!(obj.len(), 2);
}
#[test]
fn snapshot_load_mem_backend_uffd_preserves_non_ascii_paths() {
let path = Path::new("/var/run/microvm/vm-αβγ/uffd.sock");
let value = snapshot_load_mem_backend_uffd(path);
assert_eq!(
value["backend_path"].as_str().expect("string path"),
"/var/run/microvm/vm-αβγ/uffd.sock"
);
}
#[cfg(target_os = "linux")]
#[test]
fn start_removes_stale_socket_before_binding() {
let dir = tempfile::tempdir().expect("tempdir");
let socket_path = dir.path().join("uffd.sock");
let mem_file_path = dir.path().join("mem.bin");
std::fs::write(&mem_file_path, b"unused-for-this-test").expect("write mem file");
let stale = UnixListener::bind(&socket_path).expect("seed stale socket");
drop(stale); assert!(socket_path.exists(), "stale socket file should exist");
let handler = UffdHandler::start(UffdConfig {
socket_path: socket_path.clone(),
mem_file_path,
})
.expect("start with stale socket present must succeed");
assert!(socket_path.exists());
assert!(handler.is_alive());
drop(handler);
}
#[cfg(target_os = "linux")]
#[test]
fn start_errors_when_socket_parent_dir_does_not_exist() {
let dir = tempfile::tempdir().expect("tempdir");
let socket_path = dir.path().join("nope").join("uffd.sock");
let mem_file_path = dir.path().join("mem.bin");
std::fs::write(&mem_file_path, b"x").unwrap();
let err = UffdHandler::start(UffdConfig {
socket_path,
mem_file_path,
})
.expect_err("missing parent dir must surface as Uffd error");
match err {
VmRuntimeError::Uffd(msg) => {
assert!(msg.contains("bind"), "expected bind failure, got: {msg}");
}
other => panic!("unexpected error variant: {other:?}"),
}
}
#[cfg(target_os = "linux")]
#[test]
fn shutdown_unblocks_accept_and_drops_listener() {
let dir = tempfile::tempdir().expect("tempdir");
let socket_path = dir.path().join("uffd.sock");
let mem_file_path = dir.path().join("mem.bin");
std::fs::write(&mem_file_path, b"unused").unwrap();
let mut handler = UffdHandler::start(UffdConfig {
socket_path: socket_path.clone(),
mem_file_path,
})
.expect("start");
assert!(handler.is_alive());
assert!(socket_path.exists());
let start = Instant::now();
handler.shutdown();
let elapsed = start.elapsed();
assert!(
elapsed < Duration::from_secs(2),
"shutdown took too long: {elapsed:?}"
);
assert!(!handler.is_alive(), "is_alive must flip off after shutdown");
assert!(
!socket_path.exists(),
"shutdown must remove the socket file"
);
}
#[cfg(target_os = "linux")]
#[test]
fn drop_runs_shutdown() {
let dir = tempfile::tempdir().expect("tempdir");
let socket_path = dir.path().join("uffd.sock");
let mem_file_path = dir.path().join("mem.bin");
std::fs::write(&mem_file_path, b"unused").unwrap();
{
let handler = UffdHandler::start(UffdConfig {
socket_path: socket_path.clone(),
mem_file_path,
})
.expect("start");
assert!(handler.is_alive());
}
assert!(
!socket_path.exists(),
"Drop must remove the socket file via shutdown()"
);
}
#[cfg(target_os = "linux")]
#[test]
#[ignore = "requires CAP_SYS_PTRACE / unprivileged_userfaultfd; \
exercised by the firecracker adapter integration tests"]
fn end_to_end_pagefault_servicing() {
}
}