use log::{error, warn};
use serde::{Deserialize, Serialize};
pub const DEFAULT_MAX_THREAD_COUNT: u8 = 4;
#[doc = include_str!("./shared_docs/max_thread_count.md")]
#[derive(Copy, Clone, Debug, Deserialize, Serialize, PartialEq, PartialOrd)]
pub struct MaxThreadCount(u8);
impl MaxThreadCount {
pub fn as_u8(&self) -> u8 {
self.0
}
}
impl From<u8> for MaxThreadCount {
fn from(max_thread_count: u8) -> Self {
Self(max_thread_count)
}
}
impl Default for MaxThreadCount {
fn default() -> Self {
MaxThreadCount(MACHINE_PARALLELISM.with(|opt| match *opt.borrow() {
None => {
warn!(
"Machine parallelism not set, defaulting max thread count to {}",
DEFAULT_MAX_THREAD_COUNT
);
DEFAULT_MAX_THREAD_COUNT
}
Some(par) => par,
}))
}
}
use std::str::FromStr;
impl FromStr for MaxThreadCount {
type Err = MaxThreadCountError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(MaxThreadCount(u8::from_str(s)?))
}
}
use clap::builder::{OsStr, Str};
impl From<MaxThreadCount> for OsStr {
fn from(max_thread_count: MaxThreadCount) -> OsStr {
OsStr::from(Str::from(max_thread_count.as_u8().to_string()))
}
}
#[derive(thiserror::Error, Debug)]
pub enum MaxThreadCountError {
#[error("Malformed string input for u8 type")]
MalformedString(#[from] std::num::ParseIntError),
}
use std::cell::RefCell;
thread_local!(pub static MACHINE_PARALLELISM: RefCell<Option<u8>> = RefCell::new(None));
pub fn initialize_machine_parallelism() {
MACHINE_PARALLELISM.with(|opt| {
*opt.borrow_mut() = std::thread::available_parallelism()
.map_err(|err| {
error!("Problem accessing machine parallelism: {}", err);
err
})
.map_or(None, |par| Some(par.get() as u8))
});
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_without_initializing_machine_parallelism() {
assert_eq!(MaxThreadCount::default().as_u8(), DEFAULT_MAX_THREAD_COUNT);
}
}