use crate::error::CpuAffinityError;
#[cfg(target_os = "linux")]
use std::{fs, io};
#[cfg(target_os = "linux")]
const CPU_SETSIZE: usize = 1024;
#[cfg(target_os = "linux")]
pub type ThreadId = libc::pid_t;
#[cfg(not(target_os = "linux"))]
pub type ThreadId = i32;
#[cfg(target_os = "linux")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct CpuListInfo {
max_id: usize,
count: usize,
}
#[cfg(target_os = "linux")]
pub fn current_thread_id() -> Result<ThreadId, CpuAffinityError> {
let thread_id = unsafe { libc::syscall(libc::SYS_gettid) };
if thread_id <= 0 {
return Err(CpuAffinityError::Io(io::Error::last_os_error()));
}
Ok(thread_id as ThreadId)
}
#[cfg(not(target_os = "linux"))]
pub fn current_thread_id() -> Result<ThreadId, CpuAffinityError> {
Err(CpuAffinityError::NotSupported)
}
#[cfg(target_os = "linux")]
pub fn set_cpu_affinity(
thread_id: ThreadId,
cpus: impl IntoIterator<Item = usize>,
) -> Result<(), CpuAffinityError> {
let mut cpu_set: libc::cpu_set_t = unsafe { std::mem::zeroed() };
let mut max_cpu = None;
let mut has_cpus = false;
for cpu in cpus {
let max_cpu = *max_cpu.get_or_insert(max_cpu_id()?);
if cpu > max_cpu {
return Err(CpuAffinityError::InvalidCpu { cpu, max: max_cpu });
}
if cpu >= CPU_SETSIZE {
return Err(CpuAffinityError::InvalidCpu {
cpu,
max: CPU_SETSIZE - 1,
});
}
if unsafe { libc::CPU_ISSET(cpu, &cpu_set) } {
continue;
}
unsafe {
libc::CPU_SET(cpu, &mut cpu_set);
}
has_cpus = true;
}
if !has_cpus {
return Err(CpuAffinityError::EmptyCpuList);
}
let result = unsafe {
libc::sched_setaffinity(thread_id, std::mem::size_of::<libc::cpu_set_t>(), &cpu_set)
};
if result != 0 {
return Err(CpuAffinityError::Io(io::Error::last_os_error()));
}
Ok(())
}
#[cfg(not(target_os = "linux"))]
pub fn set_cpu_affinity(
_thread_id: ThreadId,
_cpus: impl IntoIterator<Item = usize>,
) -> Result<(), CpuAffinityError> {
Err(CpuAffinityError::NotSupported)
}
#[cfg(target_os = "linux")]
pub fn set_current_thread_affinity(
cpus: impl IntoIterator<Item = usize>,
) -> Result<(), CpuAffinityError> {
set_cpu_affinity(0, cpus)
}
#[cfg(not(target_os = "linux"))]
pub fn set_current_thread_affinity(
_cpus: impl IntoIterator<Item = usize>,
) -> Result<(), CpuAffinityError> {
Err(CpuAffinityError::NotSupported)
}
pub fn set_thread_affinity(
thread_id: ThreadId,
cpus: impl IntoIterator<Item = usize>,
) -> Result<(), CpuAffinityError> {
set_cpu_affinity(thread_id, cpus)
}
#[cfg(target_os = "linux")]
pub fn for_each_online_cpu(mut visit: impl FnMut(usize)) -> Result<(), CpuAffinityError> {
if let Ok(content) = fs::read_to_string("/sys/devices/system/cpu/online") {
let content = content.trim();
if parse_cpu_list(content).is_some() {
let _ = fold_cpu_list(content, |start, end| {
for cpu in start..=end {
visit(cpu);
}
});
return Ok(());
}
}
let info = sysconf_cpu_info()?;
for cpu in 0..=info.max_id {
visit(cpu);
}
Ok(())
}
#[cfg(not(target_os = "linux"))]
pub fn for_each_online_cpu(_visit: impl FnMut(usize)) -> Result<(), CpuAffinityError> {
Err(CpuAffinityError::NotSupported)
}
#[cfg(target_os = "linux")]
pub fn max_cpu_id() -> Result<usize, CpuAffinityError> {
Ok(online_cpu_info()?.max_id)
}
#[cfg(target_os = "linux")]
fn online_cpu_info() -> Result<CpuListInfo, CpuAffinityError> {
if let Ok(content) = fs::read_to_string("/sys/devices/system/cpu/online")
&& let Some(info) = parse_cpu_list(content.trim())
{
return Ok(info);
}
sysconf_cpu_info()
}
#[cfg(target_os = "linux")]
fn sysconf_cpu_info() -> Result<CpuListInfo, CpuAffinityError> {
let count = unsafe { libc::sysconf(libc::_SC_NPROCESSORS_ONLN) };
if count <= 0 {
return Err(CpuAffinityError::Io(io::Error::last_os_error()));
}
let count = count as usize;
Ok(CpuListInfo {
max_id: count - 1,
count,
})
}
#[cfg(target_os = "linux")]
fn parse_cpu_list(input: &str) -> Option<CpuListInfo> {
fold_cpu_list(input, |_, _| {})
}
#[cfg(target_os = "linux")]
fn fold_cpu_list(mut input: &str, mut range: impl FnMut(usize, usize)) -> Option<CpuListInfo> {
input = input.trim();
if input.is_empty() {
return None;
}
let mut count = 0usize;
let mut max_id = 0usize;
for part in input.split(',') {
let (start, end) = parse_cpu_range(part.trim())?;
let span = end.checked_sub(start)?.checked_add(1)?;
count = count.checked_add(span)?;
max_id = max_id.max(end);
range(start, end);
}
Some(CpuListInfo { max_id, count })
}
#[cfg(target_os = "linux")]
fn parse_cpu_range(input: &str) -> Option<(usize, usize)> {
if input.is_empty() {
return None;
}
let (start, end) = match input.split_once('-') {
Some((start, end)) => (
start.trim().parse::<usize>().ok()?,
end.trim().parse::<usize>().ok()?,
),
None => {
let cpu = input.parse::<usize>().ok()?;
(cpu, cpu)
}
};
(start <= end).then_some((start, end))
}
#[cfg(not(target_os = "linux"))]
pub fn max_cpu_id() -> Result<usize, CpuAffinityError> {
Err(CpuAffinityError::NotSupported)
}
#[cfg(target_os = "linux")]
pub fn cpu_count() -> Result<usize, CpuAffinityError> {
Ok(online_cpu_info()?.count)
}
#[cfg(not(target_os = "linux"))]
pub fn cpu_count() -> Result<usize, CpuAffinityError> {
Err(CpuAffinityError::NotSupported)
}
#[cfg(test)]
#[cfg(target_os = "linux")]
mod tests {
use super::{CpuListInfo, ThreadId, current_thread_id, parse_cpu_list, set_cpu_affinity};
use crate::CpuAffinityError;
use core::mem::size_of;
use std::{sync::mpsc, thread};
#[test]
fn parses_single_cpu() {
assert_eq!(
parse_cpu_list("0"),
Some(CpuListInfo {
max_id: 0,
count: 1
})
);
}
#[test]
fn parses_dense_range() {
assert_eq!(
parse_cpu_list("0-3"),
Some(CpuListInfo {
max_id: 3,
count: 4
})
);
}
#[test]
fn parses_sparse_ranges() {
assert_eq!(
parse_cpu_list("0-3,8-11"),
Some(CpuListInfo {
max_id: 11,
count: 8
})
);
}
#[test]
fn rejects_invalid_ranges() {
assert_eq!(parse_cpu_list("3-1"), None);
assert_eq!(parse_cpu_list(""), None);
assert_eq!(parse_cpu_list("0-3,"), None);
}
#[test]
fn set_cpu_affinity_rejects_empty_cpu_list_before_syscall() {
assert!(matches!(
set_cpu_affinity(current_thread_id().unwrap(), []),
Err(CpuAffinityError::EmptyCpuList)
));
}
#[test]
fn thread_id_matches_linux_pid_type() {
assert_eq!(size_of::<ThreadId>(), size_of::<libc::pid_t>());
}
#[test]
fn current_thread_id_is_positive() {
assert!(current_thread_id().unwrap() > 0);
}
#[test]
fn spawned_thread_has_distinct_thread_id() {
let main_thread_id = current_thread_id().unwrap();
let (sender, receiver) = mpsc::channel();
let worker = thread::spawn(move || {
sender.send(current_thread_id().unwrap()).unwrap();
});
let worker_thread_id = receiver.recv().unwrap();
worker.join().unwrap();
assert_ne!(main_thread_id, worker_thread_id);
}
}