use crate::core::{ConfigError, FormicaXError};
use std::time::Duration;
#[derive(Debug, Clone, PartialEq, Default)]
pub enum GMMVariant {
#[default]
EM,
VariationalBayes,
Robust,
}
#[derive(Debug, Clone)]
pub struct GMMConfig {
pub n_components: usize,
pub variant: GMMVariant,
pub max_iterations: usize,
pub tolerance: f64,
pub random_seed: Option<u64>,
pub covariance_type: CovarianceType,
pub regularization: f64,
pub parallel: bool,
pub num_threads: usize,
pub timeout: Option<Duration>,
}
impl Default for GMMConfig {
fn default() -> Self {
Self {
n_components: 3,
variant: GMMVariant::EM,
max_iterations: 100,
tolerance: 1e-6,
random_seed: None,
covariance_type: CovarianceType::Full,
regularization: 1e-6,
parallel: false,
num_threads: num_cpus::get(),
timeout: Some(Duration::from_secs(30)),
}
}
}
#[derive(Debug, Clone, PartialEq, Default)]
pub enum CovarianceType {
#[default]
Full,
Diagonal,
Tied,
Spherical,
}
impl GMMConfig {
pub fn validate(&self) -> Result<(), FormicaXError> {
if self.n_components == 0 {
return Err(FormicaXError::Config(ConfigError::ValidationFailed {
message: "n_components must be greater than 0".to_string(),
}));
}
if self.max_iterations == 0 {
return Err(FormicaXError::Config(ConfigError::ValidationFailed {
message: "max_iterations must be greater than 0".to_string(),
}));
}
if self.tolerance <= 0.0 {
return Err(FormicaXError::Config(ConfigError::ValidationFailed {
message: "tolerance must be positive".to_string(),
}));
}
if self.regularization < 0.0 {
return Err(FormicaXError::Config(ConfigError::ValidationFailed {
message: "regularization must be non-negative".to_string(),
}));
}
if self.num_threads == 0 {
return Err(FormicaXError::Config(ConfigError::ValidationFailed {
message: "num_threads must be greater than 0".to_string(),
}));
}
Ok(())
}
}
#[derive(Debug, Default)]
pub struct GMMConfigBuilder {
config: GMMConfig,
}
impl GMMConfigBuilder {
pub fn new() -> Self {
Self {
config: GMMConfig::default(),
}
}
pub fn n_components(mut self, n_components: usize) -> Self {
self.config.n_components = n_components;
self
}
pub fn variant(mut self, variant: GMMVariant) -> Self {
self.config.variant = variant;
self
}
pub fn max_iterations(mut self, max_iterations: usize) -> Self {
self.config.max_iterations = max_iterations;
self
}
pub fn tolerance(mut self, tolerance: f64) -> Self {
self.config.tolerance = tolerance;
self
}
pub fn random_seed(mut self, random_seed: u64) -> Self {
self.config.random_seed = Some(random_seed);
self
}
pub fn covariance_type(mut self, covariance_type: CovarianceType) -> Self {
self.config.covariance_type = covariance_type;
self
}
pub fn regularization(mut self, regularization: f64) -> Self {
self.config.regularization = regularization;
self
}
pub fn parallel(mut self, parallel: bool) -> Self {
self.config.parallel = parallel;
self
}
pub fn num_threads(mut self, num_threads: usize) -> Self {
self.config.num_threads = num_threads;
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.config.timeout = Some(timeout);
self
}
pub fn build(self) -> Result<GMMConfig, FormicaXError> {
self.config.validate()?;
Ok(self.config)
}
}
impl GMMConfig {
pub fn builder() -> GMMConfigBuilder {
GMMConfigBuilder::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gmm_config_default() {
let config = GMMConfig::default();
assert_eq!(config.n_components, 3);
assert_eq!(config.variant, GMMVariant::EM);
assert_eq!(config.max_iterations, 100);
assert_eq!(config.tolerance, 1e-6);
assert!(config.random_seed.is_none());
assert_eq!(config.covariance_type, CovarianceType::Full);
assert_eq!(config.regularization, 1e-6);
assert!(!config.parallel);
}
#[test]
fn test_gmm_config_builder() {
let config = GMMConfig::builder()
.n_components(5)
.variant(GMMVariant::VariationalBayes)
.max_iterations(200)
.tolerance(1e-8)
.random_seed(42)
.covariance_type(CovarianceType::Diagonal)
.regularization(1e-5)
.parallel(true)
.num_threads(4)
.timeout(Duration::from_secs(60))
.build()
.unwrap();
assert_eq!(config.n_components, 5);
assert_eq!(config.variant, GMMVariant::VariationalBayes);
assert_eq!(config.max_iterations, 200);
assert_eq!(config.tolerance, 1e-8);
assert_eq!(config.random_seed, Some(42));
assert_eq!(config.covariance_type, CovarianceType::Diagonal);
assert_eq!(config.regularization, 1e-5);
assert!(config.parallel);
assert_eq!(config.num_threads, 4);
assert_eq!(config.timeout, Some(Duration::from_secs(60)));
}
#[test]
fn test_gmm_config_validation_success() {
let config = GMMConfig::builder()
.n_components(3)
.max_iterations(100)
.tolerance(1e-6)
.regularization(1e-6)
.num_threads(1)
.build()
.unwrap();
assert!(config.validate().is_ok());
}
#[test]
fn test_gmm_config_validation_n_components_zero() {
let config = GMMConfig::builder().n_components(0).build();
assert!(config.is_err());
}
#[test]
fn test_gmm_config_validation_max_iterations_zero() {
let config = GMMConfig::builder().max_iterations(0).build();
assert!(config.is_err());
}
#[test]
fn test_gmm_config_validation_negative_tolerance() {
let config = GMMConfig::builder().tolerance(-1.0).build();
assert!(config.is_err());
}
#[test]
fn test_gmm_config_validation_negative_regularization() {
let config = GMMConfig::builder().regularization(-1.0).build();
assert!(config.is_err());
}
#[test]
fn test_gmm_config_validation_zero_threads() {
let config = GMMConfig::builder().num_threads(0).build();
assert!(config.is_err());
}
}