use std::fmt;
#[derive(Debug)]
pub enum AffinityError {
Unsupported,
InvalidCore(usize),
OsError(i32),
}
impl fmt::Display for AffinityError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AffinityError::Unsupported => {
write!(f, "affinity pinning not supported on this platform")
}
AffinityError::InvalidCore(c) => write!(f, "invalid core index: {c}"),
AffinityError::OsError(e) => write!(f, "os error: {e}"),
}
}
}
impl std::error::Error for AffinityError {}
pub fn set_affinity(cores: &[usize]) -> Result<(), AffinityError> {
if cores.is_empty() {
return Err(AffinityError::InvalidCore(0));
}
#[cfg(target_os = "linux")]
{
linux::set_affinity_linux(cores)
}
#[cfg(target_os = "windows")]
{
windows::set_affinity_windows(cores)
}
#[cfg(not(any(target_os = "linux", target_os = "windows")))]
{
let _ = cores;
Err(AffinityError::Unsupported)
}
}
#[cfg(target_os = "linux")]
mod linux {
use super::AffinityError;
use std::mem;
#[repr(C)]
struct CpuSetT {
bits: [u64; 16], }
unsafe extern "C" {
fn sched_setaffinity(pid: u32, cpusetsize: usize, mask: *const CpuSetT) -> i32;
}
pub(super) fn set_affinity_linux(cores: &[usize]) -> Result<(), AffinityError> {
let mut set = CpuSetT { bits: [0u64; 16] };
for &c in cores {
if c >= 1024 {
return Err(AffinityError::InvalidCore(c));
}
set.bits[c / 64] |= 1u64 << (c % 64);
}
let rc = unsafe { sched_setaffinity(0, mem::size_of::<CpuSetT>(), &set as *const _) };
if rc == 0 {
Ok(())
} else {
Err(AffinityError::OsError(rc))
}
}
}
#[cfg(target_os = "windows")]
mod windows {
use super::AffinityError;
type Handle = *mut core::ffi::c_void;
type DWordPtr = usize;
unsafe extern "system" {
fn GetCurrentThread() -> Handle;
fn SetThreadAffinityMask(thread: Handle, mask: DWordPtr) -> DWordPtr;
}
pub(super) fn set_affinity_windows(cores: &[usize]) -> Result<(), AffinityError> {
let mut mask: usize = 0;
let max_bits = usize::BITS as usize;
for &c in cores {
if c >= max_bits {
return Err(AffinityError::InvalidCore(c));
}
mask |= 1usize << c;
}
let prev = unsafe { SetThreadAffinityMask(GetCurrentThread(), mask) };
if prev == 0 {
Err(AffinityError::OsError(0))
} else {
Ok(())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_cores_is_rejected() {
let err = set_affinity(&[]);
assert!(matches!(err, Err(AffinityError::InvalidCore(0))));
}
#[test]
fn out_of_range_core_returns_invalid() {
let err = set_affinity(&[usize::MAX]);
assert!(err.is_err());
}
#[cfg(any(target_os = "linux", target_os = "windows"))]
#[test]
fn pinning_to_core_zero_succeeds_on_supported_platforms() {
let result = set_affinity(&[0]);
assert!(
matches!(result, Ok(()) | Err(AffinityError::OsError(_))),
"unexpected: {result:?}"
);
}
#[cfg(not(any(target_os = "linux", target_os = "windows")))]
#[test]
fn unsupported_platform_returns_unsupported() {
let result = set_affinity(&[0]);
assert!(matches!(result, Err(AffinityError::Unsupported)));
}
#[test]
fn display_messages_render() {
let m1 = format!("{}", AffinityError::Unsupported);
let m2 = format!("{}", AffinityError::InvalidCore(7));
let m3 = format!("{}", AffinityError::OsError(42));
assert!(m1.contains("not supported"));
assert!(m2.contains("7"));
assert!(m3.contains("42"));
}
#[test]
fn debug_messages_render() {
let s = format!("{:?}", AffinityError::OsError(1));
assert!(s.contains("OsError"));
}
}