use std::collections::BTreeSet;
use anyhow::{Context, Result};
#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AffinityIntent {
#[default]
Inherit,
RandomSubset { from: BTreeSet<usize>, count: usize },
LlcAligned,
CrossCgroup,
SingleCpu,
Exact(BTreeSet<usize>),
SmtSiblingPair,
}
impl AffinityIntent {
pub fn exact(cpus: impl IntoIterator<Item = usize>) -> Self {
AffinityIntent::Exact(cpus.into_iter().collect())
}
pub fn random_subset(from: impl IntoIterator<Item = usize>, count: usize) -> Self {
AffinityIntent::RandomSubset {
from: from.into_iter().collect(),
count,
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ResolvedAffinity {
None,
Fixed(BTreeSet<usize>),
Random { from: BTreeSet<usize>, count: usize },
SingleCpu(usize),
}
pub(crate) fn resolve_affinity(mode: &ResolvedAffinity) -> Result<Option<BTreeSet<usize>>> {
match mode {
ResolvedAffinity::None => Ok(None),
ResolvedAffinity::Fixed(cpus) => Ok(Some(cpus.clone())),
ResolvedAffinity::SingleCpu(cpu) => Ok(Some([*cpu].into_iter().collect())),
ResolvedAffinity::Random { from, count } => {
use rand::seq::IndexedRandom;
if *count == 0 {
anyhow::bail!(
"ResolvedAffinity::Random.count must be > 0; a zero count \
previously silently coerced to 1, masking caller bugs"
);
}
if from.is_empty() {
tracing::debug!(
count = count,
"resolve_affinity: empty Random pool, leaving affinity unset"
);
return Ok(None);
}
let pool: Vec<usize> = from.iter().copied().collect();
let count = (*count).min(pool.len());
Ok(Some(
pool.sample(&mut rand::rng(), count).copied().collect(),
))
}
}
}
pub(crate) fn sched_getcpu() -> usize {
nix::sched::sched_getcpu().unwrap_or(0)
}
pub fn set_thread_affinity(pid: libc::pid_t, cpus: &BTreeSet<usize>) -> Result<()> {
use nix::sched::{CpuSet, sched_setaffinity};
use nix::unistd::Pid;
if pid <= 0 {
anyhow::bail!("sched_setaffinity: invalid pid {pid} (must be > 0)");
}
let mut cpu_set = CpuSet::new();
for &cpu in cpus {
cpu_set
.set(cpu)
.with_context(|| format!("CPU {cpu} out of range"))?;
}
sched_setaffinity(Pid::from_raw(pid), &cpu_set)
.with_context(|| format!("sched_setaffinity pid={pid}"))?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::BTreeSet;
#[test]
fn resolve_affinity_none() {
let r = resolve_affinity(&ResolvedAffinity::None).unwrap();
assert!(r.is_none());
}
#[test]
fn resolve_affinity_fixed() {
let cpus: BTreeSet<usize> = [0, 1, 2].into_iter().collect();
let r = resolve_affinity(&ResolvedAffinity::Fixed(cpus.clone())).unwrap();
assert_eq!(r, Some(cpus));
}
#[test]
fn resolve_affinity_single_cpu() {
let r = resolve_affinity(&ResolvedAffinity::SingleCpu(5)).unwrap();
assert_eq!(r, Some([5].into_iter().collect()));
}
#[test]
fn resolved_affinity_single_cpu_debug_format() {
let dbg = format!("{:?}", ResolvedAffinity::SingleCpu(7));
assert!(
dbg.contains("SingleCpu"),
"Debug output must name the variant, got: {dbg}"
);
assert!(
dbg.contains('7'),
"Debug output must include the CPU id payload, got: {dbg}"
);
}
#[test]
fn resolve_affinity_random() {
let from: BTreeSet<usize> = (0..8).collect();
let r = resolve_affinity(&ResolvedAffinity::Random { from, count: 3 }).unwrap();
let cpus = r.unwrap();
assert_eq!(cpus.len(), 3);
assert!(cpus.iter().all(|c| *c < 8));
}
#[test]
fn resolve_affinity_random_clamps_count() {
let from: BTreeSet<usize> = [0, 1].into_iter().collect();
let r = resolve_affinity(&ResolvedAffinity::Random { from, count: 10 }).unwrap();
assert_eq!(r.unwrap().len(), 2);
}
#[test]
fn resolve_affinity_random_single_cpu_pool() {
let from: BTreeSet<usize> = [7].into_iter().collect();
let r = resolve_affinity(&ResolvedAffinity::Random { from, count: 1 }).unwrap();
assert_eq!(r.unwrap(), [7].into_iter().collect());
}
#[test]
fn affinity_mode_debug_shows_cpus() {
let a = ResolvedAffinity::Fixed([0, 1, 7].into_iter().collect());
let s = format!("{:?}", a);
assert!(s.contains("0"), "must show CPU 0");
assert!(s.contains("1"), "must show CPU 1");
assert!(s.contains("7"), "must show CPU 7");
let b = ResolvedAffinity::Fixed([3, 4].into_iter().collect());
let s2 = format!("{:?}", b);
assert!(s2.contains("3"), "must show CPU 3");
assert_ne!(
s, s2,
"different CPU sets must produce different debug output"
);
}
#[test]
fn affinity_mode_clone_preserves_cpus() {
let cpus: BTreeSet<usize> = [2, 5, 7].into_iter().collect();
let a = ResolvedAffinity::Random {
from: cpus.clone(),
count: 2,
};
let b = a.clone();
match b {
ResolvedAffinity::Random { from, count } => {
assert_eq!(from, cpus, "cloned from set must match original");
assert_eq!(count, 2, "cloned count must match original");
}
_ => panic!("clone must preserve variant"),
}
}
#[test]
fn resolve_affinity_random_zero_count_rejected() {
let from: BTreeSet<usize> = (0..4).collect();
let err = resolve_affinity(&ResolvedAffinity::Random { from, count: 0 }).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("count") && msg.contains("> 0"),
"error must name the field: {msg}"
);
}
#[test]
fn resolve_affinity_random_empty_pool_is_none() {
let from: BTreeSet<usize> = BTreeSet::new();
let r = resolve_affinity(&ResolvedAffinity::Random { from, count: 1 }).unwrap();
assert!(r.is_none(), "empty Random pool must resolve to no affinity");
}
#[test]
fn sched_getcpu_valid() {
let cpu = sched_getcpu();
let max = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
assert!(cpu < max, "cpu {cpu} >= max {max}");
}
#[test]
fn set_thread_affinity_cpu_zero() {
let pid: libc::pid_t = unsafe { libc::getpid() };
let cpus: BTreeSet<usize> = [0].into_iter().collect();
let result = set_thread_affinity(pid, &cpus);
assert!(result.is_ok(), "pinning to CPU 0 should succeed");
}
}