#![cfg(test)]
#![allow(unused_imports)]
use super::super::affinity::*;
use super::super::config::*;
use super::super::spawn::*;
use super::super::types::*;
use super::*;
use std::collections::{BTreeMap, BTreeSet};
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::{Duration, Instant};
#[test]
fn clock_gettime_ns_monotonic_non_decreasing() {
const N: usize = 1000;
let samples: Vec<u64> = (0..N)
.map(|i| {
clock_gettime_ns(libc::CLOCK_MONOTONIC).unwrap_or_else(|| {
panic!(
"CLOCK_MONOTONIC must be readable on any Linux host; \
sample {i}/{N} returned None"
)
})
})
.collect();
for i in 1..N {
assert!(
samples[i] >= samples[i - 1],
"CLOCK_MONOTONIC went backwards at sample {i}: \
prev={prev} curr={curr} (delta={delta})",
prev = samples[i - 1],
curr = samples[i],
delta = samples[i - 1] - samples[i],
);
}
}
#[test]
fn matrix_multiply_1x1_produces_product() {
let mut data = vec![0u64; 3];
data[0] = 3; data[1] = 5; let mut work_units = 0u64;
matrix_multiply(&mut data, 1, &mut work_units);
assert_eq!(data[2], 15, "C = A * B for 1x1 matrix");
assert_eq!(work_units, 15, "post-loop sink folds C[0] into work_units");
}
#[test]
fn matrix_multiply_2x2_against_reference() {
let size = 2;
let stride = size * size;
let mut data = vec![0u64; 3 * stride];
data[0] = 1;
data[1] = 2;
data[2] = 3;
data[3] = 4;
data[stride] = 5;
data[stride + 1] = 6;
data[stride + 2] = 7;
data[stride + 3] = 8;
let mut work_units = 0u64;
matrix_multiply(&mut data, size, &mut work_units);
assert_eq!(data[2 * stride], 19);
assert_eq!(data[2 * stride + 1], 22);
assert_eq!(data[2 * stride + 2], 43);
assert_eq!(data[2 * stride + 3], 50);
}
#[test]
fn matrix_multiply_3x3_diagonal() {
let size = 3;
let stride = size * size;
let mut data = vec![0u64; 3 * stride];
data[0] = 2;
data[4] = 3;
data[8] = 5;
data[stride] = 1;
data[stride + 4] = 1;
data[stride + 8] = 1;
let mut work_units = 0u64;
matrix_multiply(&mut data, size, &mut work_units);
let c = &data[2 * stride..3 * stride];
assert_eq!(c[0], 2);
assert_eq!(c[4], 3);
assert_eq!(c[8], 5);
assert_eq!(c[1], 0);
assert_eq!(c[2], 0);
assert_eq!(c[3], 0);
assert_eq!(c[5], 0);
assert_eq!(c[6], 0);
assert_eq!(c[7], 0);
}
#[test]
#[cfg(debug_assertions)]
#[should_panic(expected = "assertion")]
fn matrix_multiply_mismatched_len_panics_in_debug() {
let mut data = vec![0u64; 5]; let mut work_units = 0u64;
matrix_multiply(&mut data, 2, &mut work_units);
}
#[test]
fn direct_io_buf_alloc_aligned() {
let buf = DirectIoBuf::alloc()
.expect("DirectIoBuf::alloc must succeed under normal allocator pressure");
let addr = buf.as_ptr() as usize;
assert_eq!(
addr % IO_BLOCK_SIZE,
0,
"DirectIoBuf must be IO_BLOCK_SIZE-aligned (got addr={addr:#x})"
);
let slice = unsafe { std::slice::from_raw_parts_mut(buf.as_ptr(), IO_BLOCK_SIZE) };
slice.fill(0xAA);
assert!(
slice.iter().all(|&b| b == 0xAA),
"round-trip pattern must persist across the buffer",
);
}
#[test]
fn io_backing_tempfile_unlinked_on_drop() {
let path = std::env::temp_dir()
.join(format!(
"ktstr_iobacking_unlink_{}_{}",
std::process::id(),
unsafe { libc::syscall(libc::SYS_gettid) },
))
.to_string_lossy()
.to_string();
let file = std::fs::OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(true)
.open(&path)
.expect("create real tempfile for IoBacking test");
assert!(
std::path::Path::new(&path).exists(),
"precondition: file exists"
);
{
let _backing = IoBacking {
file,
capacity_bytes: 0,
tempfile_path: Some(path.clone()),
};
assert!(std::path::Path::new(&path).exists());
}
assert!(
!std::path::Path::new(&path).exists(),
"IoBacking::Drop must unlink {path}",
);
}
#[test]
fn io_backing_none_path_no_unlink() {
let path = std::env::temp_dir()
.join(format!(
"ktstr_iobacking_nounlink_{}_{}",
std::process::id(),
unsafe { libc::syscall(libc::SYS_gettid) },
))
.to_string_lossy()
.to_string();
let file = std::fs::OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(true)
.open(&path)
.expect("create stand-in for /dev/vda");
{
let _backing = IoBacking {
file,
capacity_bytes: 0,
tempfile_path: None,
};
}
assert!(
std::path::Path::new(&path).exists(),
"IoBacking::Drop must NOT unlink when tempfile_path is None",
);
let _ = std::fs::remove_file(&path);
}
#[test]
fn phase_io_tempfile_unlinked_on_drop() {
let path = std::env::temp_dir()
.join(format!(
"ktstr_phaseio_unlink_{}_{}",
std::process::id(),
unsafe { libc::syscall(libc::SYS_gettid) },
))
.to_string_lossy()
.to_string();
let file = std::fs::OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&path)
.expect("create real tempfile for PhaseIoTempfile test");
assert!(std::path::Path::new(&path).exists(), "precondition");
{
let _tf = PhaseIoTempfile {
file,
path: path.clone(),
};
}
assert!(
!std::path::Path::new(&path).exists(),
"PhaseIoTempfile::Drop must unlink {path}",
);
}
#[test]
fn ensure_io_disk_lazy_init() {
use std::os::unix::io::AsRawFd;
let tid: libc::pid_t = unsafe { libc::syscall(libc::SYS_gettid) as libc::pid_t };
let mut io_disk: Option<IoBacking> = None;
assert!(
ensure_io_disk(&mut io_disk, 0, tid),
"first ensure_io_disk must succeed (host can open tempfile fallback)",
);
let fd1 = io_disk
.as_ref()
.expect("io_disk Some after first call")
.file
.as_raw_fd();
assert!(ensure_io_disk(&mut io_disk, 0, tid));
let fd2 = io_disk.as_ref().unwrap().file.as_raw_fd();
assert_eq!(
fd1, fd2,
"ensure_io_disk must be lazy-init — second call must not re-open",
);
}
#[test]
fn ensure_io_buf_lazy_init() {
let mut io_buf: Option<DirectIoBuf> = None;
assert!(
ensure_io_buf(&mut io_buf),
"first ensure_io_buf must succeed under normal allocator pressure",
);
let ptr1 = io_buf
.as_ref()
.expect("io_buf Some after first call")
.as_ptr();
assert!(ensure_io_buf(&mut io_buf));
let ptr2 = io_buf.as_ref().unwrap().as_ptr();
assert_eq!(
ptr1, ptr2,
"ensure_io_buf must be lazy-init — second call must not re-allocate",
);
}
#[test]
fn thread_cpu_time_positive() {
let mut x = 0u64;
for i in 0..100_000 {
x = x.wrapping_add(i);
}
std::hint::black_box(x);
let t = super::thread_cpu_time_ns();
assert!(t > 0);
}
#[test]
fn set_sched_policy_normal_succeeds() {
let pid: libc::pid_t = unsafe { libc::getpid() };
let result = set_sched_policy(pid, SchedPolicy::Normal);
assert!(result.is_ok());
}
#[test]
#[ignore]
fn set_sched_policy_fifo_returns_result() {
let pid: libc::pid_t = unsafe { libc::getpid() };
let result = set_sched_policy(pid, SchedPolicy::Fifo(1));
assert!(
result.is_ok(),
"SCHED_FIFO should succeed with CAP_SYS_NICE"
);
restore_normal(pid);
}
#[test]
#[ignore]
fn set_sched_policy_rr_returns_result() {
let pid: libc::pid_t = unsafe { libc::getpid() };
let result = set_sched_policy(pid, SchedPolicy::RoundRobin(1));
assert!(result.is_ok(), "SCHED_RR should succeed with CAP_SYS_NICE");
restore_normal(pid);
}
fn restore_normal(pid: libc::pid_t) {
let param = libc::sched_param { sched_priority: 0 };
unsafe { libc::sched_setscheduler(pid, libc::SCHED_OTHER, ¶m) };
}
#[test]
fn set_sched_policy_batch_returns_valid_result() {
let pid: libc::pid_t = unsafe { libc::getpid() };
let result = set_sched_policy(pid, SchedPolicy::Batch);
match result {
Ok(()) => {
let pol = unsafe { libc::sched_getscheduler(pid) };
assert!(
pol >= 0,
"sched_getscheduler must return a valid policy, got {pol}",
);
restore_normal(pid);
}
Err(ref e) => {
let msg = format!("{e:#}");
assert!(
msg.contains("sched_setscheduler"),
"error must name the syscall: {msg}"
);
}
}
}
#[test]
fn set_sched_policy_idle_returns_valid_result() {
let pid: libc::pid_t = unsafe { libc::getpid() };
let result = set_sched_policy(pid, SchedPolicy::Idle);
match result {
Ok(()) => {
let pol = unsafe { libc::sched_getscheduler(pid) };
assert!(
pol >= 0,
"sched_getscheduler must return a valid policy, got {pol}",
);
restore_normal(pid);
}
Err(ref e) => {
let msg = format!("{e:#}");
assert!(
msg.contains("sched_setscheduler"),
"error must name the syscall: {msg}"
);
}
}
}
#[test]
fn set_sched_policy_deadline_zero_deadline_rejected() {
let pid: libc::pid_t = unsafe { libc::getpid() };
let result = set_sched_policy(
pid,
SchedPolicy::Deadline {
runtime: Duration::from_nanos(1024),
deadline: Duration::ZERO,
period: Duration::from_nanos(1_000_000),
},
);
let err = result.expect_err("zero deadline must be rejected");
let msg = format!("{err:#}");
assert!(
msg.contains("deadline"),
"error must name deadline field: {msg}"
);
assert!(
msg.contains("must be > 0") || msg.contains("zero"),
"error must explain zero rejection: {msg}"
);
}
#[test]
fn set_sched_policy_deadline_runtime_below_dl_scale_rejected() {
let pid: libc::pid_t = unsafe { libc::getpid() };
let result = set_sched_policy(
pid,
SchedPolicy::Deadline {
runtime: Duration::from_nanos(1023),
deadline: Duration::from_nanos(100_000),
period: Duration::from_nanos(1_000_000),
},
);
let err = result.expect_err("runtime below DL_SCALE must be rejected");
let msg = format!("{err:#}");
assert!(
msg.contains("runtime"),
"error must name runtime field: {msg}"
);
assert!(
msg.contains("DL_SCALE") || msg.contains("1024"),
"error must reference the floor: {msg}"
);
}
#[test]
fn set_sched_policy_deadline_runtime_exceeds_deadline_rejected() {
let pid: libc::pid_t = unsafe { libc::getpid() };
let result = set_sched_policy(
pid,
SchedPolicy::Deadline {
runtime: Duration::from_nanos(200_000),
deadline: Duration::from_nanos(100_000),
period: Duration::from_nanos(1_000_000),
},
);
let err = result.expect_err("runtime > deadline must be rejected");
let msg = format!("{err:#}");
assert!(
msg.contains("runtime") && msg.contains("deadline"),
"error must name both fields: {msg}"
);
}
#[test]
fn set_sched_policy_deadline_deadline_exceeds_period_rejected() {
let pid: libc::pid_t = unsafe { libc::getpid() };
let result = set_sched_policy(
pid,
SchedPolicy::Deadline {
runtime: Duration::from_nanos(1024),
deadline: Duration::from_nanos(2_000_000),
period: Duration::from_nanos(1_000_000),
},
);
let err = result.expect_err("deadline > period must be rejected");
let msg = format!("{err:#}");
assert!(
msg.contains("deadline") && msg.contains("period"),
"error must name both fields: {msg}"
);
}
#[test]
fn set_sched_policy_deadline_top_bit_set_rejected() {
let pid: libc::pid_t = unsafe { libc::getpid() };
let result = set_sched_policy(
pid,
SchedPolicy::Deadline {
runtime: Duration::from_nanos(1024),
deadline: Duration::from_secs(1_000_000_000_000),
period: Duration::from_nanos(1_000_000),
},
);
let err = result.expect_err("deadline exceeding i64::MAX must be rejected");
let msg = format!("{err:#}");
assert!(
msg.contains("deadline") && (msg.contains("i64::MAX") || msg.contains("63 bits")),
"error must name deadline field and the bit-63 / i64::MAX bound: {msg}"
);
assert!(
!msg.contains("period"),
"deadline-only overflow error must not mention period: {msg}"
);
}
#[test]
#[ignore]
fn set_sched_policy_deadline_period_zero_passes_validation() {
let pid: libc::pid_t = unsafe { libc::getpid() };
let result = set_sched_policy(
pid,
SchedPolicy::Deadline {
runtime: Duration::from_nanos(1024),
deadline: Duration::from_nanos(200_000),
period: Duration::ZERO,
},
);
match result {
Ok(()) => {
restore_normal(pid);
}
Err(e) => {
let msg = format!("{e:#}");
assert!(
msg.contains("sched_setattr"),
"validation must have passed (error from kernel must name sched_setattr): {msg}"
);
}
}
}
#[test]
fn reservoir_push_empty_buf() {
let mut buf = Vec::new();
let mut count = 0u64;
reservoir_push(&mut buf, &mut count, 42, 10);
assert_eq!(buf, vec![42]);
assert_eq!(count, 1);
}
#[test]
fn reservoir_push_under_cap() {
let mut buf = Vec::new();
let mut count = 0u64;
for i in 0..5 {
reservoir_push(&mut buf, &mut count, i * 100, 10);
}
assert_eq!(buf.len(), 5);
assert_eq!(count, 5);
assert_eq!(buf, vec![0, 100, 200, 300, 400]);
}
#[test]
fn reservoir_push_at_cap() {
let mut buf = Vec::new();
let mut count = 0u64;
for i in 0..10 {
reservoir_push(&mut buf, &mut count, i, 10);
}
assert_eq!(buf.len(), 10);
assert_eq!(count, 10);
for i in 0..10 {
assert!(buf.contains(&i), "missing {i}");
}
}
#[test]
fn reservoir_push_over_cap_maintains_size() {
let mut buf = Vec::new();
let mut count = 0u64;
let cap = 5;
for i in 0..1000 {
reservoir_push(&mut buf, &mut count, i, cap);
}
assert_eq!(buf.len(), cap);
assert_eq!(count, 1000);
}
#[test]
fn reservoir_push_uniform_sampling() {
let mut buf = Vec::new();
let mut count = 0u64;
let cap = 100;
let total = 10_000u64;
for i in 0..total {
reservoir_push(&mut buf, &mut count, i, cap);
}
assert_eq!(buf.len(), cap);
assert_eq!(count, total);
let has_early = buf.iter().any(|&v| v < total / 4);
let has_late = buf.iter().any(|&v| v > total * 3 / 4);
assert!(has_early, "reservoir should contain early values");
assert!(has_late, "reservoir should contain late values");
}
#[test]
fn reservoir_push_cap_zero() {
let mut buf = Vec::new();
let mut count = 0u64;
for i in 0..10 {
reservoir_push(&mut buf, &mut count, i, 0);
}
assert!(buf.is_empty(), "cap=0 should never store samples");
assert_eq!(count, 10, "count incremented regardless");
}
#[test]
fn reservoir_push_cap_one() {
let mut buf = Vec::new();
let mut count = 0u64;
reservoir_push(&mut buf, &mut count, 42, 1);
assert_eq!(buf, vec![42]);
assert_eq!(count, 1);
for i in 1..100 {
reservoir_push(&mut buf, &mut count, i * 100, 1);
}
assert_eq!(buf.len(), 1);
assert_eq!(count, 100);
}
#[test]
fn read_schedstat_returns_finite_triple() {
let Some((cpu_time, _run_delay, timeslices)) = read_schedstat(None) else {
eprintln!("skipping: /proc/self/schedstat not available (CONFIG_SCHEDSTATS off)");
return;
};
assert!(cpu_time > 0);
assert!(timeslices > 0);
}
#[test]
fn parse_schedstat_line_happy_path() {
let (cpu_time, run_delay, timeslices) = parse_schedstat_line("100 200 300 999 extra").unwrap();
assert_eq!(cpu_time, 100);
assert_eq!(run_delay, 200);
assert_eq!(timeslices, 300);
}
#[test]
fn parse_schedstat_line_tab_and_newline_separators() {
let parsed = parse_schedstat_line("1\t2\t3\n").unwrap();
assert_eq!(parsed, (1, 2, 3));
}
#[test]
fn parse_schedstat_line_missing_field_returns_none() {
assert!(parse_schedstat_line("100 200").is_none());
assert!(parse_schedstat_line("100").is_none());
assert!(parse_schedstat_line("").is_none());
assert!(parse_schedstat_line(" \t\n ").is_none());
}
#[test]
fn parse_schedstat_line_non_u64_token_returns_none() {
assert!(parse_schedstat_line("not-a-number 200 300").is_none());
assert!(parse_schedstat_line("100 abc 300").is_none());
assert!(parse_schedstat_line("100 200 nan").is_none());
assert!(parse_schedstat_line("-1 200 300").is_none());
assert!(parse_schedstat_line("99999999999999999999 2 3").is_none());
}
#[test]
fn warn_schedstat_unavailable_once_does_not_panic_on_repeat() {
for _ in 0..10 {
warn_schedstat_unavailable_once();
}
}
#[test]
fn alu_width_resolve_never_returns_widest() {
for &w in &[
AluWidth::Scalar,
AluWidth::Vec128,
AluWidth::Vec256,
AluWidth::Vec512,
AluWidth::Amx,
AluWidth::Widest,
] {
let r = resolve_alu_width(w);
assert!(
!matches!(r, AluWidth::Widest),
"resolve_alu_width({w:?}) returned Widest; \
caller invariant violated",
);
}
}