use std::{collections::HashSet, env, time::Duration};
use anyhow::{bail, Context};
use tracing::warn;
#[derive(Debug, Default)]
enum AutoCpus {
#[default]
Auto,
Count(usize),
List(String),
}
#[derive(Debug, Default)]
pub struct ConstraintsBuilder {
total_ram: Option<usize>,
agent_ram: Option<usize>,
cpus: AutoCpus,
cpus_per_agent: Option<usize>,
time_budget: Option<Duration>,
action_timeout: Option<Duration>,
time_margin: Duration,
}
impl ConstraintsBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn from_env() -> Self {
fn parse_usize(var: &str) -> Option<usize> {
env::var(var).ok()?.parse().ok()
}
fn parse_duration_secs(var: &str) -> Option<Duration> {
env::var(var)
.ok()?
.parse::<u64>()
.ok()
.map(Duration::from_secs)
}
fn parse_duration_millis(var: &str) -> Option<Duration> {
env::var(var)
.ok()?
.parse::<u64>()
.ok()
.map(Duration::from_millis)
}
let max_total_ram = parse_usize("MAX_TOTAL_RAM");
let ram_per_agent = parse_usize("RAM_PER_AGENT");
let cpu_list = env::var("CPU_LIST").ok();
let total_cpu_count = parse_usize("TOTAL_CPU_COUNT");
let cpus_per_agent = parse_usize("CPUS_PER_AGENT");
let time_budget = parse_duration_secs("TIME_BUDGET_SECS");
let action_timeout = parse_duration_millis("ACTION_TIMEOUT_MS");
let time_margin = env::var("TIME_MARGIN_MS")
.ok()
.and_then(|s| s.parse::<u64>().ok())
.map(Duration::from_millis)
.unwrap_or(Duration::ZERO);
let cpus = if let Some(cpus_str) = cpu_list {
AutoCpus::List(cpus_str)
} else if let Some(count) = total_cpu_count {
AutoCpus::Count(count)
} else {
AutoCpus::Auto
};
ConstraintsBuilder {
total_ram: max_total_ram,
agent_ram: ram_per_agent,
cpus,
cpus_per_agent,
time_budget,
action_timeout,
time_margin,
}
}
#[must_use]
pub fn with_max_total_ram(self, max: usize) -> Self {
Self {
total_ram: Some(max),
..self
}
}
#[must_use]
pub fn with_ram_per_agent(self, max: usize) -> Self {
Self {
agent_ram: Some(max),
..self
}
}
#[must_use]
pub fn with_cpu_list(self, cpus: &str) -> Self {
Self {
cpus: AutoCpus::List(cpus.to_string()),
..self
}
}
#[must_use]
pub fn with_total_cpu_count(self, max: usize) -> Self {
if let AutoCpus::List(_) = self.cpus {
warn!("`with_total_cpu_count` is ignored if `with_cpu_list` is used!");
self
} else {
Self {
cpus: AutoCpus::Count(max),
..self
}
}
}
#[must_use]
pub fn with_cpus_per_agent(self, max: usize) -> Self {
Self {
cpus_per_agent: Some(max),
..self
}
}
#[must_use]
pub fn with_time_budget(self, duration: Duration) -> Self {
Self {
time_budget: Some(duration),
..self
}
}
#[must_use]
pub fn with_action_timeout(self, duration: Duration) -> Self {
Self {
action_timeout: Some(duration),
..self
}
}
#[must_use]
pub fn with_time_margin(self, duration: Duration) -> Self {
Self {
time_margin: duration,
..self
}
}
pub fn build(self) -> anyhow::Result<Constraints> {
let mut sys = sysinfo::System::new();
let total_ram = self.total_ram.map(|i| i * 1_000_000).unwrap_or_else(|| {
sys.refresh_memory();
sys.available_memory() as usize
});
if total_ram < (self.agent_ram.unwrap_or(0) * 1_000_000) {
bail!(
"Agent RAM size ({}MB) is greater than total RAM ({}MB)",
self.agent_ram.unwrap_or(0),
total_ram / 1_000_000
);
}
let cpus = match self.cpus {
AutoCpus::Auto => {
sys.refresh_cpu_all();
let num_cpus = num_cpus::get_physical() as u8;
(0..num_cpus).collect::<HashSet<u8>>()
}
AutoCpus::Count(num_cpus) => (0..(num_cpus as u8)).collect::<HashSet<u8>>(),
AutoCpus::List(s) => {
cpu_list_to_hashset(&s).map_err(|e| e.context("error parsing cpu list"))?
}
};
let cpus_per_agent = self.cpus_per_agent.unwrap_or(1);
let agent_ram = self
.agent_ram
.map(|i| i * 1_000_000)
.unwrap_or_else(|| total_ram / (cpus.len() / cpus_per_agent));
let time_budget = self.time_budget.unwrap_or(Duration::MAX);
let action_timeout = self.action_timeout.unwrap_or(Duration::MAX);
let time_margin = if self.time_budget.is_none() && self.action_timeout.is_none() {
Duration::ZERO
} else {
self.time_margin
};
Ok(Constraints {
total_ram,
agent_ram,
cpus,
cpus_per_agent,
time_budget,
action_timeout,
time_margin,
})
}
}
fn cpu_list_to_hashset(s: &str) -> anyhow::Result<HashSet<u8>> {
if s.is_empty() {
bail!("Empty string");
}
let mut set: HashSet<u8> = HashSet::new();
for item in s.split(',') {
let mut split = item.split('-');
let cnt = item.split('-').count();
if cnt == 1 {
let value: &str = split.next().unwrap();
let value: u8 = value
.parse()
.with_context(|| format!("could not parse {value}"))?;
set.insert(value);
} else if cnt == 2 {
let start: &str = split.next().unwrap();
let start: u8 = start
.parse()
.with_context(|| format!("could not parse {start}"))?;
let end: &str = split.next().unwrap();
let end: u8 = end
.parse()
.with_context(|| format!("could not parse {end}"))?;
let range = if start <= end {
start..=end
} else {
end..=start
};
for i in range {
set.insert(i);
}
} else {
bail!(
"each comma-separated item must be a number or a range (e.g. '0-3'), got '{item}'"
);
}
}
Ok(set)
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Constraints {
pub(crate) total_ram: usize,
pub(crate) agent_ram: usize,
pub(crate) cpus: HashSet<u8>,
pub(crate) cpus_per_agent: usize,
pub(crate) time_budget: Duration,
pub(crate) action_timeout: Duration,
pub(crate) time_margin: Duration,
}
impl Constraints {
pub fn builder() -> ConstraintsBuilder {
ConstraintsBuilder::new()
}
pub(crate) fn add(&mut self, res: Constraints) {
self.total_ram += res.total_ram;
self.cpus.extend(res.cpus);
}
pub(crate) fn take(&mut self, num_cpus: usize, ram: usize) -> Constraints {
let mut cpus = HashSet::new();
for _ in 0..num_cpus {
cpus.insert(self.take_one_cpu());
}
self.total_ram -= ram;
Constraints {
total_ram: ram,
cpus,
..*self
}
}
pub(crate) fn try_take(&mut self, num_cpus: usize, ram: usize) -> Option<Constraints> {
if self.cpus.len() >= num_cpus && self.total_ram >= ram {
Some(self.take(num_cpus, ram))
} else {
None
}
}
pub(crate) fn take_one_cpu(&mut self) -> u8 {
let cpu = *self.cpus.iter().next().unwrap();
self.cpus.take(&cpu).unwrap()
}
}