use std::fs;
use std::io::{Read, Write};
use std::os::unix::fs::OpenOptionsExt;
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, OwnedFd};
use std::os::unix::process::CommandExt;
use std::path::Path;
use std::process::{Child, Command, Stdio};
use std::sync::Arc;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicBool, AtomicI32, Ordering};
use crate::sync::Latch;
use nix::mount::{MsFlags, mount};
use nix::poll::{PollFd, PollFlags, PollTimeout, poll};
use nix::pty::openpty;
use nix::sys::reboot::{RebootMode, reboot};
use nix::sys::termios::{SetArg, cfmakeraw, tcgetattr, tcsetattr};
const COM2: &str = "/dev/ttyS1";
const COM1: &str = "/dev/ttyS0";
const HVC0: &str = "/dev/hvc0";
const TRACE_SCHED_EXT_DUMP_ENABLE: &str =
"/sys/kernel/tracing/events/sched_ext/sched_ext_dump/enable";
const TRACE_TRACING_ON: &str = "/sys/kernel/tracing/tracing_on";
const TRACE_PIPE: &str = "/sys/kernel/tracing/trace_pipe";
const SYSFS_SCHED_EXT_ROOT_OPS: &str = "/sys/kernel/sched_ext/root/ops";
fn force_reboot() -> ! {
let _ = reboot(RebootMode::RB_AUTOBOOT);
loop {
std::thread::park();
}
}
static SCHED_PID: AtomicI32 = AtomicI32::new(0);
pub(crate) fn sched_pid() -> Option<libc::pid_t> {
let v = SCHED_PID.load(Ordering::Acquire);
if v == 0 { None } else { Some(v) }
}
struct SigchldDispositionGuard {
prev: libc::sighandler_t,
}
impl SigchldDispositionGuard {
fn install(handler: libc::sighandler_t) -> Self {
let prev = unsafe { libc::signal(libc::SIGCHLD, handler) };
assert_ne!(
prev,
libc::SIG_ERR,
"failed to install SIGCHLD handler — libc::signal returned SIG_ERR; \
check signum / handler validity",
);
Self { prev }
}
}
impl Drop for SigchldDispositionGuard {
fn drop(&mut self) {
unsafe {
libc::signal(libc::SIGCHLD, self.prev);
}
}
}
fn with_sigchld_default<F, R>(f: F) -> R
where
F: FnOnce() -> R,
{
let _guard = SigchldDispositionGuard::install(libc::SIG_DFL);
f()
}
fn proc_pid_alive(pid: u32) -> bool {
Path::new(&format!("/proc/{pid}")).exists()
}
fn u64_to_hex_asm(value: u64, buf: &mut [u8; 16]) -> &[u8] {
static HEX: &[u8; 16] = b"0123456789abcdef";
for (i, slot) in buf.iter_mut().enumerate() {
let nibble = (value >> ((15 - i) * 4)) & 0xf;
*slot = HEX[nibble as usize];
}
&buf[..]
}
fn write_all_asm(fd: libc::c_int, bytes: &[u8]) {
let mut off = 0;
while off < bytes.len() {
let n = unsafe {
libc::write(
fd,
bytes.as_ptr().add(off) as *const libc::c_void,
bytes.len() - off,
)
};
if n <= 0 {
return;
}
off += n as usize;
}
}
unsafe extern "C" fn fatal_signal_handler(
sig: libc::c_int,
info: *mut libc::siginfo_t,
_ctx: *mut libc::c_void,
) {
let prefix: &[u8] = match sig {
libc::SIGSEGV => b"PANIC: fatal signal SIGSEGV at addr 0x",
libc::SIGBUS => b"PANIC: fatal signal SIGBUS at addr 0x",
libc::SIGILL => b"PANIC: fatal signal SIGILL at addr 0x",
_ => b"PANIC: fatal signal (unknown) at addr 0x",
};
let addr: u64 = if info.is_null() {
0
} else {
let p = unsafe { (*info).si_addr() };
p as u64
};
let mut hex_buf = [0u8; 16];
let hex = u64_to_hex_asm(addr, &mut hex_buf);
for path in [c"/dev/ttyS1", c"/dev/ttyS0"] {
let fd = unsafe { libc::open(path.as_ptr(), libc::O_WRONLY | libc::O_NONBLOCK) };
if fd < 0 {
continue;
}
write_all_asm(fd, prefix);
write_all_asm(fd, hex);
write_all_asm(fd, b"\n");
unsafe {
libc::tcdrain(fd);
libc::close(fd);
}
}
unsafe {
libc::reboot(libc::LINUX_REBOOT_CMD_RESTART);
libc::_exit(1);
}
}
fn install_fatal_signal_handlers() {
let mut act: libc::sigaction = unsafe { std::mem::zeroed() };
act.sa_sigaction = fatal_signal_handler as *const () as usize;
act.sa_flags = libc::SA_SIGINFO | libc::SA_RESETHAND;
unsafe {
libc::sigemptyset(&mut act.sa_mask);
libc::sigaddset(&mut act.sa_mask, libc::SIGSEGV);
libc::sigaddset(&mut act.sa_mask, libc::SIGBUS);
libc::sigaddset(&mut act.sa_mask, libc::SIGILL);
}
let stack_size = libc::SIGSTKSZ.max(65536);
let stack = unsafe {
libc::mmap(
std::ptr::null_mut(),
stack_size,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_PRIVATE | libc::MAP_ANONYMOUS,
-1,
0,
)
};
if stack != libc::MAP_FAILED {
let ss = libc::stack_t {
ss_sp: stack,
ss_flags: 0,
ss_size: stack_size,
};
unsafe {
libc::sigaltstack(&ss, std::ptr::null_mut());
}
act.sa_flags |= libc::SA_ONSTACK;
}
for sig in [libc::SIGSEGV, libc::SIGBUS, libc::SIGILL] {
let _ = unsafe { libc::sigaction(sig, &act, std::ptr::null_mut()) };
}
}
pub(crate) fn ktstr_guest_init() -> ! {
let t0 = std::time::Instant::now();
install_fatal_signal_handlers();
std::panic::set_hook(Box::new(|info| {
let bt = std::backtrace::Backtrace::force_capture();
let msg = format!("PANIC: {info}\n{bt}\n");
let _ = fs::write(COM2, &msg);
let _ = fs::write(COM1, &msg);
let _ = std::io::stdout().flush();
let _ = std::io::stderr().flush();
force_reboot();
}));
unsafe {
libc::signal(libc::SIGCHLD, libc::SIG_IGN);
}
mount_filesystems();
let t_mounts = t0.elapsed();
if !Path::new("/.ktstr_init_ok").exists() {
if let Ok(raw) = rmesg::logs_raw(rmesg::Backend::Default, false) {
let _ = fs::write(COM2, &raw);
let _ = fs::write(COM1, &raw);
}
let msg = "FATAL: initramfs extraction incomplete — kernel ran out of \
memory during cpio extraction. This indicates a bug in ktstr's \
memory estimation. Please report this issue. As a workaround, \
try `--memory N` with a larger value.";
let _ = fs::write(COM2, msg);
let _ = fs::write(COM1, msg);
eprintln!("{msg}");
force_reboot();
}
for attempt in 0..100 {
if crate::vmm::guest_comms::send_sys_rdy() {
break;
}
if attempt == 99 {
tracing::warn!(
"ktstr-init: send_sys_rdy retry budget exhausted (10 s); \
monitor will rely on its 5 s pre-sample timeout + \
data_valid gate"
);
}
std::thread::sleep(std::time::Duration::from_millis(100));
}
auto_mount_data_disks();
let _ = fs::write("/proc/sys/kernel/bpf_stats_enabled", "1");
if !shell_mode_requested() {
crate::vmm::guest_comms::send_lifecycle(crate::vmm::wire::LifecyclePhase::InitStarted, "");
}
redirect_stdio_to_bulk_port();
let t_stdio = t0.elapsed();
if let Ok(cmdline) = fs::read_to_string("/proc/cmdline")
&& let Some(val) = cmdline
.split_whitespace()
.find(|s| s.starts_with("RUST_LOG="))
.and_then(|s| s.strip_prefix("RUST_LOG="))
{
unsafe { std::env::set_var("RUST_LOG", val) };
}
let t_pre_subscriber = t0.elapsed();
tracing_subscriber::fmt()
.with_writer(std::io::stderr)
.with_ansi(false)
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("warn")),
)
.init();
let t_subscriber = t0.elapsed();
tracing::debug!(
mount_ms = t_mounts.as_millis() as u64,
stdio_ms = t_stdio.as_millis() as u64,
pre_subscriber_ms = t_pre_subscriber.as_millis() as u64,
subscriber_ms = t_subscriber.as_millis() as u64,
"guest_init_timing",
);
unsafe {
std::env::set_var("PATH", build_include_path());
std::env::set_var("KTSTR_GUEST_INIT", "1");
}
if disk_template_mode_requested() {
let _span = tracing::debug_span!("disk_template_mode").entered();
let code = run_disk_template_mode();
let _ = std::io::stdout().flush();
let _ = std::io::stderr().flush();
crate::vmm::guest_comms::send_exit(code);
force_reboot();
}
if shell_mode_requested() {
let _shell_span = tracing::debug_span!("shell_mode").entered();
let console_dev = shell_console_device();
redirect_all_stdio_to(console_dev);
{
let _s = tracing::debug_span!("busybox_install").entered();
let _ = Command::new("/bin/busybox")
.args(["--install", "-s", "/bin"])
.status();
}
mount_devpts();
if let Some(cmd) = shell_exec_cmd() {
tracing::debug!(cmd = %cmd, "shell exec mode");
let stdout_fd = unsafe { BorrowedFd::borrow_raw(1) };
if let Ok(mut termios) = tcgetattr(stdout_fd) {
termios
.output_flags
.remove(nix::sys::termios::OutputFlags::OPOST);
let _ = tcsetattr(stdout_fd, SetArg::TCSANOW, &termios);
}
let status = with_sigchld_default(|| {
Command::new("/bin/busybox")
.args(["sh", "-c", &cmd])
.status()
});
let code = match status {
Ok(s) => s.code().unwrap_or(1),
Err(e) => {
eprintln!("ktstr-init: exec failed: {e}");
1
}
};
crate::vmm::guest_comms::send_exec_exit(code as i32);
let _ = std::io::stdout().flush();
let _ = std::io::stderr().flush();
unsafe {
libc::tcdrain(1);
}
unsafe {
libc::tcdrain(2);
}
force_reboot();
}
let kernel_version = fs::read_to_string("/proc/version")
.ok()
.and_then(|v| v.split_whitespace().nth(2).map(|s| s.to_string()))
.unwrap_or_else(|| "unknown".to_string());
let mem_mb = fs::read_to_string("/proc/meminfo").ok().and_then(|s| {
s.lines()
.find(|l| l.starts_with("MemTotal:"))
.and_then(|l| l.split_whitespace().nth(1))
.and_then(|kb| kb.parse::<u64>().ok())
.map(|kb| kb / 1024)
});
println!("ktstr shell");
println!(" kernel: {kernel_version}");
if let Some(mb) = mem_mb {
println!(" memory: {mb} MB");
}
print_topology_line();
print_includes_line();
println!(" tools: busybox (ls, ps, top, dmesg, ip, vi, ...)");
println!(" mounts: /proc /sys /dev /sys/fs/cgroup /sys/fs/bpf /tmp");
println!(" /sys/kernel/debug /sys/kernel/tracing /dev/pts");
println!(" type `exit` for clean shutdown, Ctrl+A X to force-kill");
let _ = std::io::stdout().flush();
tracing::debug!("spawning interactive shell with PTY");
spawn_shell_with_pty();
force_reboot();
}
let args: Vec<String> = {
let content = fs::read_to_string("/args").unwrap_or_default();
let mut a = vec!["/init".to_string()];
a.extend(content.lines().map(|s| s.to_string()));
a
};
tracing::debug!(args = ?args, "parsed /args");
crate::test_support::propagate_rust_env_from_cmdline();
let _s_phase2b = tracing::debug_span!("phase2b_probe_phase_a").entered();
let probe_phase_a = crate::test_support::start_probe_phase_a(&args);
let probes_active = probe_phase_a.is_some();
drop(_s_phase2b);
let _s_phase3 = tracing::debug_span!("phase3_scheduler_start").entered();
create_cgroup_parent_from_sched_args();
exec_shell_script("/sched_enable");
let probe_drain = probe_phase_a.as_ref().map(|pa| ProbeDrain {
stop: pa.pipeline.stop.clone(),
output_done: pa.pipeline.output_done.clone(),
});
let (mut sched_child, sched_log_path) = start_scheduler(probe_drain);
drop(_s_phase3);
let _s_phase4 = tracing::debug_span!("phase4_vc_poll").entered();
let (trace_stop, trace_handle) = start_trace_pipe();
let vc_poll_stop = start_hvc0_poll(trace_stop.clone());
drop(_s_phase4);
let suppress_com2 = Arc::new(AtomicBool::new(probes_active));
let probe_output_done = probe_phase_a
.as_ref()
.map(|pa| pa.pipeline.output_done.clone());
let sched_exit_stop = start_sched_exit_monitor(
sched_child.as_ref().map(|c| c.id()),
sched_log_path.as_deref(),
suppress_com2,
probe_output_done,
);
let _s_phase5 = tracing::debug_span!("phase5_dispatch").entered();
tracing::debug!("dispatching test");
crate::vmm::guest_comms::send_lifecycle(crate::vmm::wire::LifecyclePhase::PayloadStarting, "");
let code = if let Some(pa) = probe_phase_a {
crate::test_support::maybe_dispatch_vm_test_with_phase_a(&args, pa).unwrap_or(1)
} else {
crate::test_support::maybe_dispatch_vm_test_with_args(&args).unwrap_or(1)
};
drop(_s_phase5);
let _ = std::io::stdout().flush();
let _ = std::io::stderr().flush();
crate::test_support::try_flush_profraw();
let _s_phase6 = tracing::debug_span!("phase6_cleanup").entered();
if let Some(ref mut child) = sched_child {
let _ = child.kill();
let _ = child.wait();
if let Some(ref log_path) = sched_log_path {
dump_sched_output(log_path);
}
}
exec_shell_script("/sched_disable");
if let Some(ref stop) = vc_poll_stop {
stop.store(true, Ordering::Release);
}
if let Some(ref handle) = sched_exit_stop {
handle.stop.store(true, Ordering::Release);
handle.wake();
}
let _ = fs::write(TRACE_SCHED_EXT_DUMP_ENABLE, "0");
if let Some(ref stop) = trace_stop {
stop.store(true, Ordering::Release);
}
let _ = fs::write(TRACE_TRACING_ON, "0");
if let Some(handle) = trace_handle {
let _ = handle.join();
}
if let Ok(com1) = fs::OpenOptions::new().write(true).open(COM1) {
use std::os::unix::io::AsRawFd;
unsafe {
libc::tcdrain(com1.as_raw_fd());
}
}
let _ = std::io::stdout().flush();
let _ = std::io::stderr().flush();
crate::vmm::guest_comms::send_exit(code as i32);
if let Ok(com2) = fs::OpenOptions::new().write(true).open(COM2) {
use std::os::unix::io::AsRawFd;
unsafe {
libc::tcdrain(com2.as_raw_fd());
}
}
force_reboot()
}
const STDIO_CHUNK_BYTES: usize = 4 * 1024;
fn redirect_stdio_to_bulk_port() {
use std::io::Read;
use std::os::unix::io::{AsRawFd, FromRawFd};
fn make_pipe() -> Option<(std::fs::File, std::fs::File)> {
let mut fds = [0i32; 2];
let r = unsafe { libc::pipe(fds.as_mut_ptr()) };
if r < 0 {
return None;
}
let read_end = unsafe { std::fs::File::from_raw_fd(fds[0]) };
let write_end = unsafe { std::fs::File::from_raw_fd(fds[1]) };
Some((read_end, write_end))
}
fn spawn_forwarder(mut read_end: std::fs::File, name: &'static str, sender: fn(&[u8]) -> bool) {
let _ = std::thread::Builder::new()
.name(name.into())
.spawn(move || {
let mut buf = [0u8; STDIO_CHUNK_BYTES];
loop {
match read_end.read(&mut buf) {
Ok(0) => break, Ok(n) => {
let _ = sender(&buf[..n]);
}
Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
Err(_) => break,
}
}
});
}
let Some((stdout_r, stdout_w)) = make_pipe() else {
eprintln!("ktstr-init: redirect_stdio_to_bulk_port: pipe(stdout) failed");
return;
};
let Some((stderr_r, stderr_w)) = make_pipe() else {
eprintln!("ktstr-init: redirect_stdio_to_bulk_port: pipe(stderr) failed");
return;
};
let (rc1, err1, rc2, err2) = unsafe {
let r1 = libc::dup2(stdout_w.as_raw_fd(), 1);
let e1 = std::io::Error::last_os_error();
let r2 = libc::dup2(stderr_w.as_raw_fd(), 2);
let e2 = std::io::Error::last_os_error();
(r1, e1, r2, e2)
};
if rc1 < 0 {
eprintln!("ktstr-init: redirect_stdio_to_bulk_port: dup2(stdout) failed: {err1}");
}
if rc2 < 0 {
eprintln!("ktstr-init: redirect_stdio_to_bulk_port: dup2(stderr) failed: {err2}");
}
spawn_forwarder(stdout_r, "ktstr-stdout-fwd", |b| {
crate::vmm::guest_comms::send_stdout_chunk(b)
});
spawn_forwarder(stderr_r, "ktstr-stderr-fwd", |b| {
crate::vmm::guest_comms::send_stderr_chunk(b)
});
}
fn shell_mode_requested() -> bool {
fs::read_to_string("/proc/cmdline")
.map(|c| cmdline_contains_token(&c, "KTSTR_MODE=shell"))
.unwrap_or(false)
}
fn disk_template_mode_requested() -> bool {
fs::read_to_string("/proc/cmdline")
.map(|c| cmdline_contains_token(&c, "KTSTR_MODE=disk_template"))
.unwrap_or(false)
}
fn cmdline_contains_token(cmdline: &str, token: &str) -> bool {
cmdline.split_whitespace().any(|s| s == token)
}
fn run_disk_template_mode() -> i32 {
redirect_stdio_to_bulk_port();
const MKFS: &str = "/bin/mkfs.btrfs";
tracing::info!(mkfs = MKFS, target = "/dev/vda", "running mkfs.btrfs");
let status = with_sigchld_default(|| {
Command::new(MKFS)
.args(["-f", "--quiet", "/dev/vda"])
.status()
});
match status {
Ok(s) => s.code().unwrap_or(1),
Err(e) => {
eprintln!("ktstr-init: failed to spawn {MKFS}: {e}");
1
}
}
}
fn shell_exec_cmd() -> Option<String> {
fs::read_to_string("/exec_cmd")
.ok()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
}
fn cmdline_val(key: &str) -> Option<String> {
let cmdline = fs::read_to_string("/proc/cmdline").ok()?;
let prefix = format!("{key}=");
cmdline
.split_whitespace()
.find_map(|s| s.strip_prefix(&prefix))
.map(|s| s.to_string())
}
fn build_include_path() -> String {
use std::collections::BTreeSet;
use std::os::unix::fs::PermissionsExt;
let include_dir = std::path::Path::new("/include-files");
let mut dirs = BTreeSet::new();
if include_dir.is_dir() {
for entry in walkdir::WalkDir::new(include_dir).follow_links(true) {
let Ok(entry) = entry else { continue };
if entry.file_type().is_file()
&& entry
.metadata()
.is_ok_and(|m| m.permissions().mode() & 0o111 != 0)
&& let Some(parent) = entry.path().parent()
{
dirs.insert(parent.to_string_lossy().to_string());
}
}
}
let mut path_parts: Vec<String> = dirs.into_iter().collect();
path_parts.push("/bin".to_string());
path_parts.join(":")
}
fn redirect_all_stdio_to(path: &str) {
use std::os::unix::io::AsRawFd;
let Ok(dev) = fs::OpenOptions::new().read(true).write(true).open(path) else {
return;
};
let fd = dev.as_raw_fd();
let (rc0, err0, rc1, err1, rc2, err2) = unsafe {
let r0 = libc::dup2(fd, 0);
let e0 = std::io::Error::last_os_error();
let r1 = libc::dup2(fd, 1);
let e1 = std::io::Error::last_os_error();
let r2 = libc::dup2(fd, 2);
let e2 = std::io::Error::last_os_error();
(r0, e0, r1, e1, r2, e2)
};
if rc0 < 0 {
eprintln!("ktstr-init: redirect_all_stdio_to({path}): dup2(stdin) failed: {err0}");
}
if rc1 < 0 {
eprintln!("ktstr-init: redirect_all_stdio_to({path}): dup2(stdout) failed: {err1}");
}
if rc2 < 0 {
eprintln!("ktstr-init: redirect_all_stdio_to({path}): dup2(stderr) failed: {err2}");
}
}
fn shell_console_device() -> &'static str {
if Path::new(HVC0).exists() { HVC0 } else { COM2 }
}
fn mount_devpts() {
mkdir_p("/dev/pts");
let result = mount(
Some("devpts"),
"/dev/pts",
Some("devpts"),
MsFlags::empty(),
None::<&str>,
);
if let Err(e) = result {
eprintln!("ktstr-init: mount devpts on /dev/pts: {e}");
}
}
fn spawn_shell_with_pty() {
let pty = match openpty(None, None) {
Ok(p) => p,
Err(e) => {
eprintln!("ktstr-init: openpty failed: {e}");
return;
}
};
let slave_fd = pty.slave.as_raw_fd();
if let (Some(cols), Some(rows)) = (cmdline_val("KTSTR_COLS"), cmdline_val("KTSTR_ROWS"))
&& let (Ok(cols), Ok(rows)) = (cols.parse::<u16>(), rows.parse::<u16>())
{
let ws = libc::winsize {
ws_row: rows,
ws_col: cols,
ws_xpixel: 0,
ws_ypixel: 0,
};
unsafe {
libc::ioctl(slave_fd, libc::TIOCSWINSZ, &ws);
}
}
let term = cmdline_val("KTSTR_TERM").unwrap_or_else(|| "linux".to_string());
let colorterm = cmdline_val("KTSTR_COLORTERM");
let child = unsafe {
let mut cmd = Command::new("/bin/busybox");
cmd.arg("sh")
.env("TERM", &term)
.env("PS1", "\x1b[2m^Ax=quit\x1b[0m \\w # ");
if let Some(ref ct) = colorterm {
cmd.env("COLORTERM", ct);
}
cmd.stdin(Stdio::from(OwnedFd::from_raw_fd(libc::dup(slave_fd))))
.stdout(Stdio::from(OwnedFd::from_raw_fd(libc::dup(slave_fd))))
.stderr(Stdio::from(OwnedFd::from_raw_fd(libc::dup(slave_fd))))
.pre_exec(move || {
if libc::setsid() < 0 {
return Err(std::io::Error::last_os_error());
}
if libc::ioctl(slave_fd, libc::TIOCSCTTY, 0) < 0 {
return Err(std::io::Error::last_os_error());
}
Ok(())
})
.spawn()
};
drop(pty.slave);
let mut child = match child {
Ok(c) => c,
Err(e) => {
eprintln!("ktstr-init: spawn shell: {e}");
return;
}
};
let child_pid = child.id();
let stdin_fd = unsafe { BorrowedFd::borrow_raw(0) };
if let Ok(mut termios) = tcgetattr(stdin_fd) {
cfmakeraw(&mut termios);
let _ = tcsetattr(stdin_fd, SetArg::TCSANOW, &termios);
}
proxy_serial_pty(&pty.master, child_pid);
match child.wait() {
Ok(status) => {
tracing::debug!(?status, "shell exited");
}
Err(e) if e.raw_os_error() == Some(libc::ECHILD) => {}
Err(e) => {
eprintln!("ktstr-init: wait for shell: {e}");
}
}
}
fn proxy_serial_pty(master: &OwnedFd, child_pid: u32) {
let stdin_fd = unsafe { BorrowedFd::borrow_raw(0) };
let stdout_fd = unsafe { BorrowedFd::borrow_raw(1) };
let master_fd = master.as_fd();
let mut buf = [0u8; 4096];
loop {
let mut pollfds = [
PollFd::new(stdin_fd, PollFlags::POLLIN),
PollFd::new(master_fd, PollFlags::POLLIN),
];
match poll(&mut pollfds, PollTimeout::from(200u16)) {
Ok(0) => {
if !Path::new(&format!("/proc/{child_pid}")).exists() {
break;
}
continue;
}
Ok(_) => {}
Err(nix::errno::Errno::EINTR) => continue,
Err(_) => break,
}
if let Some(revents) = pollfds[0].revents() {
if revents.contains(PollFlags::POLLIN) {
match nix::unistd::read(stdin_fd, &mut buf) {
Ok(0) => break,
Ok(n) => {
let _ = nix::unistd::write(master_fd, &buf[..n]);
}
Err(nix::errno::Errno::EINTR) => {}
Err(_) => break,
}
}
if revents.intersects(PollFlags::POLLERR | PollFlags::POLLHUP) {
break;
}
}
if let Some(revents) = pollfds[1].revents() {
if revents.intersects(PollFlags::POLLERR | PollFlags::POLLHUP) {
break;
}
if revents.contains(PollFlags::POLLIN) {
match nix::unistd::read(master_fd, &mut buf) {
Ok(0) => break,
Ok(n) => {
let _ = nix::unistd::write(stdout_fd, &buf[..n]);
}
Err(nix::errno::Errno::EINTR) => {}
Err(_) => break,
}
}
}
}
}
fn print_topology_line() {
if let Some((n, l, c, t)) = parse_topo_from_cmdline() {
let total = l * c * t;
if n > 1 {
println!(
" topology: {n} NUMA nodes, {l} LLC{}, {c} core{}, {t} thread{} ({total} vCPU{})",
if l == 1 { "" } else { "s" },
if c == 1 { "" } else { "s" },
if t == 1 { "" } else { "s" },
if total == 1 { "" } else { "s" },
);
} else {
println!(
" topology: {l} LLC{}, {c} core{}, {t} thread{} ({total} vCPU{})",
if l == 1 { "" } else { "s" },
if c == 1 { "" } else { "s" },
if t == 1 { "" } else { "s" },
if total == 1 { "" } else { "s" },
);
}
} else if let Some(count) = count_online_cpus() {
println!(
" topology: {count} vCPU{}",
if count == 1 { "" } else { "s" }
);
}
}
fn parse_topo_from_cmdline() -> Option<(u32, u32, u32, u32)> {
let val = cmdline_val("KTSTR_TOPO")?;
let parts: Vec<&str> = val.split(',').collect();
if parts.len() != 4 {
return None;
}
let n: u32 = parts[0].parse().ok()?;
let l: u32 = parts[1].parse().ok()?;
let c: u32 = parts[2].parse().ok()?;
let t: u32 = parts[3].parse().ok()?;
Some((n, l, c, t))
}
fn count_online_cpus() -> Option<u32> {
let content = fs::read_to_string("/sys/devices/system/cpu/online").ok()?;
let mut count = 0u32;
for range in content.trim().split(',') {
if let Some((start, end)) = range.split_once('-') {
let s: u32 = start.parse().ok()?;
let e: u32 = end.parse().ok()?;
count += e - s + 1;
} else {
let _: u32 = range.parse().ok()?;
count += 1;
}
}
Some(count)
}
fn print_includes_line() {
let include_dir = Path::new("/include-files");
if !include_dir.is_dir() {
return;
}
let mut files: Vec<(String, bool)> = Vec::new();
for entry in walkdir::WalkDir::new(include_dir)
.min_depth(1)
.sort_by_file_name()
{
let Ok(entry) = entry else { continue };
if !entry.file_type().is_file() {
continue;
}
let rel = entry
.path()
.strip_prefix(include_dir)
.unwrap_or(entry.path());
let name = rel.to_string_lossy().to_string();
let executable = entry
.metadata()
.map(|m| {
use std::os::unix::fs::PermissionsExt;
m.permissions().mode() & 0o111 != 0
})
.unwrap_or(false);
files.push((name, executable));
}
if files.is_empty() {
return;
}
for (i, (name, executable)) in files.iter().enumerate() {
let marker = if *executable { " (executable)" } else { "" };
let path = format!("/include-files/{name}{marker}");
if i == 0 {
println!(" includes: {path}");
} else {
println!(" {path}");
}
}
}
fn mount_filesystems() {
let mounts: &[(&str, &str, &str, bool)] = &[
("/proc", "proc", "proc", true),
("/sys", "sys", "sysfs", true),
("/dev", "dev", "devtmpfs", true),
("/sys/kernel/debug", "debugfs", "debugfs", false),
("/sys/kernel/tracing", "tracefs", "tracefs", false),
("/sys/fs/bpf", "bpffs", "bpf", false),
("/sys/fs/cgroup", "none", "cgroup2", false),
("/tmp", "tmpfs", "tmpfs", true),
("/dev/shm", "tmpfs", "tmpfs", false),
("/run", "tmpfs", "tmpfs", false),
];
for &(target, source, fstype, required) in mounts {
mkdir_p(target);
let result = mount(
Some(source),
target,
Some(fstype),
MsFlags::empty(),
None::<&str>,
);
if let Err(e) = result
&& required
{
eprintln!("ktstr-init: mount {fstype} on {target}: {e}");
}
}
let _ = std::os::unix::fs::symlink("/proc/self/fd", "/dev/fd");
let _ = std::os::unix::fs::symlink("/proc/self/fd/0", "/dev/stdin");
let _ = std::os::unix::fs::symlink("/proc/self/fd/1", "/dev/stdout");
let _ = std::os::unix::fs::symlink("/proc/self/fd/2", "/dev/stderr");
}
fn auto_mount_data_disks() {
let Some(fstype) = cmdline_val("KTSTR_DISK0_FS") else {
return;
};
let recognized = matches!(fstype.as_str(), "btrfs");
if !recognized {
let msg = format!(
"ktstr-init: KTSTR_DISK0_FS={fstype} not recognized; \
skipping auto-mount of /dev/vda"
);
let _ = fs::write(COM2, &msg);
eprintln!("{msg}");
return;
}
let ro = cmdline_val("KTSTR_DISK0_RO").as_deref() == Some("1");
let mount_point_owned =
cmdline_val("KTSTR_DISK0_MOUNT").unwrap_or_else(|| "/mnt/disk0".to_string());
let mount_point = mount_point_owned.as_str();
mkdir_p(mount_point);
let flags = if ro {
MsFlags::MS_RDONLY
} else {
MsFlags::empty()
};
let result = mount(
Some("/dev/vda"),
mount_point,
Some(fstype.as_str()),
flags,
None::<&str>,
);
if let Err(e) = result {
let msg = format!(
"ktstr-init: mount {fstype} on {mount_point} \
(ro={ro}): {e}"
);
let _ = fs::write(COM2, &msg);
eprintln!("{msg}");
}
}
fn mkdir_p(path: &str) {
use std::os::unix::fs::DirBuilderExt;
let _ = fs::DirBuilder::new()
.recursive(true)
.mode(0o755)
.create(path);
}
fn write_com2(msg: &str) {
if let Ok(mut f) = fs::OpenOptions::new().write(true).open(COM2) {
let _ = writeln!(f, "{msg}");
} else {
eprintln!("ktstr-init [COM1 fallback]: {msg}");
}
}
#[tracing::instrument]
fn create_cgroup_parent_from_sched_args() {
let sched_args = match fs::read_to_string("/sched_args") {
Ok(s) => s,
Err(_) => return,
};
let args: Vec<&str> = sched_args.split_whitespace().collect();
for i in 0..args.len() {
if args[i] == "--cell-parent-cgroup"
&& let Some(&path) = args.get(i + 1)
{
let cgroup_dir = format!("/sys/fs/cgroup{path}");
mkdir_p(&cgroup_dir);
enable_subtree_controllers_to(&cgroup_dir);
return;
}
}
}
fn enable_subtree_controllers_to(leaf: &str) {
let cgroup_root = Path::new("/sys/fs/cgroup");
let leaf_path = Path::new(leaf);
if !leaf_path.starts_with(cgroup_root) || leaf_path == cgroup_root {
return;
}
let mut ancestors: Vec<&Path> = leaf_path
.ancestors()
.skip(1)
.take_while(|p| p.starts_with(cgroup_root))
.collect();
ancestors.reverse();
for level in ancestors {
let control = level.join("cgroup.subtree_control");
if let Err(e) = fs::write(&control, "+cpuset +cpu") {
write_com2(&format!(
"ktstr-init: write {} +cpuset +cpu: {}",
control.display(),
e
));
}
}
}
#[derive(Debug)]
enum StartupStatus {
Died,
Alive,
}
#[derive(Debug, PartialEq, Eq)]
enum ScxAttachStatus {
Attached,
Timeout,
SysfsAbsent,
}
impl ScxAttachStatus {
fn is_attached(&self) -> bool {
matches!(self, ScxAttachStatus::Attached)
}
}
fn poll_scx_attached(
interval: std::time::Duration,
timeout: std::time::Duration,
) -> ScxAttachStatus {
let start = std::time::Instant::now();
let mut ever_read_ok = false;
let attr_fd: Option<OwnedFd> = {
let raw = unsafe {
libc::open(
c"/sys/kernel/sched_ext/root/ops".as_ptr(),
libc::O_RDONLY | libc::O_CLOEXEC,
)
};
if raw < 0 {
None
} else {
Some(unsafe { OwnedFd::from_raw_fd(raw) })
}
};
let interval_ms_clamped = interval.as_millis().min(i32::MAX as u128) as i32;
loop {
match fs::read_to_string(SYSFS_SCHED_EXT_ROOT_OPS) {
Ok(contents) => {
ever_read_ok = true;
if !contents.trim().is_empty() {
return ScxAttachStatus::Attached;
}
}
Err(_) => {
}
}
let now = std::time::Instant::now();
if now.duration_since(start) >= timeout {
return if ever_read_ok {
ScxAttachStatus::Timeout
} else {
ScxAttachStatus::SysfsAbsent
};
}
let remaining_ms = (start + timeout - now)
.as_millis()
.min(interval_ms_clamped as u128) as i32;
if let Some(ref fd) = attr_fd {
let mut pfd = libc::pollfd {
fd: fd.as_raw_fd(),
events: libc::POLLPRI,
revents: 0,
};
let _ = unsafe { libc::poll(&mut pfd, 1, remaining_ms) };
} else {
std::thread::sleep(std::time::Duration::from_millis(remaining_ms.max(0) as u64));
}
}
}
fn poll_startup(
child: &mut Child,
interval: std::time::Duration,
timeout: std::time::Duration,
) -> StartupStatus {
let pid = child.id();
let pidfd =
unsafe { libc::syscall(libc::SYS_pidfd_open, pid as libc::c_int, 0u32) as libc::c_int };
if pidfd < 0 {
let start = std::time::Instant::now();
loop {
if !proc_pid_alive(pid) {
return StartupStatus::Died;
}
let now = std::time::Instant::now();
if now >= start + timeout {
return StartupStatus::Alive;
}
let remaining = (start + timeout) - now;
std::thread::sleep(remaining.min(interval));
}
}
let start = std::time::Instant::now();
let result = loop {
let now = std::time::Instant::now();
if now >= start + timeout {
break if proc_pid_alive(pid) {
StartupStatus::Alive
} else {
StartupStatus::Died
};
}
let remaining_ms = (start + timeout - now).as_millis().min(i32::MAX as u128) as i32;
let mut pfd = libc::pollfd {
fd: pidfd,
events: libc::POLLIN,
revents: 0,
};
let rc = unsafe { libc::poll(&mut pfd, 1, remaining_ms) };
if rc > 0 && pfd.revents & libc::POLLIN != 0 {
break StartupStatus::Died;
}
};
unsafe {
libc::close(pidfd);
}
result
}
struct ProbeDrain {
stop: Arc<AtomicBool>,
output_done: Arc<crate::sync::Latch>,
}
fn drain_probe_pipeline(drain: Option<&ProbeDrain>) {
let Some(d) = drain else { return };
d.stop.store(true, Ordering::Release);
d.output_done.wait();
}
#[tracing::instrument(skip(probe_drain))]
fn start_scheduler(probe_drain: Option<ProbeDrain>) -> (Option<Child>, Option<String>) {
if !Path::new("/scheduler").exists() {
return (None, None);
}
let sched_args = fs::read_to_string("/sched_args")
.unwrap_or_default()
.trim()
.to_string();
let args: Vec<&str> = if sched_args.is_empty() {
vec![]
} else {
sched_args.split_whitespace().collect()
};
let log_path = "/tmp/sched.log";
let log_file = fs::File::create(log_path).ok();
let stdout = match log_file.as_ref().and_then(|f| f.try_clone().ok()) {
Some(f) => Stdio::from(f),
None => Stdio::null(),
};
let stderr = match log_file {
Some(f) => Stdio::from(f),
None => Stdio::null(),
};
let sched_rust_log = match std::env::var("RUST_LOG") {
Ok(existing) => format!("{existing},scx_utils::libbpf_logger=warn"),
Err(_) => "info,scx_utils::libbpf_logger=warn".to_string(),
};
let child = Command::new("/scheduler")
.args(&args)
.env("RUST_LOG", &sched_rust_log)
.stdout(stdout)
.stderr(stderr)
.spawn();
match child {
Ok(mut child) => {
SCHED_PID.store(child.id() as i32, Ordering::Release);
match poll_startup(
&mut child,
std::time::Duration::from_millis(50),
std::time::Duration::from_secs(1),
) {
StartupStatus::Died => {
dump_sched_output(log_path);
crate::vmm::guest_comms::send_lifecycle(
crate::vmm::wire::LifecyclePhase::SchedulerDied,
"",
);
crate::vmm::guest_comms::send_exit(1);
drain_probe_pipeline(probe_drain.as_ref());
force_reboot();
}
StartupStatus::Alive => {
let status = poll_scx_attached(
std::time::Duration::from_millis(50),
std::time::Duration::from_secs(3),
);
if !status.is_attached() {
dump_sched_output(log_path);
let reason = match status {
ScxAttachStatus::Timeout => "timeout",
ScxAttachStatus::SysfsAbsent => "sched_ext sysfs absent",
ScxAttachStatus::Attached => unreachable!(),
};
crate::vmm::guest_comms::send_lifecycle(
crate::vmm::wire::LifecyclePhase::SchedulerNotAttached,
reason,
);
crate::vmm::guest_comms::send_exit(1);
drain_probe_pipeline(probe_drain.as_ref());
force_reboot();
}
(Some(child), Some(log_path.to_string()))
}
}
}
Err(e) => {
eprintln!("ktstr-init: spawn scheduler: {e}");
crate::vmm::guest_comms::send_sched_log(crate::verifier::SCHED_OUTPUT_START.as_bytes());
send_sched_log_text(&format!("failed to spawn: {e}"));
crate::vmm::guest_comms::send_sched_log(crate::verifier::SCHED_OUTPUT_END.as_bytes());
crate::vmm::guest_comms::send_lifecycle(
crate::vmm::wire::LifecyclePhase::SchedulerDied,
"",
);
crate::vmm::guest_comms::send_exit(1);
drain_probe_pipeline(probe_drain.as_ref());
force_reboot();
}
}
}
const SCHED_LOG_CHUNK_BYTES: usize = 64 * 1024;
fn dump_sched_output(log_path: &str) {
crate::vmm::guest_comms::send_sched_log(crate::verifier::SCHED_OUTPUT_START.as_bytes());
send_sched_log_file(log_path);
crate::vmm::guest_comms::send_sched_log(crate::verifier::SCHED_OUTPUT_END.as_bytes());
}
fn send_sched_log_file(path: &str) {
let Ok(content) = fs::read_to_string(path) else {
return;
};
let bytes = content.as_bytes();
let mut start = 0usize;
while start < bytes.len() {
let end = (start + SCHED_LOG_CHUNK_BYTES).min(bytes.len());
crate::vmm::guest_comms::send_sched_log(&bytes[start..end]);
start = end;
}
}
fn send_sched_log_text(s: &str) {
let bytes = s.as_bytes();
let cap = SCHED_LOG_CHUNK_BYTES.min(bytes.len());
crate::vmm::guest_comms::send_sched_log(&bytes[..cap]);
}
fn start_trace_pipe() -> (Option<Arc<AtomicBool>>, Option<std::thread::JoinHandle<()>>) {
if Path::new(TRACE_SCHED_EXT_DUMP_ENABLE).exists() {
let _ = fs::write(TRACE_SCHED_EXT_DUMP_ENABLE, "1");
let stop = Arc::new(AtomicBool::new(false));
let stop_clone = stop.clone();
let handle = std::thread::Builder::new()
.name("trace-pipe".into())
.spawn(move || {
use std::os::unix::fs::OpenOptionsExt;
let Ok(mut trace) = fs::OpenOptions::new()
.read(true)
.custom_flags(libc::O_NONBLOCK)
.open(TRACE_PIPE)
else {
return;
};
let Ok(mut com1) = fs::OpenOptions::new().write(true).open(COM1) else {
return;
};
let mut buf = [0u8; 4096];
let mut drain_deadline = None;
loop {
if drain_deadline.is_none() && stop_clone.load(Ordering::Acquire) {
drain_deadline =
Some(std::time::Instant::now() + std::time::Duration::from_secs(5));
}
if drain_deadline.is_some_and(|d| std::time::Instant::now() >= d) {
break;
}
let mut pollfds = [PollFd::new(trace.as_fd(), PollFlags::POLLIN)];
match poll(&mut pollfds, PollTimeout::from(200u16)) {
Ok(0) => continue,
Ok(_) => {}
Err(nix::errno::Errno::EINTR) => continue,
Err(_) => break,
}
if let Some(revents) = pollfds[0].revents() {
if revents.intersects(PollFlags::POLLERR | PollFlags::POLLNVAL) {
break;
}
if !revents.contains(PollFlags::POLLIN) {
if revents.contains(PollFlags::POLLHUP) {
break;
}
continue;
}
}
loop {
match trace.read(&mut buf) {
Ok(0) => break,
Ok(n) => {
let _ = com1.write_all(&buf[..n]);
}
Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
Err(_) => break,
}
}
}
})
.ok();
(Some(stop), handle)
} else {
(None, None)
}
}
static BPF_MAP_WRITE_DONE_LATCH: OnceLock<Arc<Latch>> = OnceLock::new();
pub(crate) fn bpf_map_write_done_latch() -> Arc<Latch> {
BPF_MAP_WRITE_DONE_LATCH
.get_or_init(|| Arc::new(Latch::new()))
.clone()
}
fn start_hvc0_poll(trace_stop: Option<Arc<AtomicBool>>) -> Option<Arc<AtomicBool>> {
let stop = Arc::new(AtomicBool::new(false));
let stop_clone = stop.clone();
std::thread::Builder::new()
.name("hvc0-poll".into())
.spawn(move || {
hvc0_poll_loop(&stop_clone, trace_stop.as_deref());
})
.ok();
Some(stop)
}
fn hvc0_poll_loop(stop: &AtomicBool, trace_stop: Option<&AtomicBool>) {
use std::os::unix::io::AsRawFd;
let hvc0 = match fs::OpenOptions::new()
.read(true)
.custom_flags(libc::O_NONBLOCK)
.open(HVC0)
{
Ok(f) => f,
Err(e) => {
write_com2(&format!(
"ktstr-init: hvc0 poll loop disabled — open {HVC0}: {e}"
));
return;
}
};
let poll_timeout_ms: PollTimeout = 1000u16.into();
while !stop.load(Ordering::Acquire) {
let borrowed = unsafe { BorrowedFd::borrow_raw(hvc0.as_raw_fd()) };
let mut fds = [PollFd::new(borrowed, PollFlags::POLLIN)];
match poll(&mut fds, poll_timeout_ms) {
Ok(0) => continue,
Ok(_) => {}
Err(nix::errno::Errno::EINTR) => continue,
Err(_) => break,
}
if let Some(revents) = fds[0].revents() {
if revents.intersects(PollFlags::POLLERR | PollFlags::POLLNVAL) {
break;
}
if !revents.contains(PollFlags::POLLIN) {
if revents.contains(PollFlags::POLLHUP) {
break;
}
continue;
}
}
let mut buf = [0u8; 16];
let mut hvc_ref: &fs::File = &hvc0;
let n = 'read_retry: loop {
match hvc_ref.read(&mut buf) {
Ok(n) => break 'read_retry Some(n),
Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
Err(e) => {
tracing::warn!(
err = %e,
"ktstr-init: hvc0 read failed; aborting poll loop"
);
break 'read_retry None;
}
}
};
let Some(n) = n else { break };
if buf[..n].contains(&crate::vmm::virtio_console::SIGNAL_VC_DUMP) {
let _ = fs::write("/proc/sysrq-trigger", "D");
}
if buf[..n].contains(&crate::vmm::virtio_console::SIGNAL_BPF_WRITE_DONE) {
bpf_map_write_done_latch().set();
}
if buf[..n].contains(&crate::vmm::virtio_console::SIGNAL_VC_SHUTDOWN) {
eprintln!("ktstr-init: shutdown request received, draining");
if let Some(ts) = trace_stop {
ts.store(true, Ordering::Release);
}
let _ = fs::write(TRACE_TRACING_ON, "0");
let _ = std::io::stdout().flush();
let _ = std::io::stderr().flush();
if let Ok(f) = fs::OpenOptions::new().write(true).open(COM1) {
unsafe {
libc::tcdrain(std::os::unix::io::AsRawFd::as_raw_fd(&f));
}
}
if let Ok(f) = fs::OpenOptions::new().write(true).open(COM2) {
unsafe {
libc::tcdrain(std::os::unix::io::AsRawFd::as_raw_fd(&f));
}
}
break;
}
}
}
pub(crate) struct SchedExitStop {
pub(crate) stop: Arc<AtomicBool>,
wake_fd: Option<OwnedFd>,
}
impl SchedExitStop {
pub(crate) fn wake(&self) {
if let Some(ref fd) = self.wake_fd {
let val: u64 = 1;
let bytes = val.to_ne_bytes();
let _ = unsafe {
libc::write(
fd.as_raw_fd(),
bytes.as_ptr() as *const libc::c_void,
bytes.len(),
)
};
}
}
}
fn start_sched_exit_monitor(
sched_pid: Option<u32>,
log_path: Option<&str>,
suppress_com2: Arc<AtomicBool>,
probe_output_done: Option<Arc<crate::sync::Latch>>,
) -> Option<SchedExitStop> {
let pid = sched_pid?;
let proc_path = format!("/proc/{pid}");
let log_path = log_path.map(|s| s.to_string());
let stop = Arc::new(AtomicBool::new(false));
let stop_clone = stop.clone();
let (monitor_fd, writer_fd): (Option<OwnedFd>, Option<OwnedFd>) = {
let raw = unsafe { libc::eventfd(0, libc::EFD_NONBLOCK | libc::EFD_CLOEXEC) };
if raw < 0 {
let err = std::io::Error::last_os_error();
tracing::warn!(
err = %err,
"ktstr-init: sched-exit-mon eventfd allocation failed; \
falling back to 250 ms stop poll cadence"
);
(None, None)
} else {
let monitor_fd = unsafe { OwnedFd::from_raw_fd(raw) };
match monitor_fd.try_clone() {
Ok(writer_fd) => (Some(monitor_fd), Some(writer_fd)),
Err(e) => {
tracing::warn!(
err = %e,
"ktstr-init: sched-exit-mon eventfd dup failed; \
falling back to 250 ms stop poll cadence"
);
(Some(monitor_fd), None)
}
}
}
};
std::thread::Builder::new()
.name("sched-exit-mon".into())
.spawn(move || {
let pidfd = unsafe {
libc::syscall(libc::SYS_pidfd_open, pid as libc::c_int, 0u32) as libc::c_int
};
if pidfd < 0 {
eprintln!(
"ktstr-init: pidfd_open failed for sched pid {pid}: {} \
— sched exit monitor disabled",
std::io::Error::last_os_error(),
);
return;
}
let stop_fd = monitor_fd.as_ref().map(|f| f.as_raw_fd()).unwrap_or(-1);
let poll_timeout: i32 = if stop_fd >= 0 { -1 } else { 250 };
while !stop_clone.load(Ordering::Acquire) {
let exited = {
let mut pfds = [
libc::pollfd {
fd: pidfd,
events: libc::POLLIN,
revents: 0,
},
libc::pollfd {
fd: stop_fd,
events: libc::POLLIN,
revents: 0,
},
];
let _ = unsafe {
libc::poll(pfds.as_mut_ptr(), pfds.len() as libc::nfds_t, poll_timeout)
};
!Path::new(&proc_path).exists()
};
if exited {
if suppress_com2.load(Ordering::Acquire) {
if let Some(ref done) = probe_output_done {
done.wait();
}
} else if let Some(ref path) = log_path {
dump_sched_output(path);
}
let exit_code: i32 = 1;
crate::vmm::guest_comms::send_sched_exit(exit_code);
unsafe {
libc::close(pidfd);
}
return;
}
if stop_fd >= 0 {
let mut buf = [0u8; 8];
let _ = unsafe {
libc::read(stop_fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len())
};
}
}
unsafe {
libc::close(pidfd);
}
})
.ok();
Some(SchedExitStop {
stop,
wake_fd: writer_fd,
})
}
#[tracing::instrument]
fn exec_shell_script(path: &str) {
let content = match fs::read_to_string(path) {
Ok(c) => c,
Err(_) => return,
};
for line in content.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
exec_shell_line(line);
}
}
fn exec_shell_line(line: &str) {
if let Some(rest) = line.strip_prefix("echo ")
&& let Some((value, path)) = rest.split_once(" > ")
{
let value = value.trim();
let path = path.trim();
if let Err(e) = fs::write(path, format!("{value}\n")) {
eprintln!("ktstr-init: echo '{value}' > {path}: {e}");
}
return;
}
eprintln!("ktstr-init: unsupported command: {line}");
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mkdir_p_creates_nested() {
let base = std::env::temp_dir().join("ktstr-rust-init-test-mkdir");
let _ = fs::remove_dir_all(&base);
let nested = base.join("a/b/c");
mkdir_p(nested.to_str().unwrap());
assert!(nested.exists());
let _ = fs::remove_dir_all(&base);
}
#[test]
fn mkdir_p_existing_is_noop() {
let tmp = std::env::temp_dir();
mkdir_p(tmp.to_str().unwrap());
}
#[test]
fn exec_shell_line_echo_redirect() {
let tmp = std::env::temp_dir().join("ktstr-rust-init-echo-test");
let path = tmp.to_str().unwrap();
exec_shell_line(&format!("echo 42 > {path}"));
let content = fs::read_to_string(&tmp).unwrap();
assert_eq!(content, "42\n");
let _ = fs::remove_file(&tmp);
}
#[test]
fn exec_shell_line_unsupported_input_no_panic() {
exec_shell_line("# this is a comment");
}
#[test]
fn shell_mode_not_requested_in_test() {
assert!(!shell_mode_requested());
}
#[test]
fn disk_template_mode_not_requested_in_test() {
assert!(!disk_template_mode_requested());
}
#[test]
fn disk_template_dispatch_precedes_shell_when_both_present() {
let cmdline = "ro KTSTR_MODE=disk_template KTSTR_MODE=shell console=ttyS0";
assert!(cmdline_contains_token(cmdline, "KTSTR_MODE=disk_template"));
assert!(cmdline_contains_token(cmdline, "KTSTR_MODE=shell"));
let cmdline_reversed = "ro KTSTR_MODE=shell KTSTR_MODE=disk_template console=ttyS0";
assert!(cmdline_contains_token(
cmdline_reversed,
"KTSTR_MODE=disk_template"
));
assert!(cmdline_contains_token(cmdline_reversed, "KTSTR_MODE=shell"));
}
#[test]
fn cmdline_contains_token_exact_match_not_prefix() {
assert!(cmdline_contains_token(
"KTSTR_MODE=shell",
"KTSTR_MODE=shell"
));
assert!(!cmdline_contains_token(
"KTSTR_MODE=shell_extended",
"KTSTR_MODE=shell"
));
assert!(!cmdline_contains_token(
"prefix_KTSTR_MODE=shell",
"KTSTR_MODE=shell"
));
assert!(!cmdline_contains_token("", "KTSTR_MODE=shell"));
}
#[test]
fn count_online_cpus_returns_some() {
let count = count_online_cpus();
assert!(count.is_some());
assert!(count.unwrap() >= 1);
}
#[test]
fn parse_topo_from_cmdline_not_present_on_host() {
assert!(parse_topo_from_cmdline().is_none());
}
#[test]
fn poll_startup_detects_early_death_quickly() {
let mut child = std::process::Command::new("/bin/true")
.spawn()
.expect("spawn /bin/true");
let start = std::time::Instant::now();
let status = poll_startup(
&mut child,
std::time::Duration::from_millis(10),
std::time::Duration::from_secs(1),
);
let elapsed = start.elapsed();
assert!(
matches!(status, StartupStatus::Died),
"expected Died, got {status:?}"
);
assert!(
elapsed < std::time::Duration::from_millis(500),
"early death must be detected fast, took {elapsed:?}"
);
}
#[test]
fn poll_startup_reports_alive_after_timeout() {
let mut child = std::process::Command::new("/bin/sleep")
.arg("5")
.spawn()
.expect("spawn /bin/sleep");
let start = std::time::Instant::now();
let status = poll_startup(
&mut child,
std::time::Duration::from_millis(20),
std::time::Duration::from_millis(100),
);
let elapsed = start.elapsed();
let _ = child.kill();
let _ = child.wait();
assert!(
matches!(status, StartupStatus::Alive),
"expected Alive, got {status:?}"
);
assert!(
elapsed >= std::time::Duration::from_millis(100),
"Alive must wait the full timeout, took only {elapsed:?}"
);
assert!(
elapsed < std::time::Duration::from_millis(300),
"Alive should not overshoot timeout significantly, took {elapsed:?}"
);
}
static SIGCHLD_TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
struct SigchldGuard {
prev: libc::sighandler_t,
}
impl SigchldGuard {
fn install(handler: libc::sighandler_t) -> Self {
let prev = unsafe { libc::signal(libc::SIGCHLD, handler) };
Self { prev }
}
}
impl Drop for SigchldGuard {
fn drop(&mut self) {
unsafe {
libc::signal(libc::SIGCHLD, self.prev);
}
}
}
#[test]
fn with_sigchld_default_captures_real_exit_status() {
let _guard = SIGCHLD_TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let _restore = SigchldGuard::install(libc::SIG_IGN);
let bare = Command::new("/bin/true").status();
assert!(
bare.is_err(),
"under SIG_IGN, Command::status must fail with ECHILD; got {bare:?}",
);
let wrapped = with_sigchld_default(|| Command::new("/bin/true").status());
let status = wrapped.expect("with_sigchld_default must capture status");
assert_eq!(
status.code(),
Some(0),
"/bin/true must exit 0 under helper; got {status:?}",
);
let after = unsafe { libc::signal(libc::SIGCHLD, libc::SIG_IGN) };
assert_eq!(
after,
libc::SIG_IGN,
"with_sigchld_default must restore SIG_IGN after closure returns",
);
}
#[test]
fn with_sigchld_default_captures_nonzero_exit_status() {
let _guard = SIGCHLD_TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let _restore = SigchldGuard::install(libc::SIG_IGN);
let wrapped = with_sigchld_default(|| Command::new("/bin/false").status());
let status = wrapped.expect("with_sigchld_default must capture status");
assert_eq!(
status.code(),
Some(1),
"/bin/false must surface non-zero code under helper; got {status:?}",
);
}
#[test]
fn poll_startup_detects_death_under_sigchld_ignore() {
let _guard = SIGCHLD_TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let _restore = SigchldGuard::install(libc::SIG_IGN);
let mut child = std::process::Command::new("/bin/true")
.spawn()
.expect("spawn /bin/true");
let status = poll_startup(
&mut child,
std::time::Duration::from_millis(10),
std::time::Duration::from_secs(1),
);
assert!(
matches!(status, StartupStatus::Died),
"under SIG_IGN, an exited child must be observed as Died (was {status:?})",
);
}
#[test]
fn poll_startup_reports_alive_under_sigchld_ignore() {
let _guard = SIGCHLD_TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let _restore = SigchldGuard::install(libc::SIG_IGN);
let mut child = std::process::Command::new("/bin/sleep")
.arg("5")
.spawn()
.expect("spawn /bin/sleep");
let status = poll_startup(
&mut child,
std::time::Duration::from_millis(20),
std::time::Duration::from_millis(100),
);
let _ = child.kill();
unsafe {
libc::signal(libc::SIGCHLD, libc::SIG_DFL);
}
let _ = child.wait();
assert!(
matches!(status, StartupStatus::Alive),
"under SIG_IGN, a running child must be observed as Alive (was {status:?})",
);
}
#[test]
fn sched_pid_side_channel_roundtrips() {
let _guard = SIGCHLD_TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let snapshot = SCHED_PID.load(Ordering::Acquire);
SCHED_PID.store(0, Ordering::Release);
assert_eq!(sched_pid(), None, "0 must read as None (sentinel)");
SCHED_PID.store(12345, Ordering::Release);
assert_eq!(
sched_pid(),
Some(12345),
"writer must publish via the atomic side channel",
);
SCHED_PID.store(snapshot, Ordering::Release);
}
#[test]
fn sched_pid_does_not_publish_via_env_var() {
let _guard = SIGCHLD_TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
unsafe { std::env::remove_var("SCHED_PID") };
let snapshot = SCHED_PID.load(Ordering::Acquire);
SCHED_PID.store(99999, Ordering::Release);
assert_eq!(sched_pid(), Some(99999));
assert!(
std::env::var("SCHED_PID").is_err(),
"atomic side channel must not publish via env var",
);
SCHED_PID.store(snapshot, Ordering::Release);
}
}