#![cfg(test)]
use std::ffi::{OsStr, OsString};
use std::sync::{Mutex, MutexGuard};
use std::time::Duration;
use anyhow::Result;
use tempfile::TempDir;
use crate::assert::{AssertDetail, AssertResult, ScenarioStats};
use crate::scenario::Ctx;
use crate::scenario::flags::FlagDecl;
use super::entry::{KtstrTestEntry, Scheduler, SchedulerSpec, TopologyConstraints};
use crate::vmm::topology::Topology;
pub(crate) static ENV_LOCK: Mutex<()> = Mutex::new(());
pub(crate) fn lock_env() -> MutexGuard<'static, ()> {
ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner())
}
pub(crate) struct IsolatedCacheDir {
_guard: EnvVarGuard,
tmp: TempDir,
}
impl IsolatedCacheDir {
pub(crate) fn path(&self) -> &std::path::Path {
self.tmp.path()
}
}
pub(crate) fn isolated_cache_dir() -> IsolatedCacheDir {
let tmp = TempDir::new().expect("tempdir for isolated cache root");
let guard = EnvVarGuard::set("KTSTR_CACHE_DIR", tmp.path());
IsolatedCacheDir { _guard: guard, tmp }
}
pub(crate) const EVAL_TOPO: Topology = Topology::new(1, 1, 2, 1);
pub(crate) fn dummy_test_fn(_ctx: &Ctx) -> Result<AssertResult> {
Ok(AssertResult::pass())
}
pub(crate) struct EnvVarGuard {
key: OsString,
original: Option<OsString>,
}
impl EnvVarGuard {
pub(crate) fn set(key: impl AsRef<OsStr>, value: impl AsRef<OsStr>) -> Self {
let key = key.as_ref().to_owned();
let original = std::env::var_os(&key);
unsafe { std::env::set_var(&key, value) };
EnvVarGuard { key, original }
}
pub(crate) fn remove(key: impl AsRef<OsStr>) -> Self {
let key = key.as_ref().to_owned();
let original = std::env::var_os(&key);
unsafe { std::env::remove_var(&key) };
EnvVarGuard { key, original }
}
}
impl Drop for EnvVarGuard {
fn drop(&mut self) {
match &self.original {
Some(val) => unsafe { std::env::set_var(&self.key, val) },
None => unsafe { std::env::remove_var(&self.key) },
}
}
}
pub(crate) fn eevdf_entry(name: &'static str) -> KtstrTestEntry {
KtstrTestEntry {
name,
func: dummy_test_fn,
auto_repro: false,
..KtstrTestEntry::DEFAULT
}
}
pub(crate) static SCHED_TEST: Scheduler = Scheduler {
name: "test_sched",
binary: SchedulerSpec::Discover("test_sched_bin"),
flags: &[],
sysctls: &[],
kargs: &[],
assert: crate::assert::Assert::NO_OVERRIDES,
cgroup_parent: None,
sched_args: &[],
topology: Topology {
llcs: 1,
cores_per_llc: 2,
threads_per_core: 1,
numa_nodes: 1,
nodes: None,
distances: None,
},
constraints: TopologyConstraints::DEFAULT,
config_file: None,
};
pub(crate) static SCHED_TEST_PAYLOAD: crate::test_support::Payload =
crate::test_support::Payload::from_scheduler(&SCHED_TEST);
pub(crate) fn sched_entry(name: &'static str) -> KtstrTestEntry {
KtstrTestEntry {
name,
func: dummy_test_fn,
scheduler: &SCHED_TEST_PAYLOAD,
auto_repro: false,
..KtstrTestEntry::DEFAULT
}
}
pub(crate) fn no_repro(_output: &str) -> Option<String> {
None
}
pub(crate) fn make_vm_result(
output: &str,
stderr: &str,
exit_code: i32,
timed_out: bool,
) -> crate::vmm::VmResult {
crate::vmm::VmResult {
success: !timed_out && exit_code == 0,
exit_code,
duration: std::time::Duration::from_secs(1),
timed_out,
output: output.to_string(),
stderr: stderr.to_string(),
monitor: None,
shm_data: None,
stimulus_events: Vec::new(),
verifier_stats: Vec::new(),
kvm_stats: None,
crash_message: None,
cleanup_duration: None,
}
}
pub(crate) fn build_assert_result_json(passed: bool, details: Vec<AssertDetail>) -> String {
let result = AssertResult {
passed,
skipped: false,
details,
stats: ScenarioStats::default(),
};
serde_json::to_string(&result).expect("AssertResult must always serialize")
}
pub(crate) fn validate_entry(
name: &'static str,
memory_mb: u32,
duration: Duration,
workers_per_cgroup: u32,
) -> KtstrTestEntry {
KtstrTestEntry {
name,
memory_mb,
duration,
workers_per_cgroup,
..KtstrTestEntry::DEFAULT
}
}
pub(crate) fn panic_payload_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
if let Some(s) = payload.downcast_ref::<&'static str>() {
(*s).to_string()
} else if let Some(s) = payload.downcast_ref::<String>() {
s.clone()
} else {
format!("{payload:?}")
}
}
#[cfg(test)]
mod panic_payload_tests {
use super::panic_payload_to_string;
#[test]
#[cfg(panic = "unwind")]
fn panic_payload_str_literal_recovered() {
let result = std::panic::catch_unwind(|| {
panic!("a-string-literal");
});
let payload = result.expect_err("closure must panic");
assert_eq!(panic_payload_to_string(payload), "a-string-literal");
}
#[test]
#[cfg(panic = "unwind")]
fn panic_payload_formatted_string_recovered() {
let n: u32 = 42;
let result = std::panic::catch_unwind(|| {
panic!("formatted-{n}-panic");
});
let payload = result.expect_err("closure must panic");
assert_eq!(panic_payload_to_string(payload), "formatted-42-panic");
}
#[test]
#[cfg(panic = "unwind")]
fn panic_payload_non_string_falls_back_to_debug() {
let result = std::panic::catch_unwind(|| {
std::panic::panic_any(127i32);
});
let payload = result.expect_err("closure must panic");
let rendered = panic_payload_to_string(payload);
assert!(
!rendered.is_empty(),
"non-string payload must render to SOMETHING (non-empty)",
);
}
}
pub(crate) static FLAG_A: FlagDecl = FlagDecl {
name: "flag_a",
args: &["--flag-a"],
requires: &[],
};
pub(crate) static BORROW: FlagDecl = FlagDecl {
name: "borrow",
args: &["--borrow"],
requires: &[],
};
pub(crate) static REBAL: FlagDecl = FlagDecl {
name: "rebal",
args: &["--rebal"],
requires: &[],
};
pub(crate) static TEST_LLC: FlagDecl = FlagDecl {
name: "llc",
args: &["--llc"],
requires: &[],
};
pub(crate) static TEST_STEAL: FlagDecl = FlagDecl {
name: "steal",
args: &["--steal"],
requires: &[&TEST_LLC],
};
pub(crate) static BORROW_LONG: FlagDecl = FlagDecl {
name: "borrow",
args: &["--enable-borrow"],
requires: &[],
};
pub(crate) static TEST_A: FlagDecl = FlagDecl {
name: "a",
args: &["-a"],
requires: &[],
};
pub(crate) static TEST_B: FlagDecl = FlagDecl {
name: "b",
args: &["-b"],
requires: &[],
};
pub(crate) static FLAGS_A: &[&FlagDecl] = &[&FLAG_A];
pub(crate) static FLAGS_BORROW_REBAL: &[&FlagDecl] = &[&BORROW, &REBAL];
pub(crate) static FLAGS_STEAL_LLC: &[&FlagDecl] = &[&TEST_STEAL, &TEST_LLC];
pub(crate) static FLAGS_BORROW_LONG: &[&FlagDecl] = &[&BORROW_LONG];
pub(crate) static FLAGS_AB: &[&FlagDecl] = &[&TEST_A, &TEST_B];
pub(crate) static FLAGS_LLC_STEAL: &[&FlagDecl] = &[&TEST_LLC, &TEST_STEAL];
#[test]
fn env_var_guard_set_restores_original_value_on_drop() {
const KEY: &str = "KTSTR_TEST_ENV_VAR_GUARD_SET_RESTORES_ORIGINAL";
let _lock = lock_env();
unsafe { std::env::set_var(KEY, "pre-existing") };
{
let _g = EnvVarGuard::set(KEY, "overwritten");
assert_eq!(std::env::var(KEY).ok().as_deref(), Some("overwritten"));
}
assert_eq!(
std::env::var(KEY).ok().as_deref(),
Some("pre-existing"),
"EnvVarGuard::set must restore the pre-existing value on drop"
);
unsafe { std::env::remove_var(KEY) };
}
#[test]
fn env_var_guard_set_restores_absent_key_on_drop() {
const KEY: &str = "KTSTR_TEST_ENV_VAR_GUARD_SET_RESTORES_ABSENT";
let _lock = lock_env();
unsafe { std::env::remove_var(KEY) };
assert!(std::env::var(KEY).is_err(), "setup: key must start absent");
{
let _g = EnvVarGuard::set(KEY, "transient");
assert_eq!(std::env::var(KEY).ok().as_deref(), Some("transient"));
}
assert!(
std::env::var(KEY).is_err(),
"EnvVarGuard::set on an absent key must restore absence, not leave the value behind"
);
}
#[test]
fn env_var_guard_remove_restores_original_value_on_drop() {
const KEY: &str = "KTSTR_TEST_ENV_VAR_GUARD_REMOVE_RESTORES_ORIGINAL";
let _lock = lock_env();
unsafe { std::env::set_var(KEY, "pre-existing") };
{
let _g = EnvVarGuard::remove(KEY);
assert!(
std::env::var(KEY).is_err(),
"EnvVarGuard::remove must remove the key for the guard's lifetime"
);
}
assert_eq!(
std::env::var(KEY).ok().as_deref(),
Some("pre-existing"),
"EnvVarGuard::remove must restore the pre-existing value on drop"
);
unsafe { std::env::remove_var(KEY) };
}
#[cfg(unix)]
#[test]
fn env_var_guard_restores_non_utf8_original_exactly() {
use std::os::unix::ffi::OsStrExt;
const KEY: &str = "KTSTR_TEST_ENV_VAR_GUARD_RESTORES_NON_UTF8";
let original_bytes: &[u8] = b"foo\xFFbar";
let original = OsStr::from_bytes(original_bytes);
let _lock = lock_env();
unsafe { std::env::set_var(KEY, original) };
{
let _g = EnvVarGuard::set(KEY, "replacement");
assert_eq!(
std::env::var_os(KEY).as_deref(),
Some(OsStr::new("replacement")),
"EnvVarGuard::set must overwrite the key while the guard is live",
);
}
let restored = std::env::var_os(KEY).expect("restored value must be present after drop");
assert_eq!(
restored.as_bytes(),
original_bytes,
"EnvVarGuard must restore non-UTF-8 originals byte-for-byte",
);
unsafe { std::env::remove_var(KEY) };
}
pub(crate) static STDERR_CAPTURE_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
pub(crate) struct StderrRestoreGuard {
saved: Option<std::os::fd::OwnedFd>,
}
impl Drop for StderrRestoreGuard {
fn drop(&mut self) {
if let Some(saved) = self.saved.take() {
let _ = nix::unistd::dup2_stderr(&saved);
}
}
}
pub(crate) fn capture_stderr<R>(f: impl FnOnce() -> R) -> (R, Vec<u8>) {
use std::io::{Read, Seek, SeekFrom, Write};
let _lock = STDERR_CAPTURE_LOCK
.lock()
.unwrap_or_else(|e| e.into_inner());
let mut sink = tempfile::tempfile().expect("create stderr-capture tempfile");
std::io::stderr().flush().ok();
let saved = nix::unistd::dup(std::io::stderr()).expect("dup(stderr)");
nix::unistd::dup2_stderr(&sink).expect("dup2_stderr(sink)");
let guard = StderrRestoreGuard { saved: Some(saved) };
let result = f();
std::io::stderr().flush().ok();
drop(guard);
sink.seek(SeekFrom::Start(0)).expect("rewind sink");
let mut bytes = Vec::new();
sink.read_to_end(&mut bytes).expect("read sink");
(result, bytes)
}
#[test]
fn capture_stderr_serializes_concurrent_callers() {
const N: usize = 8;
let markers: Vec<String> = (0..N)
.map(|i| format!("KTSTR_CAPTURE_LOCK_MARKER_{i:03}"))
.collect();
let handles: Vec<std::thread::JoinHandle<(usize, Vec<u8>)>> = (0..N)
.map(|i| {
let mine = markers[i].clone();
let others: Vec<String> = markers
.iter()
.enumerate()
.filter(|&(j, _m)| j != i)
.map(|(_j, m)| m.clone())
.collect();
std::thread::spawn(move || {
let (_, bytes) = capture_stderr(|| {
eprintln!("{mine}");
});
let captured = String::from_utf8_lossy(&bytes);
assert!(
captured.contains(&mine),
"thread {i}: own marker missing from captured output: {captured:?}",
);
for other in &others {
assert!(
!captured.contains(other.as_str()),
"thread {i}: foreign marker '{other}' leaked into captured \
output — STDERR_CAPTURE_LOCK failed to serialize \
concurrent callers: {captured:?}",
);
}
(i, bytes)
})
})
.collect();
for h in handles {
h.join().expect("capture thread panicked");
}
}