#[cfg(unix)]
use std::os::unix::fs::FileExt;
use std::{
collections::BTreeMap,
fs::File,
io::{Read, Write},
os::unix::net::UnixStream,
path::{Path, PathBuf},
sync::{
Arc,
atomic::{AtomicBool, AtomicU64, Ordering},
},
time::Duration,
};
use bytes::{Bytes, BytesMut};
use parking_lot::Mutex;
use thiserror::Error;
#[derive(Debug, Clone)]
pub struct PagerConfig {
pub poll_interval: Duration,
pub prewarm: PrewarmList,
}
impl Default for PagerConfig {
fn default() -> Self {
Self {
poll_interval: Duration::from_secs(1),
prewarm: PrewarmList::default(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct PrewarmList {
pub pages: Vec<u64>,
}
impl PrewarmList {
#[must_use]
pub fn from_boot_critical(kernel_text_ipa: u64, vcpu0_stack_ipa: u64, fdt_ipa: u64) -> Self {
Self {
pages: vec![kernel_text_ipa, vcpu0_stack_ipa, fdt_ipa],
}
}
}
#[derive(Debug, Clone)]
pub struct PageRequest {
pub ipa: u64,
pub page_size: u64,
}
#[derive(Debug, Error)]
pub enum PageSourceError {
#[error("page source returned a short read: {got} bytes, expected {expected}")]
Short {
got: u64,
expected: u64,
},
#[error("page source open: {0}")]
Open(String),
#[error("page source I/O: {0}")]
Io(#[source] std::io::Error),
}
impl From<std::io::Error> for PageSourceError {
fn from(err: std::io::Error) -> Self {
Self::Io(err)
}
}
pub trait PageSource: Send + Sync + std::fmt::Debug {
fn fetch(&self, req: &PageRequest) -> Result<Bytes, PageSourceError>;
}
#[derive(Debug)]
pub struct FilePageSource {
file: File,
ram_start: u64,
}
impl FilePageSource {
pub fn open(path: &Path, ram_start: u64) -> Result<Self, PageSourceError> {
let file = File::open(path)
.map_err(|e| PageSourceError::Open(format!("{}: {e}", path.display())))?;
Ok(Self { file, ram_start })
}
}
impl PageSource for FilePageSource {
fn fetch(&self, req: &PageRequest) -> Result<Bytes, PageSourceError> {
if req.ipa < self.ram_start {
return Err(PageSourceError::Open(format!(
"ipa {:#x} below ram_start {:#x}",
req.ipa, self.ram_start
)));
}
let offset = req.ipa - self.ram_start;
let aligned = offset & !(req.page_size - 1);
let mut buf = vec![
0u8;
usize::try_from(req.page_size).map_err(|_| {
PageSourceError::Open("page_size > usize::MAX".into())
})?
];
#[cfg(unix)]
let n = self.file.read_at(&mut buf, aligned)?;
#[cfg(not(unix))]
let n = {
let _ = aligned;
buf.fill(0);
buf.len()
};
let n_u64 = u64::try_from(n).unwrap_or(u64::MAX);
if n_u64 < req.page_size {
return Err(PageSourceError::Short {
got: n_u64,
expected: req.page_size,
});
}
Ok(Bytes::from(buf))
}
}
#[derive(Debug)]
pub struct UffdPageSource {
socket: Mutex<UnixStream>,
uds_path: PathBuf,
}
impl UffdPageSource {
pub fn connect(path: &Path) -> Result<Self, PageSourceError> {
let sock = UnixStream::connect(path)
.map_err(|e| PageSourceError::Open(format!("connect({}): {e}", path.display())))?;
Ok(Self {
socket: Mutex::new(sock),
uds_path: path.to_path_buf(),
})
}
#[must_use]
pub fn path(&self) -> &Path {
&self.uds_path
}
}
impl PageSource for UffdPageSource {
fn fetch(&self, req: &PageRequest) -> Result<Bytes, PageSourceError> {
let mut sock = self.socket.lock();
sock.write_all(&req.ipa.to_le_bytes())?;
sock.write_all(&req.page_size.to_le_bytes())?;
let mut hdr = [0u8; 16];
sock.read_exact(&mut hdr)?;
let echo_ipa = u64::from_le_bytes(hdr[0..8].try_into().unwrap_or([0; 8]));
let echo_size = u64::from_le_bytes(hdr[8..16].try_into().unwrap_or([0; 8]));
if echo_ipa != req.ipa || echo_size != req.page_size {
return Err(PageSourceError::Open(format!(
"Uffd protocol violation: expected ipa={:#x} size={} got ipa={:#x} size={}",
req.ipa, req.page_size, echo_ipa, echo_size
)));
}
let want = usize::try_from(req.page_size)
.map_err(|_| PageSourceError::Open("page_size > usize::MAX".into()))?;
let mut buf = BytesMut::zeroed(want);
sock.read_exact(&mut buf)?;
Ok(buf.freeze())
}
}
#[derive(Debug, Default)]
pub struct PagerStats {
pub faults: AtomicU64,
pub prewarmed: AtomicU64,
pub port_reinstalls: AtomicU64,
pub forwarded_exceptions: AtomicU64,
}
impl PagerStats {
pub fn snapshot(&self) -> PagerStatsSnapshot {
PagerStatsSnapshot {
faults: self.faults.load(Ordering::Relaxed),
prewarmed: self.prewarmed.load(Ordering::Relaxed),
port_reinstalls: self.port_reinstalls.load(Ordering::Relaxed),
forwarded_exceptions: self.forwarded_exceptions.load(Ordering::Relaxed),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PagerStatsSnapshot {
pub faults: u64,
pub prewarmed: u64,
pub port_reinstalls: u64,
pub forwarded_exceptions: u64,
}
#[derive(Debug, Error)]
pub enum PagerError {
#[error("postcopy region missing: ipa={ipa:#x}")]
OutOfRegion {
ipa: u64,
},
#[error("page source: {0}")]
Source(#[from] PageSourceError),
#[error("mach error: {0}")]
Mach(String),
#[error("pager server thread spawn: {0}")]
Spawn(#[source] std::io::Error),
}
#[derive(Debug, Clone, Copy)]
struct Region {
base: u64,
size: u64,
}
impl Region {
fn contains(&self, ipa: u64) -> bool {
ipa >= self.base && ipa < self.base.saturating_add(self.size)
}
}
#[derive(Debug)]
pub struct Pager {
inner: Arc<PagerInner>,
}
#[derive(Debug)]
struct PagerInner {
source: Box<dyn PageSource>,
config: PagerConfig,
regions: Mutex<BTreeMap<u64, Region>>, stats: PagerStats,
shutdown: AtomicBool,
}
impl Pager {
#[must_use]
pub fn new(source: Box<dyn PageSource>, config: PagerConfig) -> Self {
Self {
inner: Arc::new(PagerInner {
source,
config,
regions: Mutex::new(BTreeMap::new()),
stats: PagerStats::default(),
shutdown: AtomicBool::new(false),
}),
}
}
pub fn register_region(&self, base: u64, size: u64) {
let mut regs = self.inner.regions.lock();
regs.insert(base, Region { base, size });
}
pub fn prewarm(&self) -> Result<(), PagerError> {
let pages = self.inner.config.prewarm.pages.clone();
for ipa in pages {
self.serve_fault(ipa, host_page_size())?;
self.inner.stats.prewarmed.fetch_add(1, Ordering::Relaxed);
}
Ok(())
}
pub fn serve_fault(&self, ipa: u64, page_size: u64) -> Result<Bytes, PagerError> {
if !self.contains_ipa(ipa) {
return Err(PagerError::OutOfRegion { ipa });
}
let aligned = ipa & !(page_size - 1);
let req = PageRequest {
ipa: aligned,
page_size,
};
let bytes = self.inner.source.fetch(&req)?;
self.inner.stats.faults.fetch_add(1, Ordering::Relaxed);
Ok(bytes)
}
#[must_use]
pub fn contains_ipa(&self, ipa: u64) -> bool {
let regs = self.inner.regions.lock();
regs.values().any(|r| r.contains(ipa))
}
#[must_use]
pub fn stats(&self) -> PagerStatsSnapshot {
self.inner.stats.snapshot()
}
pub fn request_shutdown(&self) {
self.inner.shutdown.store(true, Ordering::SeqCst);
}
#[must_use]
pub fn handle(&self) -> PagerHandle {
PagerHandle(Arc::clone(&self.inner))
}
#[doc(hidden)]
pub fn record_port_reinstall(&self) {
self.inner
.stats
.port_reinstalls
.fetch_add(1, Ordering::Relaxed);
}
#[doc(hidden)]
pub fn record_forwarded(&self) {
self.inner
.stats
.forwarded_exceptions
.fetch_add(1, Ordering::Relaxed);
}
}
#[derive(Debug, Clone)]
pub struct PagerHandle(Arc<PagerInner>);
impl PagerHandle {
pub fn is_shutting_down(&self) -> bool {
self.0.shutdown.load(Ordering::SeqCst)
}
#[must_use]
pub fn poll_interval(&self) -> Duration {
self.0.config.poll_interval
}
}
#[must_use]
pub fn host_page_size() -> u64 {
16 * 1024
}
#[cfg(target_os = "macos")]
mod mach_imp {
use std::thread::{self, JoinHandle};
#[cfg(not(feature = "pager-live-mach"))]
use std::time::Instant;
#[cfg(not(feature = "pager-live-mach"))]
use tracing::{debug, info};
#[cfg(not(feature = "pager-live-mach"))]
use super::PagerHandle;
use super::{Pager, PagerError};
pub fn spawn_server(pager: &Pager) -> Result<JoinHandle<Result<(), PagerError>>, PagerError> {
let handle = pager.handle();
thread::Builder::new()
.name("squib-pager".into())
.spawn(move || -> Result<(), PagerError> {
#[cfg(feature = "pager-live-mach")]
{
live::run_live_server(&handle)
}
#[cfg(not(feature = "pager-live-mach"))]
{
run_skeleton_loop(&handle);
Ok(())
}
})
.map_err(PagerError::Spawn)
}
#[cfg(not(feature = "pager-live-mach"))]
fn run_skeleton_loop(handle: &PagerHandle) {
info!(
poll_interval = ?handle.poll_interval(),
"squib-pager server starting (drift-poll skeleton; build with `--features pager-live-mach` to install a real exception port)"
);
let mut last_drift_check = Instant::now();
while !handle.is_shutting_down() {
thread::park_timeout(handle.poll_interval());
if last_drift_check.elapsed() >= handle.poll_interval() {
debug!("squib-pager drift check (no-op skeleton)");
last_drift_check = Instant::now();
}
}
debug!("squib-pager server exiting (shutdown requested)");
}
#[cfg(feature = "pager-live-mach")]
mod live {
use std::time::Instant;
use mach2::{
exception_types::{
EXC_MASK_BAD_ACCESS, EXCEPTION_DEFAULT, exception_behavior_array_t,
exception_flavor_array_t, exception_mask_array_t, exception_mask_t,
},
kern_return::KERN_SUCCESS,
mach_port::{mach_port_allocate, mach_port_deallocate},
mach_types::exception_handler_array_t,
message::{
MACH_MSG_TIMEOUT_NONE, MACH_RCV_MSG, MACH_RCV_TIMEOUT, MACH_RCV_TOO_LARGE,
mach_msg, mach_msg_header_t,
},
port::{MACH_PORT_RIGHT_RECEIVE, mach_port_t},
task::{task_get_exception_ports, task_swap_exception_ports},
thread_status::THREAD_STATE_NONE,
traps::mach_task_self,
};
use tracing::{debug, info, warn};
use super::super::{PagerError, PagerHandle};
const SQUIB_EXC_MASK: exception_mask_t = EXC_MASK_BAD_ACCESS as exception_mask_t;
const MAX_EXCEPTION_PORTS: usize = 32;
pub(super) fn run_live_server(handle: &PagerHandle) -> Result<(), PagerError> {
let our_port = allocate_receive_port()
.map_err(|kr| PagerError::Mach(format!("mach_port_allocate: kr={kr}")))?;
install_exception_port(our_port).map_err(|kr| {
PagerError::Mach(format!("task_swap_exception_ports install: kr={kr}"))
})?;
info!(
port = our_port,
"squib-pager live: exception port installed"
);
let mut last_drift_check = Instant::now();
while !handle.is_shutting_down() {
let timeout_ms =
u32::try_from(handle.poll_interval().as_millis()).unwrap_or(u32::MAX);
let kr = recv_one(our_port, timeout_ms);
match kr {
KERN_SUCCESS => {
debug!("squib-pager live: received exception (no-reply forwarder)");
}
rc if rc == MACH_RCV_TIMEOUT => {
}
rc if rc == MACH_RCV_TOO_LARGE => {
warn!("squib-pager live: oversize exception message dropped");
}
rc => {
warn!(kr = rc, "squib-pager live: unexpected mach_msg return");
}
}
if last_drift_check.elapsed() >= handle.poll_interval() {
if let Err(e) = drift_check_live(our_port) {
warn!(error = ?e, "squib-pager live: drift check failed");
}
last_drift_check = Instant::now();
}
}
let dealloc = unsafe { mach_port_deallocate(mach_task_self(), our_port) };
if dealloc != KERN_SUCCESS {
warn!(
kr = dealloc,
"squib-pager live: mach_port_deallocate failed (best-effort)"
);
}
debug!("squib-pager live: server exiting (shutdown requested)");
Ok(())
}
fn allocate_receive_port() -> Result<mach_port_t, i32> {
let mut port: mach_port_t = 0;
let kr = unsafe {
mach_port_allocate(mach_task_self(), MACH_PORT_RIGHT_RECEIVE, &raw mut port)
};
if kr == KERN_SUCCESS {
Ok(port)
} else {
Err(kr)
}
}
fn install_exception_port(port: mach_port_t) -> Result<(), i32> {
let mut masks: [exception_mask_t; MAX_EXCEPTION_PORTS] = [0; MAX_EXCEPTION_PORTS];
let mut handlers: [mach_port_t; MAX_EXCEPTION_PORTS] = [0; MAX_EXCEPTION_PORTS];
let mut behaviors: [u32; MAX_EXCEPTION_PORTS] = [0; MAX_EXCEPTION_PORTS];
let mut flavors: [i32; MAX_EXCEPTION_PORTS] = [0; MAX_EXCEPTION_PORTS];
let mut count: u32 = MAX_EXCEPTION_PORTS as u32;
let kr = unsafe {
task_swap_exception_ports(
mach_task_self(),
SQUIB_EXC_MASK,
port,
EXCEPTION_DEFAULT.cast_signed(),
THREAD_STATE_NONE,
masks.as_mut_ptr() as exception_mask_array_t,
&raw mut count,
handlers.as_mut_ptr() as exception_handler_array_t,
behaviors.as_mut_ptr() as exception_behavior_array_t,
flavors.as_mut_ptr() as exception_flavor_array_t,
)
};
if kr == KERN_SUCCESS { Ok(()) } else { Err(kr) }
}
fn drift_check_live(our_port: mach_port_t) -> Result<(), i32> {
let mut masks: [exception_mask_t; MAX_EXCEPTION_PORTS] = [0; MAX_EXCEPTION_PORTS];
let mut handlers: [mach_port_t; MAX_EXCEPTION_PORTS] = [0; MAX_EXCEPTION_PORTS];
let mut behaviors: [u32; MAX_EXCEPTION_PORTS] = [0; MAX_EXCEPTION_PORTS];
let mut flavors: [i32; MAX_EXCEPTION_PORTS] = [0; MAX_EXCEPTION_PORTS];
let mut count: u32 = MAX_EXCEPTION_PORTS as u32;
let kr = unsafe {
task_get_exception_ports(
mach_task_self(),
SQUIB_EXC_MASK,
masks.as_mut_ptr() as exception_mask_array_t,
&raw mut count,
handlers.as_mut_ptr() as exception_handler_array_t,
behaviors.as_mut_ptr() as exception_behavior_array_t,
flavors.as_mut_ptr() as exception_flavor_array_t,
)
};
if kr != KERN_SUCCESS {
return Err(kr);
}
let drifted = (0..count as usize)
.any(|i| (masks[i] & SQUIB_EXC_MASK) != 0 && handlers[i] != our_port);
if drifted {
debug!("squib-pager live: drift detected, re-installing");
install_exception_port(our_port)?;
}
Ok(())
}
fn recv_one(port: mach_port_t, timeout_ms: u32) -> i32 {
let mut header = mach_msg_header_t::default();
let option = if timeout_ms == 0 {
MACH_RCV_MSG
} else {
MACH_RCV_MSG | MACH_RCV_TIMEOUT
};
let timeout = if timeout_ms == 0 {
MACH_MSG_TIMEOUT_NONE
} else {
timeout_ms
};
unsafe {
mach_msg(
(&raw mut header).cast(),
option,
0,
size_of::<mach_msg_header_t>() as u32,
port,
timeout,
0,
)
}
}
}
}
#[cfg(target_os = "macos")]
pub use mach_imp::spawn_server as spawn_mach_server;
#[cfg(not(target_os = "macos"))]
pub fn spawn_mach_server(
_pager: &Pager,
) -> Result<std::thread::JoinHandle<Result<(), PagerError>>, PagerError> {
std::thread::Builder::new()
.name("squib-pager-stub".into())
.spawn(|| Ok(()))
.map_err(PagerError::Spawn)
}
#[cfg(test)]
mod tests {
use std::io::Write as _;
use bytes::Bytes;
use tempfile::TempDir;
use super::*;
#[derive(Debug)]
struct FakeSource {
bytes: Bytes,
}
impl PageSource for FakeSource {
fn fetch(&self, _req: &PageRequest) -> Result<Bytes, PageSourceError> {
Ok(self.bytes.clone())
}
}
#[test]
fn test_should_register_and_match_a_postcopy_region() {
let pager = Pager::new(
Box::new(FakeSource {
bytes: Bytes::from(vec![0u8; 16 * 1024]),
}),
PagerConfig::default(),
);
pager.register_region(0x8000_0000, 0x1000_0000);
assert!(pager.contains_ipa(0x8000_1234));
assert!(!pager.contains_ipa(0x9000_0000));
}
#[test]
fn test_should_serve_fault_inside_registered_region() {
let bytes = Bytes::from(vec![0xAB; 16 * 1024]);
let pager = Pager::new(
Box::new(FakeSource {
bytes: bytes.clone(),
}),
PagerConfig::default(),
);
pager.register_region(0x8000_0000, 0x1000_0000);
let got = pager.serve_fault(0x8000_1234, 16 * 1024).unwrap();
assert_eq!(got, bytes);
assert_eq!(pager.stats().faults, 1);
}
#[test]
fn test_should_reject_fault_outside_region() {
let pager = Pager::new(
Box::new(FakeSource {
bytes: Bytes::from(vec![0u8; 16 * 1024]),
}),
PagerConfig::default(),
);
pager.register_region(0x8000_0000, 0x1000_0000);
let err = pager.serve_fault(0xC000_0000, 16 * 1024).unwrap_err();
assert!(matches!(err, PagerError::OutOfRegion { ipa: 0xC000_0000 }));
}
#[derive(Debug, Default)]
struct CaptureSource {
seen: Mutex<Vec<u64>>,
}
impl PageSource for CaptureSource {
fn fetch(&self, req: &PageRequest) -> Result<Bytes, PageSourceError> {
self.seen.lock().push(req.ipa);
Ok(Bytes::from(vec![0u8; req.page_size as usize]))
}
}
#[derive(Debug)]
struct WrapperSource {
inner: Arc<CaptureSource>,
}
impl PageSource for WrapperSource {
fn fetch(&self, req: &PageRequest) -> Result<Bytes, PageSourceError> {
self.inner.fetch(req)
}
}
#[test]
fn test_should_align_request_to_page_boundary() {
let src = Arc::new(CaptureSource::default());
let pager = Pager::new(
Box::new(WrapperSource { inner: src.clone() }),
PagerConfig::default(),
);
pager.register_region(0x8000_0000, 0x1000_0000);
pager.serve_fault(0x8000_1234, 16 * 1024).unwrap();
let seen = src.seen.lock().clone();
assert_eq!(seen, vec![0x8000_0000]);
}
#[test]
fn test_should_prewarm_each_page_on_demand() {
let bytes = Bytes::from(vec![0xCD; 16 * 1024]);
let cfg = PagerConfig {
prewarm: PrewarmList::from_boot_critical(
0x8000_2000, 0x9000_0000, 0x9F00_0000, ),
..Default::default()
};
let pager = Pager::new(Box::new(FakeSource { bytes }), cfg);
pager.register_region(0x8000_0000, 0x2000_0000);
pager.prewarm().unwrap();
let stats = pager.stats();
assert_eq!(stats.prewarmed, 3);
assert_eq!(stats.faults, 3);
}
#[test]
fn test_file_page_source_reads_at_offset_from_ram_start() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("x.mem");
let mut bytes = vec![0u8; 64 * 1024];
for byte in &mut bytes[16 * 1024..32 * 1024] {
*byte = 0x77;
}
std::fs::write(&path, &bytes).unwrap();
let src = FilePageSource::open(&path, 0x8000_0000).unwrap();
let got = src
.fetch(&PageRequest {
ipa: 0x8000_4000, page_size: 16 * 1024,
})
.unwrap();
assert!(got.iter().all(|&b| b == 0x77));
}
#[test]
fn test_file_page_source_rejects_below_ram_start() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("x.mem");
std::fs::write(&path, vec![0u8; 16 * 1024]).unwrap();
let src = FilePageSource::open(&path, 0x8000_0000).unwrap();
let err = src
.fetch(&PageRequest {
ipa: 0x4000_0000,
page_size: 16 * 1024,
})
.unwrap_err();
assert!(matches!(err, PageSourceError::Open(_)));
}
#[test]
fn test_uffd_page_source_round_trips_protocol() {
let dir = TempDir::new().unwrap();
let sock_path = dir.path().join("pager.sock");
let listener = std::os::unix::net::UnixListener::bind(&sock_path).unwrap();
let server = std::thread::spawn(move || {
let (mut sock, _) = listener.accept().unwrap();
let mut hdr = [0u8; 16];
sock.read_exact(&mut hdr).unwrap();
let ipa = u64::from_le_bytes(hdr[0..8].try_into().unwrap());
let size = u64::from_le_bytes(hdr[8..16].try_into().unwrap());
sock.write_all(&hdr).unwrap();
let payload = vec![0xEEu8; size as usize];
sock.write_all(&payload).unwrap();
(ipa, size)
});
let src = UffdPageSource::connect(&sock_path).unwrap();
let got = src
.fetch(&PageRequest {
ipa: 0x8000_0000,
page_size: 16 * 1024,
})
.unwrap();
assert!(got.iter().all(|&b| b == 0xEE));
let _ = server.join().unwrap();
}
#[test]
fn test_should_request_shutdown_via_handle() {
let pager = Pager::new(
Box::new(FakeSource {
bytes: Bytes::from(vec![0u8; 16 * 1024]),
}),
PagerConfig::default(),
);
let h = pager.handle();
assert!(!h.is_shutting_down());
pager.request_shutdown();
assert!(h.is_shutting_down());
}
#[test]
fn test_stats_record_port_reinstall_and_forwarded() {
let pager = Pager::new(
Box::new(FakeSource {
bytes: Bytes::from(vec![0u8; 16 * 1024]),
}),
PagerConfig::default(),
);
pager.record_port_reinstall();
pager.record_port_reinstall();
pager.record_forwarded();
let s = pager.stats();
assert_eq!(s.port_reinstalls, 2);
assert_eq!(s.forwarded_exceptions, 1);
}
}