Skip to main content

fdars_core/gmm/
cluster.rs

1//! Clustering wrapper and prediction for GMM.
2
3use super::em::{e_step, gmm_em, hard_assignments, resp_to_membership};
4use super::init::build_features;
5use super::{CovType, GmmClusterResult, GmmResult};
6use crate::basis::projection::ProjectionBasisType;
7use crate::error::FdarError;
8use crate::matrix::FdMatrix;
9
10/// Run multiple initializations for a single K and return the best by log-likelihood.
11pub(super) fn run_multiple_inits(
12    features: &[Vec<f64>],
13    k: usize,
14    cov_type: CovType,
15    max_iter: usize,
16    tol: f64,
17    n_init: usize,
18    base_seed: u64,
19) -> Option<GmmResult> {
20    let mut best: Option<GmmResult> = None;
21    for init in 0..n_init.max(1) {
22        let seed = base_seed.wrapping_add(init as u64 * 1000 + k as u64);
23        if let Ok(result) = gmm_em(features, k, cov_type, max_iter, tol, seed) {
24            let is_better = best
25                .as_ref()
26                .map_or(true, |b| result.log_likelihood > b.log_likelihood);
27            if is_better {
28                best = Some(result);
29            }
30        }
31    }
32    best
33}
34
35/// Configuration for GMM-based functional clustering.
36///
37/// Collects all tuning parameters for [`gmm_cluster_with_config`], with sensible
38/// defaults obtained via [`GmmClusterConfig::default()`].
39///
40/// # Example
41/// ```no_run
42/// use fdars_core::gmm::cluster::GmmClusterConfig;
43/// use fdars_core::gmm::CovType;
44/// use fdars_core::basis::ProjectionBasisType;
45///
46/// let config = GmmClusterConfig {
47///     nbasis: 10,
48///     cov_type: CovType::Full,
49///     ..GmmClusterConfig::default()
50/// };
51/// ```
52#[derive(Debug, Clone)]
53pub struct GmmClusterConfig {
54    /// Number of basis functions for projection (default: 5).
55    pub nbasis: usize,
56    /// Basis type for projection (default: `Bspline`).
57    pub basis_type: ProjectionBasisType,
58    /// Covariance structure (default: `Diagonal`).
59    pub cov_type: CovType,
60    /// Scaling factor for covariates (default: 1.0).
61    pub cov_weight: f64,
62    /// Maximum EM iterations per K (default: 200).
63    pub max_iter: usize,
64    /// Convergence tolerance (default: 1e-6).
65    pub tol: f64,
66    /// Number of random initializations per K (default: 3).
67    pub n_init: usize,
68    /// Base random seed (default: 42).
69    pub seed: u64,
70    /// If true, select K by ICL; otherwise by BIC (default: false).
71    pub use_icl: bool,
72}
73
74impl Default for GmmClusterConfig {
75    fn default() -> Self {
76        Self {
77            nbasis: 5,
78            basis_type: ProjectionBasisType::Bspline,
79            cov_type: CovType::Diagonal,
80            cov_weight: 1.0,
81            max_iter: 200,
82            tol: 1e-6,
83            n_init: 3,
84            seed: 42,
85            use_icl: false,
86        }
87    }
88}
89
90/// GMM clustering using a [`GmmClusterConfig`] struct.
91///
92/// This is the config-based alternative to [`gmm_cluster`]. It takes data
93/// parameters directly and reads all tuning parameters from the config.
94///
95/// # Arguments
96/// * `data` — Functional data matrix (n x m)
97/// * `argvals` — Evaluation points (length m)
98/// * `covariates` — Optional scalar covariates (n x p)
99/// * `k_range` — Range of K values to try
100/// * `config` — Tuning parameters
101///
102/// # Errors
103///
104/// Returns [`FdarError::ComputationFailed`] if basis projection fails or no
105/// valid GMM fit is found for any K in the given range.
106#[must_use = "expensive computation whose result should not be discarded"]
107pub fn gmm_cluster_with_config(
108    data: &FdMatrix,
109    argvals: &[f64],
110    covariates: Option<&FdMatrix>,
111    k_range: &[usize],
112    config: &GmmClusterConfig,
113) -> Result<GmmClusterResult, FdarError> {
114    gmm_cluster(
115        data,
116        argvals,
117        covariates,
118        k_range,
119        config.nbasis,
120        config.basis_type,
121        config.cov_type,
122        config.cov_weight,
123        config.max_iter,
124        config.tol,
125        config.n_init,
126        config.seed,
127        config.use_icl,
128    )
129}
130
131/// Main clustering function: project curves onto basis, concatenate covariates,
132/// and fit GMM with automatic K selection.
133///
134/// # Arguments
135/// * `data` — Functional data matrix (n × m)
136/// * `argvals` — Evaluation points (length m)
137/// * `covariates` — Optional scalar covariates (n × p)
138/// * `k_range` — Range of K values to try (e.g., `2..=5`)
139/// * `nbasis` — Number of basis functions for projection
140/// * `basis_type` — Basis type for projection
141/// * `cov_type` — Covariance structure
142/// * `cov_weight` — Scaling factor for covariates (default 1.0)
143/// * `max_iter` — Maximum EM iterations per K
144/// * `tol` — Convergence tolerance
145/// * `n_init` — Number of random initializations per K
146/// * `seed` — Base random seed
147/// * `use_icl` — If true, select K by ICL; otherwise by BIC
148///
149/// # Errors
150///
151/// Returns [`FdarError::ComputationFailed`] if basis projection fails or no
152/// valid GMM fit is found for any K in the given range.
153#[must_use = "expensive computation whose result should not be discarded"]
154pub fn gmm_cluster(
155    data: &FdMatrix,
156    argvals: &[f64],
157    covariates: Option<&FdMatrix>,
158    k_range: &[usize],
159    nbasis: usize,
160    basis_type: ProjectionBasisType,
161    cov_type: CovType,
162    cov_weight: f64,
163    max_iter: usize,
164    tol: f64,
165    n_init: usize,
166    seed: u64,
167    use_icl: bool,
168) -> Result<GmmClusterResult, FdarError> {
169    let (features, _d) = build_features(data, argvals, covariates, nbasis, basis_type, cov_weight)
170        .ok_or_else(|| FdarError::ComputationFailed {
171            operation: "build_features",
172            detail: "basis projection failed; check that nbasis <= number of evaluation points and data is non-degenerate".to_string(),
173        })?;
174
175    let mut bic_values = Vec::new();
176    let mut icl_values = Vec::new();
177    let mut best_result: Option<GmmResult> = None;
178    let mut best_criterion = f64::INFINITY;
179
180    for &k in k_range {
181        let best_for_k = run_multiple_inits(&features, k, cov_type, max_iter, tol, n_init, seed);
182        let Some(result) = best_for_k else {
183            continue;
184        };
185
186        bic_values.push((k, result.bic));
187        icl_values.push((k, result.icl));
188
189        let criterion = if use_icl { result.icl } else { result.bic };
190        if criterion < best_criterion {
191            best_criterion = criterion;
192            best_result = Some(result);
193        }
194    }
195
196    best_result
197        .map(|best| GmmClusterResult {
198            best,
199            bic_values,
200            icl_values,
201        })
202        .ok_or_else(|| FdarError::ComputationFailed {
203            operation: "gmm_cluster",
204            detail: "no valid GMM fit found for any K in range; try widening k_range, increasing n_init, or reducing nbasis".to_string(),
205        })
206}
207
208/// Predict cluster assignments for new observations.
209///
210/// # Arguments
211/// * `new_data` — New functional data (n_new × m)
212/// * `argvals` — Evaluation points
213/// * `new_covariates` — Optional scalar covariates for new data
214/// * `result` — Fitted GMM result
215/// * `nbasis` — Number of basis functions (must match training)
216/// * `basis_type` — Basis type (must match training)
217/// * `cov_weight` — Covariate weight (must match training)
218/// * `cov_type` — Covariance type (must match training)
219///
220/// # Errors
221///
222/// Returns [`FdarError::ComputationFailed`] if basis projection fails for the
223/// new data.
224#[must_use = "expensive computation whose result should not be discarded"]
225pub fn predict_gmm(
226    new_data: &FdMatrix,
227    argvals: &[f64],
228    new_covariates: Option<&FdMatrix>,
229    result: &GmmResult,
230    nbasis: usize,
231    basis_type: ProjectionBasisType,
232    cov_weight: f64,
233    cov_type: CovType,
234) -> Result<(Vec<usize>, FdMatrix), FdarError> {
235    let (features, _d) = build_features(
236        new_data,
237        argvals,
238        new_covariates,
239        nbasis,
240        basis_type,
241        cov_weight,
242    )
243    .ok_or_else(|| FdarError::ComputationFailed {
244        operation: "build_features",
245        detail: "basis projection failed for new data; ensure new_data has the same number of evaluation points as the training data".to_string(),
246    })?;
247
248    let k = result.k;
249    let d = result.d;
250    let n = features.len();
251
252    let (resp, _ll) = e_step(
253        &features,
254        &result.means,
255        &result.covariances,
256        &result.weights,
257        k,
258        d,
259        cov_type,
260    );
261
262    let cluster = hard_assignments(&resp, n, k);
263    let membership = resp_to_membership(&resp, n, k);
264
265    Ok((cluster, membership))
266}