use std::collections::BTreeSet;
use anyhow::{Context, Result};
#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AffinityIntent {
#[default]
Inherit,
RandomSubset { from: BTreeSet<usize>, count: usize },
LlcAligned,
CrossCgroup,
SingleCpu,
Exact(BTreeSet<usize>),
SmtSiblingPair,
}
impl AffinityIntent {
pub fn exact(cpus: impl IntoIterator<Item = usize>) -> Self {
AffinityIntent::Exact(cpus.into_iter().collect())
}
pub fn random_subset(from: impl IntoIterator<Item = usize>, count: usize) -> Self {
AffinityIntent::RandomSubset {
from: from.into_iter().collect(),
count,
}
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ResolvedAffinity {
#[default]
None,
Fixed(BTreeSet<usize>),
Random { from: BTreeSet<usize>, count: usize },
SingleCpu(usize),
}
impl ResolvedAffinity {
pub fn fixed(cpus: impl IntoIterator<Item = usize>) -> Self {
ResolvedAffinity::Fixed(cpus.into_iter().collect())
}
pub fn random(from: impl IntoIterator<Item = usize>, count: usize) -> Self {
ResolvedAffinity::Random {
from: from.into_iter().collect(),
count,
}
}
pub const fn single_cpu(cpu: usize) -> Self {
ResolvedAffinity::SingleCpu(cpu)
}
}
pub(crate) fn resolve_affinity(mode: &ResolvedAffinity) -> Result<Option<BTreeSet<usize>>> {
match mode {
ResolvedAffinity::None => Ok(None),
ResolvedAffinity::Fixed(cpus) => Ok(Some(cpus.clone())),
ResolvedAffinity::SingleCpu(cpu) => Ok(Some([*cpu].into_iter().collect())),
ResolvedAffinity::Random { from, count } => {
use rand::seq::IndexedRandom;
if *count == 0 {
anyhow::bail!(
"ResolvedAffinity::Random.count must be > 0; a zero count \
previously silently coerced to 1, masking caller bugs"
);
}
if from.is_empty() {
anyhow::bail!(
"ResolvedAffinity::Random.from is empty with count={count}; \
a worker cannot be pinned to an empty CPU pool. The \
resolution step that produced this Random must reject \
the empty set up-front (e.g. via the bail paths in \
`crate::scenario::resolve_affinity_for_cgroup`) — \
forwarding an unsatisfiable sample request would \
silently drop the affinity constraint",
count = count,
);
}
let pool: Vec<usize> = from.iter().copied().collect();
let count = (*count).min(pool.len());
Ok(Some(
pool.sample(&mut rand::rng(), count).copied().collect(),
))
}
}
}
pub(crate) fn sched_getcpu() -> usize {
nix::sched::sched_getcpu().unwrap_or(0)
}
pub fn set_thread_affinity(pid: libc::pid_t, cpus: &BTreeSet<usize>) -> Result<()> {
use nix::sched::{CpuSet, sched_setaffinity};
use nix::unistd::{Pid, SysconfVar, sysconf};
if pid <= 0 {
anyhow::bail!("sched_setaffinity: invalid pid {pid} (must be > 0)");
}
let online_cpus_str: std::borrow::Cow<'static, str> =
match sysconf(SysconfVar::_NPROCESSORS_ONLN).ok().flatten() {
Some(n) => format!("{n}").into(),
None => "unavailable".into(),
};
let cpuset_bitmap_width: usize = libc::CPU_SETSIZE as usize;
let mut cpu_set = CpuSet::new();
for &cpu in cpus {
cpu_set.set(cpu).with_context(|| {
format!(
"CPU {cpu} out of range: cpu_set bitmap holds CPU IDs \
0..{cpuset_bitmap_width} (libc CPU_SETSIZE) and host \
reports {online_cpus_str} online CPUs (sysconf \
_SC_NPROCESSORS_ONLN). Either the cpuset spec was \
resolved against a stale topology or the bitmap cap \
needs raising on this build."
)
})?;
}
sched_setaffinity(Pid::from_raw(pid), &cpu_set)
.with_context(|| format!("sched_setaffinity pid={pid}"))?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::BTreeSet;
#[test]
fn resolve_affinity_none() {
let r = resolve_affinity(&ResolvedAffinity::None).unwrap();
assert!(r.is_none());
}
#[test]
fn resolve_affinity_fixed() {
let cpus: BTreeSet<usize> = [0, 1, 2].into_iter().collect();
let r = resolve_affinity(&ResolvedAffinity::Fixed(cpus.clone())).unwrap();
assert_eq!(r, Some(cpus));
}
#[test]
fn resolve_affinity_single_cpu() {
let r = resolve_affinity(&ResolvedAffinity::SingleCpu(5)).unwrap();
assert_eq!(r, Some([5].into_iter().collect()));
}
#[test]
fn resolved_affinity_single_cpu_debug_format() {
let dbg = format!("{:?}", ResolvedAffinity::SingleCpu(7));
assert!(
dbg.contains("SingleCpu"),
"Debug output must name the variant, got: {dbg}"
);
assert!(
dbg.contains('7'),
"Debug output must include the CPU id payload, got: {dbg}"
);
}
#[test]
fn resolve_affinity_random() {
let from: BTreeSet<usize> = (0..8).collect();
let r = resolve_affinity(&ResolvedAffinity::Random { from, count: 3 }).unwrap();
let cpus = r.unwrap();
assert_eq!(cpus.len(), 3);
assert!(cpus.iter().all(|c| *c < 8));
}
#[test]
fn resolve_affinity_random_clamps_count() {
let from: BTreeSet<usize> = [0, 1].into_iter().collect();
let r = resolve_affinity(&ResolvedAffinity::Random { from, count: 10 }).unwrap();
assert_eq!(r.unwrap().len(), 2);
}
#[test]
fn resolve_affinity_random_single_cpu_pool() {
let from: BTreeSet<usize> = [7].into_iter().collect();
let r = resolve_affinity(&ResolvedAffinity::Random { from, count: 1 }).unwrap();
assert_eq!(r.unwrap(), [7].into_iter().collect());
}
#[test]
fn affinity_mode_debug_shows_cpus() {
let a = ResolvedAffinity::Fixed([0, 1, 7].into_iter().collect());
let s = format!("{:?}", a);
assert!(s.contains("0"), "must show CPU 0");
assert!(s.contains("1"), "must show CPU 1");
assert!(s.contains("7"), "must show CPU 7");
let b = ResolvedAffinity::Fixed([3, 4].into_iter().collect());
let s2 = format!("{:?}", b);
assert!(s2.contains("3"), "must show CPU 3");
assert_ne!(
s, s2,
"different CPU sets must produce different debug output"
);
}
#[test]
fn affinity_mode_clone_preserves_cpus() {
let cpus: BTreeSet<usize> = [2, 5, 7].into_iter().collect();
let a = ResolvedAffinity::Random {
from: cpus.clone(),
count: 2,
};
let b = a.clone();
match b {
ResolvedAffinity::Random { from, count } => {
assert_eq!(from, cpus, "cloned from set must match original");
assert_eq!(count, 2, "cloned count must match original");
}
_ => panic!("clone must preserve variant"),
}
}
#[test]
fn resolve_affinity_random_zero_count_rejected() {
let from: BTreeSet<usize> = (0..4).collect();
let err = resolve_affinity(&ResolvedAffinity::Random { from, count: 0 }).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("count") && msg.contains("> 0"),
"error must name the field: {msg}"
);
}
#[test]
fn resolve_affinity_random_empty_pool_bails() {
let from: BTreeSet<usize> = BTreeSet::new();
let err = resolve_affinity(&ResolvedAffinity::Random { from, count: 1 }).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("empty") && msg.contains("count=1"),
"diagnostic must name the empty pool and the count: got {msg}",
);
assert!(
msg.contains("resolve_affinity_for_cgroup"),
"diagnostic must point to the upstream resolver so callers \
learn where the empty pool should have been rejected: got {msg}",
);
}
#[test]
fn sched_getcpu_valid() {
let cpu = sched_getcpu();
let max = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
assert!(cpu < max, "cpu {cpu} >= max {max}");
}
#[test]
fn set_thread_affinity_cpu_zero() {
let pid: libc::pid_t = unsafe { libc::getpid() };
let cpus: BTreeSet<usize> = [0].into_iter().collect();
let result = set_thread_affinity(pid, &cpus);
assert!(result.is_ok(), "pinning to CPU 0 should succeed");
}
#[test]
fn resolved_affinity_constructors_match_direct_variants() {
let from_ctor = ResolvedAffinity::fixed([0_usize, 1, 2]);
let from_variant = ResolvedAffinity::Fixed([0_usize, 1, 2].into_iter().collect());
assert_eq!(from_ctor, from_variant);
let from_ctor = ResolvedAffinity::random([0_usize, 1, 2, 3], 2);
let from_variant = ResolvedAffinity::Random {
from: [0_usize, 1, 2, 3].into_iter().collect(),
count: 2,
};
assert_eq!(from_ctor, from_variant);
let from_ctor = ResolvedAffinity::single_cpu(5);
let from_variant = ResolvedAffinity::SingleCpu(5);
assert_eq!(from_ctor, from_variant);
}
const _: ResolvedAffinity = ResolvedAffinity::single_cpu(7);
#[test]
fn resolved_affinity_default_is_none_and_serde_roundtrip_per_variant() {
let d: ResolvedAffinity = Default::default();
assert_eq!(d, ResolvedAffinity::None);
let variants = [
ResolvedAffinity::None,
ResolvedAffinity::Fixed([0_usize, 1, 5].into_iter().collect()),
ResolvedAffinity::Fixed(BTreeSet::new()),
ResolvedAffinity::Random {
from: [0_usize, 1, 2, 3, 4].into_iter().collect(),
count: 2,
},
ResolvedAffinity::Random {
from: BTreeSet::new(),
count: 0,
},
ResolvedAffinity::SingleCpu(7),
];
for original in &variants {
let bytes = serde_json::to_vec(original).expect("serialize");
let restored: ResolvedAffinity = serde_json::from_slice(&bytes).expect("deserialize");
assert_eq!(restored, *original, "roundtrip drift for {original:?}");
}
}
}