cgroups_rs/systemd/
cpuset.rs

1// Copyright (c) 2025 Ant Group
2//
3// SPDX-License-Identifier: Apache-2.0 or MIT
4//
5
6use bit_vec::BitVec;
7
8use crate::systemd::error::{Error, Result};
9use crate::systemd::{ALLOWED_CPUS, ALLOWED_MEMORY_NODES};
10
11const BYTE_IN_BITS: usize = 8;
12
13/// Returns the property for cpuset CPUs.
14pub fn cpus(cpus: &str) -> Result<(&'static str, Vec<u8>)> {
15    let mask = convert_list_to_mask(cpus)?;
16    Ok((ALLOWED_CPUS, mask))
17}
18
19/// Returns the property for cpuset memory nodes.
20pub fn mems(mems: &str) -> Result<(&'static str, Vec<u8>)> {
21    let mask = convert_list_to_mask(mems)?;
22    Ok((ALLOWED_MEMORY_NODES, mask))
23}
24
25/// Convert cpuset cpus/mems from the string in comma-separated list format
26/// to bitmask restored in `Vec<u8>`, see [1].
27///
28/// 1: https://man7.org/linux/man-pages/man7/cpuset.7.html
29///
30/// # Arguments
31///
32/// * `list` - A string slice that holds the list of CPUs in the format
33///   "0-3,5,7".
34fn convert_list_to_mask(list: &str) -> Result<Vec<u8>> {
35    let mut bit_vec = BitVec::from_elem(8, false);
36
37    let local_idx =
38        |index: usize| -> usize { index / BYTE_IN_BITS * BYTE_IN_BITS + 7 - index % BYTE_IN_BITS };
39
40    for part1 in list.split(',') {
41        let range: Vec<&str> = part1.split('-').collect();
42        match range.len() {
43            // x-
44            1 => {
45                let left: usize = range[0].parse().map_err(|_| Error::InvalidArgument)?;
46
47                while left >= bit_vec.len() {
48                    bit_vec.grow(BYTE_IN_BITS, false);
49                }
50                bit_vec.set(local_idx(left), true);
51            }
52            // x-y
53            2 => {
54                let left: usize = range[0].parse().map_err(|_| Error::InvalidArgument)?;
55                let right: usize = range[1].parse().map_err(|_| Error::InvalidArgument)?;
56
57                while right >= bit_vec.len() {
58                    bit_vec.grow(BYTE_IN_BITS, false);
59                }
60
61                for index in left..=right {
62                    bit_vec.set(local_idx(index), true);
63                }
64            }
65            _ => {
66                return Err(Error::InvalidArgument);
67            }
68        }
69    }
70
71    let mut mask = bit_vec.to_bytes();
72    mask.reverse();
73
74    Ok(mask)
75}
76
77#[cfg(test)]
78mod tests {
79    use crate::systemd::cpuset::convert_list_to_mask;
80
81    #[test]
82    fn test_convert_list_to_mask() {
83        let mask = convert_list_to_mask("2-4").unwrap();
84        assert_eq!(vec![0b00011100_u8], mask);
85
86        let mask = convert_list_to_mask("1,7").unwrap();
87        assert_eq!(vec![0b10000010_u8], mask);
88
89        let mask = convert_list_to_mask("0-4,9").unwrap();
90        assert_eq!(vec![0b00000010_u8, 0b00011111_u8], mask);
91
92        assert!(convert_list_to_mask("1-3-4").is_err());
93
94        assert!(convert_list_to_mask("1-3,,").is_err());
95    }
96}