use std::fmt;
use oci_spec::runtime::{MemoryPolicyFlagType, MemoryPolicyModeType};
use crate::syscall::{Syscall, SyscallError};
#[derive(Debug, thiserror::Error)]
pub enum MemoryPolicyError {
#[error("Invalid memory policy flag: {0}")]
InvalidFlag(String),
#[error("Invalid node specification: {0}")]
InvalidNodes(String),
#[error("Incompatible flag and mode combination: {0}")]
IncompatibleFlagMode(String),
#[error("Mutually exclusive flags: {0}")]
MutuallyExclusiveFlags(String),
#[error("Syscall error: {0}")]
Syscall(#[from] SyscallError),
}
type Result<T> = std::result::Result<T, MemoryPolicyError>;
#[repr(i32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum MemoryPolicyMode {
Default = 0,
Preferred = 1,
Bind = 2,
Interleave = 3,
Local = 4,
PreferredMany = 5,
WeightedInterleave = 6,
}
impl From<MemoryPolicyMode> for i32 {
fn from(mode: MemoryPolicyMode) -> Self {
mode as i32
}
}
impl From<MemoryPolicyModeType> for MemoryPolicyMode {
fn from(mode: MemoryPolicyModeType) -> Self {
match mode {
MemoryPolicyModeType::MpolDefault => MemoryPolicyMode::Default,
MemoryPolicyModeType::MpolPreferred => MemoryPolicyMode::Preferred,
MemoryPolicyModeType::MpolBind => MemoryPolicyMode::Bind,
MemoryPolicyModeType::MpolInterleave => MemoryPolicyMode::Interleave,
MemoryPolicyModeType::MpolLocal => MemoryPolicyMode::Local,
MemoryPolicyModeType::MpolPreferredMany => MemoryPolicyMode::PreferredMany,
MemoryPolicyModeType::MpolWeightedInterleave => MemoryPolicyMode::WeightedInterleave,
}
}
}
impl fmt::Display for MemoryPolicyMode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
MemoryPolicyMode::Default => "MPOL_DEFAULT",
MemoryPolicyMode::Preferred => "MPOL_PREFERRED",
MemoryPolicyMode::Bind => "MPOL_BIND",
MemoryPolicyMode::Interleave => "MPOL_INTERLEAVE",
MemoryPolicyMode::Local => "MPOL_LOCAL",
MemoryPolicyMode::PreferredMany => "MPOL_PREFERRED_MANY",
MemoryPolicyMode::WeightedInterleave => "MPOL_WEIGHTED_INTERLEAVE",
};
write!(f, "{}", s)
}
}
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum MemoryPolicyFlag {
NumaBalancing = 1 << 13, RelativeNodes = 1 << 14, StaticNodes = 1 << 15, }
impl From<MemoryPolicyFlag> for u32 {
fn from(flag: MemoryPolicyFlag) -> Self {
flag as u32
}
}
struct ValidatedMemoryPolicy {
mode_with_flags: i32,
nodemask: Vec<libc::c_ulong>,
maxnode: u64,
}
fn validate_memory_policy(
memory_policy: &Option<oci_spec::runtime::LinuxMemoryPolicy>,
) -> Result<Option<ValidatedMemoryPolicy>> {
let Some(policy) = memory_policy else {
return Ok(None);
};
let base_mode = MemoryPolicyMode::from(policy.mode());
let (flags_value, has_static, has_relative) = policy
.flags()
.as_ref()
.map(|flags| {
flags
.iter()
.fold((0u32, false, false), |(val, s, r), flag| match flag {
MemoryPolicyFlagType::MpolFNumaBalancing => {
(val | u32::from(MemoryPolicyFlag::NumaBalancing), s, r)
}
MemoryPolicyFlagType::MpolFStaticNodes => {
(val | u32::from(MemoryPolicyFlag::StaticNodes), true, r)
}
MemoryPolicyFlagType::MpolFRelativeNodes => {
(val | u32::from(MemoryPolicyFlag::RelativeNodes), s, true)
}
})
})
.unwrap_or((0, false, false));
if let Some(flags) = policy.flags() {
if flags.contains(&MemoryPolicyFlagType::MpolFNumaBalancing)
&& base_mode != MemoryPolicyMode::Bind
{
return Err(MemoryPolicyError::IncompatibleFlagMode(
"MPOL_F_NUMA_BALANCING can only be used with MPOL_BIND".to_string(),
));
}
}
if has_static && has_relative {
return Err(MemoryPolicyError::MutuallyExclusiveFlags(
"MPOL_F_STATIC_NODES and MPOL_F_RELATIVE_NODES are mutually exclusive".to_string(),
));
}
let mode_with_flags = i32::from(base_mode) | (flags_value as i32);
match base_mode {
MemoryPolicyMode::Default | MemoryPolicyMode::Local => {
let mode_name = base_mode.to_string();
if let Some(nodes) = policy.nodes() {
if !nodes.trim().is_empty() {
return Err(MemoryPolicyError::InvalidNodes(format!(
"{} does not accept node specification",
mode_name
)));
}
}
if flags_value != 0 {
return Err(MemoryPolicyError::InvalidFlag(format!(
"{} does not accept flags",
mode_name
)));
}
Ok(Some(ValidatedMemoryPolicy {
mode_with_flags,
nodemask: Vec::new(),
maxnode: 0,
}))
}
MemoryPolicyMode::Preferred => {
let relative_or_static: u32 = u32::from(MemoryPolicyFlag::RelativeNodes)
| u32::from(MemoryPolicyFlag::StaticNodes);
let check_empty_nodes_flags = |flags_value: u32| -> Result<()> {
if flags_value & relative_or_static != 0u32 {
return Err(MemoryPolicyError::IncompatibleFlagMode(
"MPOL_PREFERRED with empty nodes cannot use MPOL_F_STATIC_NODES or MPOL_F_RELATIVE_NODES flags".to_string(),
));
}
Ok(())
};
match policy.nodes() {
None => {
check_empty_nodes_flags(flags_value)?;
Ok(Some(ValidatedMemoryPolicy {
mode_with_flags,
nodemask: Vec::new(),
maxnode: 0,
}))
}
Some(nodes) if nodes.trim().is_empty() => {
check_empty_nodes_flags(flags_value)?;
Ok(Some(ValidatedMemoryPolicy {
mode_with_flags,
nodemask: Vec::new(),
maxnode: 0,
}))
}
Some(nodes) => {
let (nodemask, maxnode) = build_nodemask(nodes)?;
if maxnode == 0 {
check_empty_nodes_flags(flags_value)?;
return Ok(Some(ValidatedMemoryPolicy {
mode_with_flags,
nodemask: Vec::new(),
maxnode: 0,
}));
}
Ok(Some(ValidatedMemoryPolicy {
mode_with_flags,
nodemask,
maxnode,
}))
}
}
}
_ => {
let mode_name = base_mode.to_string();
let nodes = match policy.nodes() {
None => {
return Err(MemoryPolicyError::InvalidNodes(format!(
"Mode {} requires non-empty node specification",
mode_name
)));
}
Some(nodes) if nodes.trim().is_empty() => {
return Err(MemoryPolicyError::InvalidNodes(format!(
"Mode {} requires non-empty node specification",
mode_name
)));
}
Some(nodes) => nodes,
};
let (nodemask, maxnode) = build_nodemask(nodes)?;
if maxnode == 0 {
return Err(MemoryPolicyError::InvalidNodes(format!(
"Mode {} requires non-empty node specification (parsed result is empty)",
mode_name
)));
}
Ok(Some(ValidatedMemoryPolicy {
mode_with_flags,
nodemask,
maxnode,
}))
}
}
}
pub fn setup_memory_policy(
memory_policy: &Option<oci_spec::runtime::LinuxMemoryPolicy>,
syscall: &dyn Syscall,
) -> Result<()> {
let validated = validate_memory_policy(memory_policy)?;
if let Some(valid) = validated {
syscall
.set_mempolicy(valid.mode_with_flags, &valid.nodemask, valid.maxnode)
.map_err(|err| {
tracing::error!(?err, "failed to set memory policy");
MemoryPolicyError::Syscall(err)
})?;
}
Ok(())
}
fn build_nodemask(nodes: &str) -> Result<(Vec<libc::c_ulong>, u64)> {
let node_ids = parse_node_string(nodes)?;
if node_ids.is_empty() {
return Ok((Vec::new(), 0));
}
let highest_node = node_ids.iter().max().copied().unwrap_or(0) as usize;
let bits_per_ulong = std::mem::size_of::<libc::c_ulong>() * 8;
let num_ulongs = (highest_node / bits_per_ulong) + 1;
let maxnode = (num_ulongs * bits_per_ulong) as u64;
let mut nodemask = vec![0 as libc::c_ulong; num_ulongs];
for node_id in node_ids {
let node_id = node_id as usize;
let word_index = node_id / bits_per_ulong;
let bit_index = node_id % bits_per_ulong;
if word_index < nodemask.len() {
nodemask[word_index] |= (1 as libc::c_ulong) << bit_index;
}
}
Ok((nodemask, maxnode))
}
fn parse_node_string(nodes: &str) -> Result<Vec<u32>> {
let mut node_ids = Vec::new();
let nodes = nodes.trim();
if nodes.is_empty() {
return Ok(node_ids);
}
for range in nodes.split(',') {
let range = range.trim();
if range.is_empty() {
continue; }
if let Some(dash_pos) = range.find('-') {
let start_str = range[..dash_pos].trim();
let end_str = range[dash_pos + 1..].trim();
let start: u32 = start_str.parse().map_err(|_| {
MemoryPolicyError::InvalidNodes(format!("Invalid node range start: {}", start_str))
})?;
let end: u32 = end_str.parse().map_err(|_| {
MemoryPolicyError::InvalidNodes(format!("Invalid node range end: {}", end_str))
})?;
if start > end {
return Err(MemoryPolicyError::InvalidNodes(format!(
"Invalid node range: {}-{}",
start, end
)));
}
for node in start..=end {
node_ids.push(node);
}
} else {
let node: u32 = range
.parse()
.map_err(|_| MemoryPolicyError::InvalidNodes(format!("Invalid node: {}", range)))?;
node_ids.push(node);
}
}
Ok(node_ids)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::syscall::syscall::create_syscall;
use crate::syscall::test::TestHelperSyscall;
#[test]
fn test_parse_node_string() {
assert_eq!(parse_node_string("").unwrap(), Vec::<u32>::new());
assert_eq!(parse_node_string("0").unwrap(), vec![0]);
assert_eq!(parse_node_string("1").unwrap(), vec![1]);
assert_eq!(parse_node_string("2").unwrap(), vec![2]);
assert_eq!(parse_node_string("0-2").unwrap(), vec![0, 1, 2]);
assert_eq!(parse_node_string("1-3").unwrap(), vec![1, 2, 3]);
assert_eq!(parse_node_string("0,2").unwrap(), vec![0, 2]);
assert_eq!(parse_node_string("0,1,3").unwrap(), vec![0, 1, 3]);
assert_eq!(parse_node_string("0-1,3").unwrap(), vec![0, 1, 3]);
assert_eq!(parse_node_string("0,2-3").unwrap(), vec![0, 2, 3]);
assert_eq!(parse_node_string(" 0 , 2 ").unwrap(), vec![0, 2]);
assert_eq!(parse_node_string(" 0 - 2 ").unwrap(), vec![0, 1, 2]);
assert_eq!(parse_node_string(" ").unwrap(), Vec::<u32>::new());
assert_eq!(parse_node_string(" , , ").unwrap(), Vec::<u32>::new());
assert!(parse_node_string("2-1").is_err()); assert!(parse_node_string("abc").is_err()); assert!(parse_node_string("0-abc").is_err()); }
#[test]
fn test_setup_memory_policy() {
use oci_spec::runtime::{LinuxMemoryPolicyBuilder, MemoryPolicyModeType};
let syscall = create_syscall();
assert!(setup_memory_policy(&None, syscall.as_ref()).is_ok());
let policy = LinuxMemoryPolicyBuilder::default()
.mode(MemoryPolicyModeType::MpolBind)
.nodes("0,1".to_string())
.flags(vec![])
.build()
.unwrap();
assert!(setup_memory_policy(&Some(policy), syscall.as_ref()).is_ok());
let got_args = syscall
.as_any()
.downcast_ref::<TestHelperSyscall>()
.unwrap()
.get_mempolicy_args();
assert_eq!(got_args.len(), 1);
assert_eq!(got_args[0].mode, 2); assert_eq!(got_args[0].nodemask.len(), 1); assert_eq!(got_args[0].nodemask[0], 3); assert_eq!(got_args[0].maxnode, 64);
let policy_with_flags = LinuxMemoryPolicyBuilder::default()
.mode(MemoryPolicyModeType::MpolBind)
.nodes("0".to_string())
.flags(vec![
oci_spec::runtime::MemoryPolicyFlagType::MpolFStaticNodes,
])
.build()
.unwrap();
assert!(setup_memory_policy(&Some(policy_with_flags), syscall.as_ref()).is_ok());
let got_args_with_flags = syscall
.as_any()
.downcast_ref::<TestHelperSyscall>()
.unwrap()
.get_mempolicy_args();
assert_eq!(got_args_with_flags.len(), 2);
assert_eq!(got_args_with_flags[1].mode, 2 | (1 << 15));
assert_eq!(got_args_with_flags[1].nodemask.len(), 1);
assert_eq!(got_args_with_flags[1].nodemask[0], 1); assert_eq!(got_args_with_flags[1].maxnode, 64);
let policy_invalid_flags = LinuxMemoryPolicyBuilder::default()
.mode(MemoryPolicyModeType::MpolBind)
.nodes("0".to_string())
.flags(vec![
oci_spec::runtime::MemoryPolicyFlagType::MpolFStaticNodes,
oci_spec::runtime::MemoryPolicyFlagType::MpolFRelativeNodes,
])
.build()
.unwrap();
assert!(setup_memory_policy(&Some(policy_invalid_flags), syscall.as_ref()).is_err());
let policy_invalid_numa_balancing = LinuxMemoryPolicyBuilder::default()
.mode(MemoryPolicyModeType::MpolInterleave)
.nodes("0".to_string())
.flags(vec![
oci_spec::runtime::MemoryPolicyFlagType::MpolFNumaBalancing,
])
.build()
.unwrap();
assert!(
setup_memory_policy(&Some(policy_invalid_numa_balancing), syscall.as_ref()).is_err()
);
let policy_default_with_nodes = LinuxMemoryPolicyBuilder::default()
.mode(MemoryPolicyModeType::MpolDefault)
.nodes("0".to_string())
.flags(vec![])
.build()
.unwrap();
assert!(setup_memory_policy(&Some(policy_default_with_nodes), syscall.as_ref()).is_err());
let policy_default_with_flags = LinuxMemoryPolicyBuilder::default()
.mode(MemoryPolicyModeType::MpolDefault)
.nodes("".to_string())
.flags(vec![
oci_spec::runtime::MemoryPolicyFlagType::MpolFStaticNodes,
])
.build()
.unwrap();
assert!(setup_memory_policy(&Some(policy_default_with_flags), syscall.as_ref()).is_err());
let policy_local_with_nodes = LinuxMemoryPolicyBuilder::default()
.mode(MemoryPolicyModeType::MpolLocal)
.nodes("0".to_string())
.flags(vec![])
.build()
.unwrap();
assert!(setup_memory_policy(&Some(policy_local_with_nodes), syscall.as_ref()).is_err());
let policy_bind_empty = LinuxMemoryPolicyBuilder::default()
.mode(MemoryPolicyModeType::MpolBind)
.nodes("".to_string())
.flags(vec![])
.build()
.unwrap();
assert!(setup_memory_policy(&Some(policy_bind_empty), syscall.as_ref()).is_err());
let policy_bind_whitespace = LinuxMemoryPolicyBuilder::default()
.mode(MemoryPolicyModeType::MpolBind)
.nodes(" ".to_string())
.flags(vec![])
.build()
.unwrap();
assert!(setup_memory_policy(&Some(policy_bind_whitespace), syscall.as_ref()).is_err());
let policy_preferred_empty_with_flags = LinuxMemoryPolicyBuilder::default()
.mode(MemoryPolicyModeType::MpolPreferred)
.nodes("".to_string())
.flags(vec![
oci_spec::runtime::MemoryPolicyFlagType::MpolFStaticNodes,
])
.build()
.unwrap();
assert!(
setup_memory_policy(&Some(policy_preferred_empty_with_flags), syscall.as_ref())
.is_err()
);
}
}