#[derive(Debug, Clone)]
pub struct WassersteinConfig {
pub num_projections: usize,
pub regularization: f64,
pub max_iterations: usize,
pub threshold: f64,
pub p: f64,
pub seed: Option<u64>,
}
impl Default for WassersteinConfig {
fn default() -> Self {
Self {
num_projections: 100,
regularization: 0.1,
max_iterations: 100,
threshold: 1e-6,
p: 2.0,
seed: None,
}
}
}
impl WassersteinConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_projections(mut self, n: usize) -> Self {
self.num_projections = n;
self
}
pub fn with_regularization(mut self, eps: f64) -> Self {
self.regularization = eps;
self
}
pub fn with_max_iterations(mut self, max_iter: usize) -> Self {
self.max_iterations = max_iter;
self
}
pub fn with_threshold(mut self, threshold: f64) -> Self {
self.threshold = threshold;
self
}
pub fn with_power(mut self, p: f64) -> Self {
self.p = p;
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn validate(&self) -> crate::Result<()> {
if self.num_projections == 0 {
return Err(crate::MathError::invalid_parameter(
"num_projections",
"must be > 0",
));
}
if self.regularization <= 0.0 {
return Err(crate::MathError::invalid_parameter(
"regularization",
"must be > 0",
));
}
if self.p <= 0.0 {
return Err(crate::MathError::invalid_parameter("p", "must be > 0"));
}
Ok(())
}
}