1use super::em::{e_step, gmm_em, hard_assignments, resp_to_membership};
4use super::init::build_features;
5use super::{CovType, GmmClusterResult, GmmResult};
6use crate::basis::projection::ProjectionBasisType;
7use crate::error::FdarError;
8use crate::matrix::FdMatrix;
9
10pub(super) fn run_multiple_inits(
12 features: &[Vec<f64>],
13 k: usize,
14 cov_type: CovType,
15 max_iter: usize,
16 tol: f64,
17 n_init: usize,
18 base_seed: u64,
19) -> Option<GmmResult> {
20 let mut best: Option<GmmResult> = None;
21 for init in 0..n_init.max(1) {
22 let seed = base_seed.wrapping_add(init as u64 * 1000 + k as u64);
23 if let Ok(result) = gmm_em(features, k, cov_type, max_iter, tol, seed) {
24 let is_better = best
25 .as_ref()
26 .map_or(true, |b| result.log_likelihood > b.log_likelihood);
27 if is_better {
28 best = Some(result);
29 }
30 }
31 }
32 best
33}
34
35#[derive(Debug, Clone, PartialEq)]
50#[non_exhaustive]
51pub struct GmmClusterConfig {
52 pub nbasis: usize,
54 pub basis_type: ProjectionBasisType,
56 pub cov_type: CovType,
58 pub cov_weight: f64,
60 pub max_iter: usize,
62 pub tol: f64,
64 pub n_init: usize,
66 pub seed: u64,
68 pub use_icl: bool,
70}
71
72impl Default for GmmClusterConfig {
73 fn default() -> Self {
74 Self {
75 nbasis: 5,
76 basis_type: ProjectionBasisType::Bspline,
77 cov_type: CovType::Diagonal,
78 cov_weight: 1.0,
79 max_iter: 200,
80 tol: 1e-6,
81 n_init: 3,
82 seed: 42,
83 use_icl: false,
84 }
85 }
86}
87
88#[must_use = "expensive computation whose result should not be discarded"]
105pub fn gmm_cluster_with_config(
106 data: &FdMatrix,
107 argvals: &[f64],
108 covariates: Option<&FdMatrix>,
109 k_range: &[usize],
110 config: &GmmClusterConfig,
111) -> Result<GmmClusterResult, FdarError> {
112 gmm_cluster(
113 data,
114 argvals,
115 covariates,
116 k_range,
117 config.nbasis,
118 config.basis_type,
119 config.cov_type,
120 config.cov_weight,
121 config.max_iter,
122 config.tol,
123 config.n_init,
124 config.seed,
125 config.use_icl,
126 )
127}
128
129#[must_use = "expensive computation whose result should not be discarded"]
152pub fn gmm_cluster(
153 data: &FdMatrix,
154 argvals: &[f64],
155 covariates: Option<&FdMatrix>,
156 k_range: &[usize],
157 nbasis: usize,
158 basis_type: ProjectionBasisType,
159 cov_type: CovType,
160 cov_weight: f64,
161 max_iter: usize,
162 tol: f64,
163 n_init: usize,
164 seed: u64,
165 use_icl: bool,
166) -> Result<GmmClusterResult, FdarError> {
167 let (features, _d) = build_features(data, argvals, covariates, nbasis, basis_type, cov_weight)
168 .ok_or_else(|| FdarError::ComputationFailed {
169 operation: "build_features",
170 detail: "basis projection failed; check that nbasis <= number of evaluation points and data is non-degenerate".to_string(),
171 })?;
172
173 let mut bic_values = Vec::new();
174 let mut icl_values = Vec::new();
175 let mut best_result: Option<GmmResult> = None;
176 let mut best_criterion = f64::INFINITY;
177
178 for &k in k_range {
179 let best_for_k = run_multiple_inits(&features, k, cov_type, max_iter, tol, n_init, seed);
180 let Some(result) = best_for_k else {
181 continue;
182 };
183
184 bic_values.push((k, result.bic));
185 icl_values.push((k, result.icl));
186
187 let criterion = if use_icl { result.icl } else { result.bic };
188 if criterion < best_criterion {
189 best_criterion = criterion;
190 best_result = Some(result);
191 }
192 }
193
194 best_result
195 .map(|best| GmmClusterResult {
196 best,
197 bic_values,
198 icl_values,
199 })
200 .ok_or_else(|| FdarError::ComputationFailed {
201 operation: "gmm_cluster",
202 detail: "no valid GMM fit found for any K in range; try widening k_range, increasing n_init, or reducing nbasis".to_string(),
203 })
204}
205
206#[must_use = "expensive computation whose result should not be discarded"]
223pub fn predict_gmm(
224 new_data: &FdMatrix,
225 argvals: &[f64],
226 new_covariates: Option<&FdMatrix>,
227 result: &GmmResult,
228 nbasis: usize,
229 basis_type: ProjectionBasisType,
230 cov_weight: f64,
231 cov_type: CovType,
232) -> Result<(Vec<usize>, FdMatrix), FdarError> {
233 let (features, _d) = build_features(
234 new_data,
235 argvals,
236 new_covariates,
237 nbasis,
238 basis_type,
239 cov_weight,
240 )
241 .ok_or_else(|| FdarError::ComputationFailed {
242 operation: "build_features",
243 detail: "basis projection failed for new data; ensure new_data has the same number of evaluation points as the training data".to_string(),
244 })?;
245
246 let k = result.k;
247 let d = result.d;
248 let n = features.len();
249
250 let (resp, _ll) = e_step(
251 &features,
252 &result.means,
253 &result.covariances,
254 &result.weights,
255 k,
256 d,
257 cov_type,
258 );
259
260 let cluster = hard_assignments(&resp, n, k);
261 let membership = resp_to_membership(&resp, n, k);
262
263 Ok((cluster, membership))
264}