picard/
config.rs

1// src/config.rs
2
3//! Configuration for the PICARD algorithm.
4
5use crate::density::DensityType;
6use crate::error::{PicardError, Result};
7use ndarray::Array2;
8
9/// Configuration parameters for the PICARD algorithm.
10#[derive(Clone)]
11pub struct PicardConfig {
12    /// Density function to use for ICA.
13    pub density: DensityType,
14
15    /// Number of components to extract. If None, uses min(n_features, n_samples).
16    pub n_components: Option<usize>,
17
18    /// If true, uses Picard-O with orthogonal constraint.
19    pub ortho: bool,
20
21    /// If true, uses extended algorithm for sub/super-Gaussian sources.
22    /// Defaults to same value as `ortho` if not specified.
23    pub extended: Option<bool>,
24
25    /// If true, perform whitening on the data.
26    pub whiten: bool,
27
28    /// If true, center the data before processing.
29    pub centering: bool,
30
31    /// Maximum number of iterations.
32    pub max_iter: usize,
33
34    /// Convergence tolerance for gradient norm.
35    pub tol: f64,
36
37    /// Size of L-BFGS memory.
38    pub m: usize,
39
40    /// Maximum line search attempts.
41    pub ls_tries: usize,
42
43    /// Minimum eigenvalue for Hessian regularization.
44    pub lambda_min: f64,
45
46    /// Initial unmixing matrix. If None, uses random initialization.
47    pub w_init: Option<Array2<f64>>,
48
49    /// Number of FastICA iterations before PICARD. If None, skip FastICA.
50    pub fastica_it: Option<usize>,
51
52    /// Number of JADE iterations before PICARD. If None, skip JADE.
53    /// JADE (Joint Approximate Diagonalization of Eigenmatrices) can provide
54    /// a better warm start than FastICA for some data distributions.
55    pub jade_it: Option<usize>,
56
57    /// Random seed for reproducibility.
58    pub random_state: Option<u64>,
59
60    /// If true, print progress information.
61    pub verbose: bool,
62}
63
64impl Default for PicardConfig {
65    fn default() -> Self {
66        Self {
67            density: DensityType::default(),
68            n_components: None,
69            ortho: true,
70            extended: None,
71            whiten: true,
72            centering: true,
73            max_iter: 500,
74            tol: 1e-7,
75            m: 7,
76            ls_tries: 10,
77            lambda_min: 0.01,
78            w_init: None,
79            fastica_it: None,
80            jade_it: None,
81            random_state: None,
82            verbose: false,
83        }
84    }
85}
86
87impl PicardConfig {
88    /// Create a new configuration with default values.
89    pub fn new() -> Self {
90        Self::default()
91    }
92
93    /// Create a builder for constructing a configuration.
94    pub fn builder() -> ConfigBuilder {
95        ConfigBuilder::new()
96    }
97
98    /// Get the effective value of `extended` (defaults to `ortho` if not set).
99    pub fn effective_extended(&self) -> bool {
100        self.extended.unwrap_or(self.ortho)
101    }
102
103    /// Validate the configuration.
104    pub fn validate(&self) -> Result<()> {
105        if self.max_iter == 0 {
106            return Err(PicardError::InvalidConfig {
107                parameter: "max_iter".into(),
108                message: "must be greater than 0".into(),
109            });
110        }
111
112        if self.tol <= 0.0 {
113            return Err(PicardError::InvalidConfig {
114                parameter: "tol".into(),
115                message: "must be positive".into(),
116            });
117        }
118
119        if self.lambda_min <= 0.0 {
120            return Err(PicardError::InvalidConfig {
121                parameter: "lambda_min".into(),
122                message: "must be positive".into(),
123            });
124        }
125
126        if self.m == 0 {
127            return Err(PicardError::InvalidConfig {
128                parameter: "m".into(),
129                message: "L-BFGS memory size must be at least 1".into(),
130            });
131        }
132
133        if self.fastica_it.is_some() && self.jade_it.is_some() {
134            return Err(PicardError::InvalidConfig {
135                parameter: "jade_it".into(),
136                message: "cannot use both fastica_it and jade_it; choose one warm start method"
137                    .into(),
138            });
139        }
140
141        Ok(())
142    }
143}
144
145/// Builder for constructing `PicardConfig` with a fluent API.
146#[derive(Default)]
147pub struct ConfigBuilder {
148    config: PicardConfig,
149}
150
151impl ConfigBuilder {
152    /// Create a new builder with default values.
153    pub fn new() -> Self {
154        Self {
155            config: PicardConfig::default(),
156        }
157    }
158
159    /// Set the density function.
160    pub fn density(mut self, density: DensityType) -> Self {
161        self.config.density = density;
162        self
163    }
164
165    /// Set the number of components to extract.
166    pub fn n_components(mut self, n: usize) -> Self {
167        self.config.n_components = Some(n);
168        self
169    }
170
171    /// Enable or disable orthogonal constraint (Picard-O).
172    pub fn ortho(mut self, ortho: bool) -> Self {
173        self.config.ortho = ortho;
174        self
175    }
176
177    /// Enable or disable extended algorithm for mixed sub/super-Gaussian sources.
178    pub fn extended(mut self, extended: bool) -> Self {
179        self.config.extended = Some(extended);
180        self
181    }
182
183    /// Enable or disable whitening.
184    pub fn whiten(mut self, whiten: bool) -> Self {
185        self.config.whiten = whiten;
186        self
187    }
188
189    /// Enable or disable centering.
190    pub fn centering(mut self, centering: bool) -> Self {
191        self.config.centering = centering;
192        self
193    }
194
195    /// Set the maximum number of iterations.
196    pub fn max_iter(mut self, max_iter: usize) -> Self {
197        self.config.max_iter = max_iter;
198        self
199    }
200
201    /// Set the convergence tolerance.
202    pub fn tol(mut self, tol: f64) -> Self {
203        self.config.tol = tol;
204        self
205    }
206
207    /// Set the L-BFGS memory size.
208    pub fn m(mut self, m: usize) -> Self {
209        self.config.m = m;
210        self
211    }
212
213    /// Set the maximum line search attempts.
214    pub fn ls_tries(mut self, ls_tries: usize) -> Self {
215        self.config.ls_tries = ls_tries;
216        self
217    }
218
219    /// Set the minimum eigenvalue for Hessian regularization.
220    pub fn lambda_min(mut self, lambda_min: f64) -> Self {
221        self.config.lambda_min = lambda_min;
222        self
223    }
224
225    /// Set the initial unmixing matrix.
226    pub fn w_init(mut self, w_init: Array2<f64>) -> Self {
227        self.config.w_init = Some(w_init);
228        self
229    }
230
231    /// Set the number of FastICA pre-iterations.
232    ///
233    /// Note: Cannot be used together with `jade_it`.
234    pub fn fastica_it(mut self, iterations: usize) -> Self {
235        self.config.fastica_it = Some(iterations);
236        self
237    }
238
239    /// Set the number of JADE pre-iterations.
240    ///
241    /// JADE (Joint Approximate Diagonalization of Eigenmatrices) uses
242    /// fourth-order cumulants and Jacobi rotations for joint diagonalization.
243    /// It can provide a better warm start than FastICA for some distributions.
244    ///
245    /// Note: Cannot be used together with `fastica_it`.
246    pub fn jade_it(mut self, iterations: usize) -> Self {
247        self.config.jade_it = Some(iterations);
248        self
249    }
250
251    /// Set the random seed.
252    pub fn random_state(mut self, seed: u64) -> Self {
253        self.config.random_state = Some(seed);
254        self
255    }
256
257    /// Enable or disable verbose output.
258    pub fn verbose(mut self, verbose: bool) -> Self {
259        self.config.verbose = verbose;
260        self
261    }
262
263    /// Build the configuration.
264    pub fn build(self) -> PicardConfig {
265        self.config
266    }
267
268    /// Build and validate the configuration.
269    pub fn build_validated(self) -> Result<PicardConfig> {
270        self.config.validate()?;
271        Ok(self.config)
272    }
273}