use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum FeatureFamily {
#[default]
Mean,
Robust3,
CenteredSquare,
}
impl std::fmt::Display for FeatureFamily {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FeatureFamily::Mean => write!(f, "Mean"),
FeatureFamily::Robust3 => write!(f, "Robust3 (median, P10, P90)"),
FeatureFamily::CenteredSquare => write!(f, "CenteredSquare (variance)"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PartitionConfig {
pub num_partitions: usize,
pub overlap: bool,
pub overlap_fraction: f64,
}
impl Default for PartitionConfig {
fn default() -> Self {
Self {
num_partitions: 32,
overlap: false,
overlap_fraction: 0.25,
}
}
}
impl PartitionConfig {
pub fn new(num_partitions: usize) -> Self {
Self {
num_partitions,
..Default::default()
}
}
pub fn with_overlap(mut self, fraction: f64) -> Self {
self.overlap = true;
self.overlap_fraction = fraction.clamp(0.0, 0.5);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PreprocessingConfig {
pub winsorize_lower: f64,
pub winsorize_upper: f64,
pub normalize_mean: bool,
pub normalize_variance: bool,
}
impl Default for PreprocessingConfig {
fn default() -> Self {
Self {
winsorize_lower: 0.01,
winsorize_upper: 99.99,
normalize_mean: true,
normalize_variance: false,
}
}
}
impl PreprocessingConfig {
pub fn none() -> Self {
Self {
winsorize_lower: 0.0,
winsorize_upper: 100.0,
normalize_mean: false,
normalize_variance: false,
}
}
pub fn standard() -> Self {
Self::default()
}
pub fn full() -> Self {
Self {
normalize_variance: true,
..Default::default()
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub feature_family: FeatureFamily,
pub partition: PartitionConfig,
pub preprocessing: PreprocessingConfig,
pub floor_multiplier: f64,
pub stage_wise: bool,
pub seed: Option<u64>,
pub bootstrap_iterations: usize,
pub pass_threshold: f64,
pub fail_threshold: f64,
}
impl Default for Config {
fn default() -> Self {
Self {
feature_family: FeatureFamily::default(),
partition: PartitionConfig::default(),
preprocessing: PreprocessingConfig::default(),
floor_multiplier: 5.0,
stage_wise: true,
seed: None,
bootstrap_iterations: 2000,
pass_threshold: 0.05,
fail_threshold: 0.95,
}
}
}
impl Config {
pub fn new() -> Self {
Self::default()
}
pub fn with_feature_family(mut self, family: FeatureFamily) -> Self {
self.feature_family = family;
self
}
pub fn with_partitions(mut self, num_partitions: usize) -> Self {
self.partition.num_partitions = num_partitions;
self
}
pub fn with_floor_multiplier(mut self, multiplier: f64) -> Self {
self.floor_multiplier = multiplier;
self
}
pub fn without_stages(mut self) -> Self {
self.stage_wise = false;
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn feature_dimension(&self) -> usize {
match self.feature_family {
FeatureFamily::Mean => self.partition.num_partitions,
FeatureFamily::Robust3 => 3 * self.partition.num_partitions,
FeatureFamily::CenteredSquare => self.partition.num_partitions,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_feature_dimension() {
let mut config = Config::default();
config.partition.num_partitions = 32;
config.feature_family = FeatureFamily::Mean;
assert_eq!(config.feature_dimension(), 32);
config.feature_family = FeatureFamily::Robust3;
assert_eq!(config.feature_dimension(), 96);
config.feature_family = FeatureFamily::CenteredSquare;
assert_eq!(config.feature_dimension(), 32);
}
#[test]
fn test_config_builder() {
let config = Config::new()
.with_feature_family(FeatureFamily::Robust3)
.with_partitions(64)
.with_floor_multiplier(3.0)
.with_seed(42);
assert_eq!(config.feature_family, FeatureFamily::Robust3);
assert_eq!(config.partition.num_partitions, 64);
assert_eq!(config.floor_multiplier, 3.0);
assert_eq!(config.seed, Some(42));
}
#[test]
fn test_preprocessing_presets() {
let none = PreprocessingConfig::none();
assert!(!none.normalize_mean);
assert!(!none.normalize_variance);
let standard = PreprocessingConfig::standard();
assert!(standard.normalize_mean);
assert!(!standard.normalize_variance);
let full = PreprocessingConfig::full();
assert!(full.normalize_mean);
assert!(full.normalize_variance);
}
}