Skip to main content

fdars_core/gmm/
cluster.rs

1//! Clustering wrapper and prediction for GMM.
2
3use super::em::{e_step, gmm_em, hard_assignments, resp_to_membership};
4use super::init::build_features;
5use super::{CovType, GmmClusterResult, GmmResult};
6use crate::basis::projection::ProjectionBasisType;
7use crate::error::FdarError;
8use crate::matrix::FdMatrix;
9
10/// Run multiple initializations for a single K and return the best by log-likelihood.
11pub(super) fn run_multiple_inits(
12    features: &[Vec<f64>],
13    k: usize,
14    cov_type: CovType,
15    max_iter: usize,
16    tol: f64,
17    n_init: usize,
18    base_seed: u64,
19) -> Option<GmmResult> {
20    let mut best: Option<GmmResult> = None;
21    for init in 0..n_init.max(1) {
22        let seed = base_seed.wrapping_add(init as u64 * 1000 + k as u64);
23        if let Ok(result) = gmm_em(features, k, cov_type, max_iter, tol, seed) {
24            let is_better = best
25                .as_ref()
26                .map_or(true, |b| result.log_likelihood > b.log_likelihood);
27            if is_better {
28                best = Some(result);
29            }
30        }
31    }
32    best
33}
34
35/// Configuration for GMM-based functional clustering.
36///
37/// Collects all tuning parameters for [`gmm_cluster_with_config`], with sensible
38/// defaults obtained via [`GmmClusterConfig::default()`].
39///
40/// # Example
41/// ```no_run
42/// use fdars_core::gmm::cluster::GmmClusterConfig;
43/// use fdars_core::gmm::CovType;
44///
45/// let mut config = GmmClusterConfig::default();
46/// config.nbasis = 10;
47/// config.cov_type = CovType::Full;
48/// ```
49#[derive(Debug, Clone, PartialEq)]
50#[non_exhaustive]
51pub struct GmmClusterConfig {
52    /// Number of basis functions for projection (default: 5).
53    pub nbasis: usize,
54    /// Basis type for projection (default: `Bspline`).
55    pub basis_type: ProjectionBasisType,
56    /// Covariance structure (default: `Diagonal`).
57    pub cov_type: CovType,
58    /// Scaling factor for covariates (default: 1.0).
59    pub cov_weight: f64,
60    /// Maximum EM iterations per K (default: 200).
61    pub max_iter: usize,
62    /// Convergence tolerance (default: 1e-6).
63    pub tol: f64,
64    /// Number of random initializations per K (default: 3).
65    pub n_init: usize,
66    /// Base random seed (default: 42).
67    pub seed: u64,
68    /// If true, select K by ICL; otherwise by BIC (default: false).
69    pub use_icl: bool,
70}
71
72impl Default for GmmClusterConfig {
73    fn default() -> Self {
74        Self {
75            nbasis: 5,
76            basis_type: ProjectionBasisType::Bspline,
77            cov_type: CovType::Diagonal,
78            cov_weight: 1.0,
79            max_iter: 200,
80            tol: 1e-6,
81            n_init: 3,
82            seed: 42,
83            use_icl: false,
84        }
85    }
86}
87
88/// GMM clustering using a [`GmmClusterConfig`] struct.
89///
90/// This is the config-based alternative to [`gmm_cluster`]. It takes data
91/// parameters directly and reads all tuning parameters from the config.
92///
93/// # Arguments
94/// * `data` — Functional data matrix (n x m)
95/// * `argvals` — Evaluation points (length m)
96/// * `covariates` — Optional scalar covariates (n x p)
97/// * `k_range` — Range of K values to try
98/// * `config` — Tuning parameters
99///
100/// # Errors
101///
102/// Returns [`FdarError::ComputationFailed`] if basis projection fails or no
103/// valid GMM fit is found for any K in the given range.
104#[must_use = "expensive computation whose result should not be discarded"]
105pub fn gmm_cluster_with_config(
106    data: &FdMatrix,
107    argvals: &[f64],
108    covariates: Option<&FdMatrix>,
109    k_range: &[usize],
110    config: &GmmClusterConfig,
111) -> Result<GmmClusterResult, FdarError> {
112    gmm_cluster(
113        data,
114        argvals,
115        covariates,
116        k_range,
117        config.nbasis,
118        config.basis_type,
119        config.cov_type,
120        config.cov_weight,
121        config.max_iter,
122        config.tol,
123        config.n_init,
124        config.seed,
125        config.use_icl,
126    )
127}
128
129/// Main clustering function: project curves onto basis, concatenate covariates,
130/// and fit GMM with automatic K selection.
131///
132/// # Arguments
133/// * `data` — Functional data matrix (n × m)
134/// * `argvals` — Evaluation points (length m)
135/// * `covariates` — Optional scalar covariates (n × p)
136/// * `k_range` — Range of K values to try (e.g., `2..=5`)
137/// * `nbasis` — Number of basis functions for projection
138/// * `basis_type` — Basis type for projection
139/// * `cov_type` — Covariance structure
140/// * `cov_weight` — Scaling factor for covariates (default 1.0)
141/// * `max_iter` — Maximum EM iterations per K
142/// * `tol` — Convergence tolerance
143/// * `n_init` — Number of random initializations per K
144/// * `seed` — Base random seed
145/// * `use_icl` — If true, select K by ICL; otherwise by BIC
146///
147/// # Errors
148///
149/// Returns [`FdarError::ComputationFailed`] if basis projection fails or no
150/// valid GMM fit is found for any K in the given range.
151#[must_use = "expensive computation whose result should not be discarded"]
152pub fn gmm_cluster(
153    data: &FdMatrix,
154    argvals: &[f64],
155    covariates: Option<&FdMatrix>,
156    k_range: &[usize],
157    nbasis: usize,
158    basis_type: ProjectionBasisType,
159    cov_type: CovType,
160    cov_weight: f64,
161    max_iter: usize,
162    tol: f64,
163    n_init: usize,
164    seed: u64,
165    use_icl: bool,
166) -> Result<GmmClusterResult, FdarError> {
167    let (features, _d) = build_features(data, argvals, covariates, nbasis, basis_type, cov_weight)
168        .ok_or_else(|| FdarError::ComputationFailed {
169            operation: "build_features",
170            detail: "basis projection failed; check that nbasis <= number of evaluation points and data is non-degenerate".to_string(),
171        })?;
172
173    let mut bic_values = Vec::new();
174    let mut icl_values = Vec::new();
175    let mut best_result: Option<GmmResult> = None;
176    let mut best_criterion = f64::INFINITY;
177
178    for &k in k_range {
179        let best_for_k = run_multiple_inits(&features, k, cov_type, max_iter, tol, n_init, seed);
180        let Some(result) = best_for_k else {
181            continue;
182        };
183
184        bic_values.push((k, result.bic));
185        icl_values.push((k, result.icl));
186
187        let criterion = if use_icl { result.icl } else { result.bic };
188        if criterion < best_criterion {
189            best_criterion = criterion;
190            best_result = Some(result);
191        }
192    }
193
194    best_result
195        .map(|best| GmmClusterResult {
196            best,
197            bic_values,
198            icl_values,
199        })
200        .ok_or_else(|| FdarError::ComputationFailed {
201            operation: "gmm_cluster",
202            detail: "no valid GMM fit found for any K in range; try widening k_range, increasing n_init, or reducing nbasis".to_string(),
203        })
204}
205
206/// Predict cluster assignments for new observations.
207///
208/// # Arguments
209/// * `new_data` — New functional data (n_new × m)
210/// * `argvals` — Evaluation points
211/// * `new_covariates` — Optional scalar covariates for new data
212/// * `result` — Fitted GMM result
213/// * `nbasis` — Number of basis functions (must match training)
214/// * `basis_type` — Basis type (must match training)
215/// * `cov_weight` — Covariate weight (must match training)
216/// * `cov_type` — Covariance type (must match training)
217///
218/// # Errors
219///
220/// Returns [`FdarError::ComputationFailed`] if basis projection fails for the
221/// new data.
222#[must_use = "expensive computation whose result should not be discarded"]
223pub fn predict_gmm(
224    new_data: &FdMatrix,
225    argvals: &[f64],
226    new_covariates: Option<&FdMatrix>,
227    result: &GmmResult,
228    nbasis: usize,
229    basis_type: ProjectionBasisType,
230    cov_weight: f64,
231    cov_type: CovType,
232) -> Result<(Vec<usize>, FdMatrix), FdarError> {
233    let (features, _d) = build_features(
234        new_data,
235        argvals,
236        new_covariates,
237        nbasis,
238        basis_type,
239        cov_weight,
240    )
241    .ok_or_else(|| FdarError::ComputationFailed {
242        operation: "build_features",
243        detail: "basis projection failed for new data; ensure new_data has the same number of evaluation points as the training data".to_string(),
244    })?;
245
246    let k = result.k;
247    let d = result.d;
248    let n = features.len();
249
250    let (resp, _ll) = e_step(
251        &features,
252        &result.means,
253        &result.covariances,
254        &result.weights,
255        k,
256        d,
257        cov_type,
258    );
259
260    let cluster = hard_assignments(&resp, n, k);
261    let membership = resp_to_membership(&resp, n, k);
262
263    Ok((cluster, membership))
264}