use crate::types::BitWidth;
#[derive(Copy, Clone, Debug)]
pub struct Config {
pub bits: u32,
pub threshold: f64,
pub seed: Option<u64>,
}
pub const DEFAULT_CONFIG: Config = Config {
bits: 12,
threshold: 0.2,
seed: None,
};
impl Default for Config {
fn default() -> Self {
DEFAULT_CONFIG
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum Error {
InvalidArg,
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error::InvalidArg => f.write_str("onpair: invalid argument"),
}
}
}
impl std::error::Error for Error {}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub(crate) struct FixedThreshold {
pub(crate) value: u8,
}
#[derive(Copy, Clone, Debug, PartialEq)]
pub(crate) struct DynamicThreshold {
pub(crate) sample_fraction: f64,
}
impl Default for DynamicThreshold {
fn default() -> Self {
Self {
sample_fraction: 0.2,
}
}
}
#[derive(Copy, Clone, Debug)]
#[allow(dead_code)] pub(crate) enum ThresholdSpec {
Fixed(FixedThreshold),
Dynamic(DynamicThreshold),
}
impl Default for ThresholdSpec {
fn default() -> Self {
Self::Dynamic(DynamicThreshold::default())
}
}
#[derive(Clone, Debug)]
pub(crate) struct TrainingConfig {
pub(crate) bits: BitWidth,
pub(crate) threshold: ThresholdSpec,
pub(crate) seed: Option<u64>,
}
impl Default for TrainingConfig {
fn default() -> Self {
Self {
bits: 16,
threshold: ThresholdSpec::default(),
seed: None,
}
}
}
impl From<Config> for TrainingConfig {
fn from(c: Config) -> Self {
Self {
bits: c.bits as BitWidth,
threshold: ThresholdSpec::Dynamic(DynamicThreshold {
sample_fraction: c.threshold,
}),
seed: c.seed,
}
}
}
pub(crate) fn validate_config(cfg: Config) -> Result<(), Error> {
if !(9..=16).contains(&cfg.bits) {
return Err(Error::InvalidArg);
}
if !(cfg.threshold > 0.0 && cfg.threshold <= 1.0) {
return Err(Error::InvalidArg);
}
Ok(())
}