use std::path::{Path, PathBuf};
use anyhow::Result;
use clap::{ArgAction, CommandFactory, Parser, Subcommand};
use ktstr::cache::KernelMetadata;
use ktstr::cgroup::CgroupManager;
use ktstr::cli;
use ktstr::runner::Runner;
use ktstr::scenario;
use ktstr::topology::TestTopology;
#[derive(Parser)]
#[command(
name = "ktstr",
about = "Run ktstr scheduler test scenarios on the host"
)]
struct Cli {
#[command(subcommand)]
command: Command,
}
#[derive(Subcommand)]
enum Command {
Run {
#[arg(long, default_value = "20")]
duration: u64,
#[arg(long, default_value = "4")]
workers: usize,
#[arg(long, value_delimiter = ',')]
flags: Option<Vec<String>>,
#[arg(long)]
filter: Option<String>,
#[arg(long)]
json: bool,
#[arg(long)]
repro: bool,
#[arg(long)]
probe_stack: Option<String>,
#[arg(long)]
auto_repro: bool,
#[arg(long)]
kernel_dir: Option<String>,
#[arg(long)]
work_type: Option<String>,
},
List {
#[arg(long)]
filter: Option<String>,
#[arg(long)]
json: bool,
},
Topo,
Cleanup {
#[arg(long, default_value = "/sys/fs/cgroup/ktstr")]
parent_cgroup: String,
},
Kernel {
#[command(subcommand)]
command: KernelCommand,
},
Shell {
#[arg(long)]
kernel: Option<String>,
#[arg(long, default_value = "1,1,1")]
topology: String,
#[arg(short = 'i', long = "include-files", action = ArgAction::Append)]
include_files: Vec<PathBuf>,
#[arg(long = "memory-mb", value_parser = clap::value_parser!(u32).range(128..))]
memory_mb: Option<u32>,
#[arg(long)]
dmesg: bool,
#[arg(long)]
exec: Option<String>,
},
Completions {
shell: clap_complete::Shell,
},
}
#[derive(Subcommand)]
enum KernelCommand {
List {
#[arg(long)]
json: bool,
},
Build {
#[arg(conflicts_with_all = ["source", "git"])]
version: Option<String>,
#[arg(long, conflicts_with_all = ["version", "git"])]
source: Option<PathBuf>,
#[arg(long, requires = "git_ref", conflicts_with_all = ["version", "source"])]
git: Option<String>,
#[arg(long = "ref", requires = "git")]
git_ref: Option<String>,
#[arg(long)]
force: bool,
#[arg(long)]
clean: bool,
},
Clean {
#[arg(long)]
keep: Option<usize>,
#[arg(long)]
force: bool,
},
}
struct CgroupGuard {
path: String,
}
impl Drop for CgroupGuard {
fn drop(&mut self) {
let cgroups = CgroupManager::new(&self.path);
let _ = cgroups.cleanup_all();
let _ = std::fs::remove_dir(&self.path);
}
}
fn kernel_build(
version: Option<String>,
source: Option<PathBuf>,
git: Option<String>,
git_ref: Option<String>,
force: bool,
clean: bool,
) -> Result<()> {
use ktstr::cache::CacheDir;
use ktstr::fetch;
let cache = CacheDir::new()?;
let tmp_dir = tempfile::TempDir::new()?;
let acquired = if let Some(ref src_path) = source {
fetch::local_source(src_path).map_err(|e| anyhow::anyhow!("{e}"))?
} else if let Some(ref url) = git {
let ref_name = git_ref.as_deref().expect("clap requires --ref with --git");
fetch::git_clone(url, ref_name, tmp_dir.path()).map_err(|e| anyhow::anyhow!("{e}"))?
} else {
let ver = match version {
Some(v) if v.matches('.').count() < 2 => {
fetch::fetch_version_for_prefix(&v).map_err(|e| anyhow::anyhow!("{e}"))?
}
Some(v) => v,
None => fetch::fetch_latest_stable_version().map_err(|e| anyhow::anyhow!("{e}"))?,
};
let (arch, _) = fetch::arch_info();
let cache_key = format!("{ver}-tarball-{arch}-kc{}", ktstr::cache_key_suffix());
if !force && let Some(entry) = cache.lookup(&cache_key) {
if entry.has_stale_kconfig(&cli::embedded_kconfig_hash()) {
eprintln!("ktstr: cached kernel has stale kconfig, rebuilding");
} else {
eprintln!("ktstr: cached kernel found: {}", entry.path.display());
eprintln!("ktstr: use --force to rebuild");
return Ok(());
}
}
let sp = cli::Spinner::start("Downloading kernel...");
let result =
fetch::download_tarball(&ver, tmp_dir.path()).map_err(|e| anyhow::anyhow!("{e}"));
sp.clear();
result?
};
if !force
&& (source.is_some() || git.is_some())
&& !acquired.is_dirty
&& let Some(entry) = cache.lookup(&acquired.cache_key)
{
if entry.has_stale_kconfig(&cli::embedded_kconfig_hash()) {
eprintln!("ktstr: cached kernel has stale kconfig, rebuilding");
} else {
eprintln!("ktstr: cached kernel found: {}", entry.path.display());
eprintln!("ktstr: use --force to rebuild");
return Ok(());
}
}
let source_dir = &acquired.source_dir;
if clean {
if source.is_none() {
eprintln!(
"ktstr: --clean is only meaningful with --source (downloaded sources start clean)"
);
} else {
eprintln!("ktstr: make mrproper");
cli::run_make(source_dir, &["mrproper"])?;
}
}
if !cli::has_sched_ext(source_dir) {
let sp = cli::Spinner::start("Configuring kernel...");
let result = cli::configure_kernel(source_dir, cli::EMBEDDED_KCONFIG);
if result.is_err() {
sp.clear();
} else {
sp.finish("Kernel configured");
}
result?;
}
let sp = cli::Spinner::start("Building kernel...");
let result = cli::make_kernel_with_output(source_dir, Some(&sp));
if result.is_err() {
sp.clear();
} else {
sp.finish("Kernel built");
}
result?;
cli::validate_kernel_config(source_dir)?;
if !acquired.is_temp {
eprintln!("ktstr: generating compile_commands.json");
cli::run_make(source_dir, &["compile_commands.json"])?;
}
let image_path = ktstr::kernel_path::find_image_in_dir(source_dir)
.ok_or_else(|| anyhow::anyhow!("no kernel image found in {}", source_dir.display()))?;
let vmlinux_path = source_dir.join("vmlinux");
let _stripped_dir;
let vmlinux_ref = if vmlinux_path.exists() {
let orig_mb = std::fs::metadata(&vmlinux_path)
.map(|m| m.len() as f64 / (1024.0 * 1024.0))
.unwrap_or(0.0);
match ktstr::cache::strip_vmlinux_debug(&vmlinux_path) {
Ok((dir, stripped_path)) => {
let stripped_mb = std::fs::metadata(&stripped_path)
.map(|m| m.len() as f64 / (1024.0 * 1024.0))
.unwrap_or(0.0);
eprintln!(
"ktstr: caching vmlinux ({orig_mb:.0} MB -> {stripped_mb:.0} MB, debug stripped)"
);
_stripped_dir = Some(dir);
Some(stripped_path)
}
Err(e) => {
eprintln!(
"ktstr: warning: vmlinux strip failed ({e:#}), caching unstripped ({orig_mb:.0} MB)"
);
_stripped_dir = None;
Some(vmlinux_path.clone())
}
}
} else {
eprintln!("ktstr: warning: vmlinux not found, BTF will not be cached");
_stripped_dir = None;
None
};
let vmlinux_ref = vmlinux_ref.as_deref();
if acquired.is_dirty {
eprintln!("ktstr: kernel built at {}", image_path.display());
eprintln!("ktstr: skipping cache (dirty tree)");
return Ok(());
}
let config_path = source_dir.join(".config");
let config_hash = if config_path.exists() {
let data = std::fs::read(&config_path)?;
Some(format!("{:08x}", crc32fast::hash(&data)))
} else {
None
};
let (arch, image_name) = fetch::arch_info();
let kconfig_hash = cli::embedded_kconfig_hash();
let metadata = KernelMetadata::new(
acquired.source_type.clone(),
arch.to_string(),
image_name.to_string(),
cli::now_iso8601(),
)
.with_version(acquired.version.clone())
.with_config_hash(config_hash)
.with_ktstr_kconfig_hash(Some(kconfig_hash))
.with_ktstr_git_hash(Some(ktstr::GIT_FULL_HASH.to_string()))
.with_git_hash(acquired.git_hash.clone())
.with_git_ref(acquired.git_ref.clone())
.with_source_tree_path(if source.is_some() {
Some(acquired.source_dir.clone())
} else {
None
});
let config_ref = config_path.exists().then_some(config_path.as_path());
let entry = cache.store(
&acquired.cache_key,
&image_path,
vmlinux_ref,
config_ref,
&metadata,
)?;
cli::success(&format!("\u{2713} Kernel cached: {}", acquired.cache_key));
eprintln!("ktstr: image: {}", entry.path.join(image_name).display());
Ok(())
}
fn run_completions(shell: clap_complete::Shell) {
let mut cmd = Cli::command();
clap_complete::generate(shell, &mut cmd, "ktstr", &mut std::io::stdout());
}
fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
)
.with_writer(std::io::stderr)
.init();
let args = Cli::parse();
match args.command {
Command::Run {
duration,
workers,
flags: flag_arg,
filter,
json,
repro,
probe_stack,
auto_repro,
kernel_dir,
work_type,
} => {
let parent_cgroup = format!("/sys/fs/cgroup/ktstr-{}", std::process::id());
let _guard = CgroupGuard {
path: parent_cgroup.clone(),
};
let active_flags = cli::resolve_flags(flag_arg)?;
let work_type_override = cli::parse_work_type(work_type.as_deref())?;
let config = cli::build_run_config(
parent_cgroup,
duration,
workers,
active_flags,
repro,
probe_stack,
auto_repro,
kernel_dir,
work_type_override,
);
let topo = TestTopology::from_system()?;
let runner = Runner::new(config, topo)?;
let scenarios = scenario::all_scenarios();
let refs = cli::filter_scenarios(&scenarios, filter.as_deref())?;
let results = runner.run_scenarios(&refs)?;
if json {
println!("{}", serde_json::to_string_pretty(&results)?);
} else {
for r in &results {
let status = if r.passed { "PASS" } else { "FAIL" };
println!("[{status}] {} ({:.1}s)", r.scenario_name, r.duration_s);
for d in &r.details {
println!(" {d}");
}
}
let passed = results.iter().filter(|r| r.passed).count();
let total = results.len();
println!("\n{passed}/{total} passed");
}
}
Command::List { filter, json } => {
let scenarios = scenario::all_scenarios();
let filtered: Vec<&scenario::Scenario> = scenarios
.iter()
.filter(|s| filter.as_ref().is_none_or(|f| s.name.contains(f.as_str())))
.collect();
if json {
let entries: Vec<serde_json::Value> = filtered
.iter()
.map(|s| {
let profiles: Vec<String> = s.profiles().iter().map(|p| p.name()).collect();
serde_json::json!({
"name": s.name,
"category": s.category,
"description": s.description,
"profiles": profiles,
})
})
.collect();
println!("{}", serde_json::to_string_pretty(&entries)?);
} else {
for s in &filtered {
let profiles: Vec<String> = s.profiles().iter().map(|p| p.name()).collect();
println!(
"{:<30} [{:<12}] {} (profiles: {})",
s.name,
s.category,
s.description,
profiles.join(", "),
);
}
println!("\n{} scenarios", filtered.len());
}
}
Command::Topo => {
let topo = TestTopology::from_system()?;
println!("CPUs: {}", topo.total_cpus());
println!("LLCs: {}", topo.num_llcs());
println!("NUMA nodes: {}", topo.num_numa_nodes());
for (i, llc) in topo.llcs().iter().enumerate() {
println!(" LLC {} (node {}): {:?}", i, llc.numa_node(), llc.cpus(),);
}
}
Command::Cleanup { parent_cgroup } => {
let cgroups = CgroupManager::new(&parent_cgroup);
cgroups.cleanup_all()?;
println!("cleaned up {parent_cgroup}");
}
Command::Kernel { command } => match command {
KernelCommand::List { json } => cli::kernel_list(json)?,
KernelCommand::Build {
version,
source,
git,
git_ref,
force,
clean,
} => kernel_build(version, source, git, git_ref, force, clean)?,
KernelCommand::Clean { keep, force } => cli::kernel_clean(keep, force)?,
},
Command::Shell {
kernel,
topology,
include_files,
memory_mb,
dmesg,
exec,
} => {
cli::check_kvm()?;
let kernel_path = cli::resolve_kernel_image(kernel.as_deref())?;
let parts: Vec<&str> = topology.split(',').collect();
anyhow::ensure!(
parts.len() == 3,
"invalid topology '{topology}': expected 'sockets,cores,threads' (e.g. '2,4,1')"
);
let sockets: u32 = parts[0]
.parse()
.map_err(|_| anyhow::anyhow!("invalid sockets value: '{}'", parts[0]))?;
let cores: u32 = parts[1]
.parse()
.map_err(|_| anyhow::anyhow!("invalid cores value: '{}'", parts[1]))?;
let threads: u32 = parts[2]
.parse()
.map_err(|_| anyhow::anyhow!("invalid threads value: '{}'", parts[2]))?;
anyhow::ensure!(
sockets > 0 && cores > 0 && threads > 0,
"invalid topology '{topology}': sockets, cores, and threads must all be >= 1"
);
let resolved_includes = cli::resolve_include_files(&include_files)?;
let include_refs: Vec<(&str, &Path)> = resolved_includes
.iter()
.map(|(a, p)| (a.as_str(), p.as_path()))
.collect();
ktstr::run_shell(
kernel_path,
sockets,
cores,
threads,
&include_refs,
memory_mb,
dmesg,
exec.as_deref(),
)?;
}
Command::Completions { shell } => {
run_completions(shell);
}
}
Ok(())
}