use crate::NR_CPU_IDS;
use anyhow::bail;
use anyhow::Context;
use anyhow::Result;
use bitvec::prelude::*;
use sscanf::sscanf;
use std::fmt;
use std::ops::BitAndAssign;
use std::ops::BitOrAssign;
use std::ops::BitXorAssign;
#[cfg(any(test, feature = "testutils"))]
thread_local! {
static MASK_WIDTH_OVERRIDE: std::cell::Cell<usize> = const { std::cell::Cell::new(0) };
}
fn mask_width() -> usize {
#[cfg(any(test, feature = "testutils"))]
{
let ovr = MASK_WIDTH_OVERRIDE.with(|c| c.get());
if ovr > 0 {
return ovr;
}
}
*NR_CPU_IDS
}
#[cfg(any(test, feature = "testutils"))]
pub fn set_cpumask_test_width(width: usize) {
MASK_WIDTH_OVERRIDE.with(|c| c.set(width));
}
#[derive(Debug, Eq, Clone, Hash, Ord, PartialEq, PartialOrd)]
pub struct Cpumask {
mask: BitVec<u64, Lsb0>,
}
impl Cpumask {
fn check_cpu(&self, cpu: usize) -> Result<()> {
if cpu >= mask_width() {
bail!("Invalid CPU {} passed, max {}", cpu, mask_width());
}
Ok(())
}
pub fn new() -> Cpumask {
Cpumask {
mask: bitvec![u64, Lsb0; 0; mask_width()],
}
}
pub fn from_str(cpumask: &str) -> Result<Cpumask> {
match cpumask {
"none" => {
let mask = bitvec![u64, Lsb0; 0; mask_width()];
return Ok(Self { mask });
}
"all" => {
let mask = bitvec![u64, Lsb0; 1; mask_width()];
return Ok(Self { mask });
}
_ => {}
}
let hex_str = {
let mut tmp_str = cpumask
.strip_prefix("0x")
.unwrap_or(cpumask)
.replace('_', "");
if tmp_str.len() % 2 != 0 {
tmp_str = "0".to_string() + &tmp_str;
}
tmp_str
};
let byte_vec =
hex::decode(&hex_str).with_context(|| format!("Failed to parse cpumask: {cpumask}"))?;
let mut mask = bitvec![u64, Lsb0; 0; mask_width()];
for (index, &val) in byte_vec.iter().rev().enumerate() {
let mut v = val;
while v != 0 {
let lsb = v.trailing_zeros() as usize;
v &= !(1 << lsb);
let cpu = index * 8 + lsb;
if cpu >= mask_width() {
bail!(
concat!(
"Found cpu ({}) in cpumask ({}) which is larger",
" than the number of cpus on the machine ({})"
),
cpu,
cpumask,
mask_width()
);
}
mask.set(cpu, true);
}
}
Ok(Self { mask })
}
pub fn from_cpulist(cpulist: &str) -> Result<Cpumask> {
let mut mask = Cpumask::new();
for cpu_id in read_cpulist(cpulist)? {
let _ = mask.set_cpu(cpu_id);
}
Ok(mask)
}
pub fn to_cpulist(&self) -> String {
let cpus: Vec<usize> = self.iter().collect();
if cpus.is_empty() {
return String::from("none");
}
let mut ranges = Vec::new();
let mut start = cpus[0];
let mut end = cpus[0];
for &cpu in &cpus[1..] {
if cpu == end + 1 {
end = cpu;
} else {
ranges.push(if start == end {
format!("{}", start)
} else {
format!("{}-{}", start, end)
});
start = cpu;
end = cpu;
}
}
ranges.push(if start == end {
format!("{}", start)
} else {
format!("{}-{}", start, end)
});
ranges.join(",")
}
pub fn from_vec(vec: Vec<u64>) -> Self {
Self {
mask: BitVec::from_vec(vec),
}
}
pub fn from_bitvec(bitvec: BitVec<u64, Lsb0>) -> Self {
Self { mask: bitvec }
}
pub fn as_raw_slice(&self) -> &[u64] {
self.mask.as_raw_slice()
}
pub fn as_raw_bitvec_mut(&mut self) -> &mut BitVec<u64, Lsb0> {
&mut self.mask
}
pub fn as_raw_bitvec(&self) -> &BitVec<u64, Lsb0> {
&self.mask
}
pub fn set_all(&mut self) {
self.mask.fill(true);
}
pub fn clear_all(&mut self) {
self.mask.fill(false);
}
pub fn set_cpu(&mut self, cpu: usize) -> Result<()> {
self.check_cpu(cpu)?;
self.mask.set(cpu, true);
Ok(())
}
pub fn clear_cpu(&mut self, cpu: usize) -> Result<()> {
self.check_cpu(cpu)?;
self.mask.set(cpu, false);
Ok(())
}
pub fn test_cpu(&self, cpu: usize) -> bool {
match self.mask.get(cpu) {
Some(bit) => *bit,
None => false,
}
}
pub fn weight(&self) -> usize {
self.mask.count_ones()
}
pub fn is_empty(&self) -> bool {
self.mask.count_ones() == 0
}
pub fn is_full(&self) -> bool {
self.mask.count_ones() == mask_width()
}
pub fn len(&self) -> usize {
mask_width()
}
pub fn not(&self) -> Cpumask {
let mut new = self.clone();
new.mask = !new.mask;
new
}
pub fn and(&self, other: &Cpumask) -> Cpumask {
let mut new = self.clone();
new.mask &= other.mask.clone();
new
}
pub fn or(&self, other: &Cpumask) -> Cpumask {
let mut new = self.clone();
new.mask |= other.mask.clone();
new
}
pub fn xor(&self, other: &Cpumask) -> Cpumask {
let mut new = self.clone();
new.mask ^= other.mask.clone();
new
}
pub fn iter(&self) -> CpumaskIterator<'_> {
CpumaskIterator {
mask: self,
index: 0,
}
}
pub unsafe fn write_to_ptr(&self, bpfptr: *mut u64, len: usize) -> Result<()> {
let cpumask_slice = self.as_raw_slice();
if len != cpumask_slice.len() {
bail!(
"BPF CPU mask has length {} u64s, Cpumask size is {}",
len,
cpumask_slice.len()
);
}
let ptr = bpfptr as *mut [u64; 64];
let bpfmask: &mut [u64; 64] = unsafe { &mut *ptr };
let (left, _) = bpfmask.split_at_mut(cpumask_slice.len());
left.clone_from_slice(cpumask_slice);
Ok(())
}
fn fmt_with(&self, f: &mut fmt::Formatter<'_>, case: char) -> fmt::Result {
let mut masks: Vec<u32> = self
.as_raw_slice()
.iter()
.flat_map(|x| [*x as u32, (x >> 32) as u32])
.collect();
masks.truncate((mask_width()).div_ceil(32));
let width = match (mask_width()).div_ceil(4) % 8 {
0 => 8,
v => v,
};
match case {
'x' => write!(f, "{:0width$x}", masks.pop().unwrap(), width = width)?,
'X' => write!(f, "{:0width$X}", masks.pop().unwrap(), width = width)?,
_ => unreachable!(),
}
for submask in masks.iter().rev() {
match case {
'x' => write!(f, ",{submask:08x}")?,
'X' => write!(f, ",{submask:08X}")?,
_ => unreachable!(),
}
}
Ok(())
}
}
pub fn read_cpulist(cpulist: &str) -> Result<Vec<usize>> {
let cpulist = cpulist.trim_end_matches('\0');
let cpu_groups: Vec<&str> = cpulist.split(',').collect();
let mut cpu_ids = vec![];
for group in cpu_groups.iter() {
let (min, max) = match sscanf!(group.trim(), "{usize}-{usize}") {
Ok((x, y)) => (x, y),
Err(_) => match sscanf!(group.trim(), "{usize}") {
Ok(x) => (x, x),
Err(_) => {
bail!("Failed to parse cpulist {}", group.trim());
}
},
};
for i in min..(max + 1) {
cpu_ids.push(i);
}
}
Ok(cpu_ids)
}
pub struct CpumaskIterator<'a> {
mask: &'a Cpumask,
index: usize,
}
impl Iterator for CpumaskIterator<'_> {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
while self.index < mask_width() {
let index = self.index;
self.index += 1;
let bit_val = self.mask.test_cpu(index);
if bit_val {
return Some(index);
}
}
None
}
}
impl fmt::Display for Cpumask {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.fmt_with(f, 'x')
}
}
impl fmt::LowerHex for Cpumask {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.fmt_with(f, 'x')
}
}
impl fmt::UpperHex for Cpumask {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.fmt_with(f, 'X')
}
}
impl BitAndAssign<&Self> for Cpumask {
fn bitand_assign(&mut self, rhs: &Self) {
self.mask &= &rhs.mask;
}
}
impl BitOrAssign<&Self> for Cpumask {
fn bitor_assign(&mut self, rhs: &Self) {
self.mask |= &rhs.mask;
}
}
impl BitXorAssign<&Self> for Cpumask {
fn bitxor_assign(&mut self, rhs: &Self) {
self.mask ^= &rhs.mask;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_to_cpulist_empty() {
let mask = Cpumask::new();
assert_eq!(mask.to_cpulist(), "none");
}
#[test]
fn test_to_cpulist_single_cpu() {
let mut mask = Cpumask::new();
mask.set_cpu(5).unwrap();
assert_eq!(mask.to_cpulist(), "5");
}
#[test]
fn test_to_cpulist_contiguous_range() {
let mut mask = Cpumask::new();
for cpu in 0..8 {
mask.set_cpu(cpu).unwrap();
}
assert_eq!(mask.to_cpulist(), "0-7");
}
#[test]
fn test_to_cpulist_multiple_ranges() {
let mut mask = Cpumask::new();
for cpu in 0..4 {
mask.set_cpu(cpu).unwrap();
}
for cpu in 8..12 {
mask.set_cpu(cpu).unwrap();
}
assert_eq!(mask.to_cpulist(), "0-3,8-11");
}
#[test]
fn test_to_cpulist_scattered() {
let mut mask = Cpumask::new();
mask.set_cpu(1).unwrap();
mask.set_cpu(3).unwrap();
mask.set_cpu(5).unwrap();
assert_eq!(mask.to_cpulist(), "1,3,5");
}
#[test]
fn test_to_cpulist_mixed() {
let mut mask = Cpumask::new();
mask.set_cpu(0).unwrap();
mask.set_cpu(1).unwrap();
mask.set_cpu(2).unwrap();
mask.set_cpu(5).unwrap();
mask.set_cpu(10).unwrap();
mask.set_cpu(11).unwrap();
assert_eq!(mask.to_cpulist(), "0-2,5,10-11");
}
#[test]
fn test_to_cpulist_roundtrip() {
let original = "0-3,8-11,16";
let mask = Cpumask::from_cpulist(original).unwrap();
assert_eq!(mask.to_cpulist(), original);
}
}