1use super::em::{e_step, gmm_em, hard_assignments, resp_to_membership};
4use super::init::build_features;
5use super::{CovType, GmmClusterResult, GmmResult};
6use crate::basis::projection::ProjectionBasisType;
7use crate::error::FdarError;
8use crate::matrix::FdMatrix;
9
10pub(super) fn run_multiple_inits(
12 features: &[Vec<f64>],
13 k: usize,
14 cov_type: CovType,
15 max_iter: usize,
16 tol: f64,
17 n_init: usize,
18 base_seed: u64,
19) -> Option<GmmResult> {
20 let mut best: Option<GmmResult> = None;
21 for init in 0..n_init.max(1) {
22 let seed = base_seed.wrapping_add(init as u64 * 1000 + k as u64);
23 if let Ok(result) = gmm_em(features, k, cov_type, max_iter, tol, seed) {
24 let is_better = best
25 .as_ref()
26 .map_or(true, |b| result.log_likelihood > b.log_likelihood);
27 if is_better {
28 best = Some(result);
29 }
30 }
31 }
32 best
33}
34
35#[derive(Debug, Clone)]
53pub struct GmmClusterConfig {
54 pub nbasis: usize,
56 pub basis_type: ProjectionBasisType,
58 pub cov_type: CovType,
60 pub cov_weight: f64,
62 pub max_iter: usize,
64 pub tol: f64,
66 pub n_init: usize,
68 pub seed: u64,
70 pub use_icl: bool,
72}
73
74impl Default for GmmClusterConfig {
75 fn default() -> Self {
76 Self {
77 nbasis: 5,
78 basis_type: ProjectionBasisType::Bspline,
79 cov_type: CovType::Diagonal,
80 cov_weight: 1.0,
81 max_iter: 200,
82 tol: 1e-6,
83 n_init: 3,
84 seed: 42,
85 use_icl: false,
86 }
87 }
88}
89
90#[must_use = "expensive computation whose result should not be discarded"]
107pub fn gmm_cluster_with_config(
108 data: &FdMatrix,
109 argvals: &[f64],
110 covariates: Option<&FdMatrix>,
111 k_range: &[usize],
112 config: &GmmClusterConfig,
113) -> Result<GmmClusterResult, FdarError> {
114 gmm_cluster(
115 data,
116 argvals,
117 covariates,
118 k_range,
119 config.nbasis,
120 config.basis_type,
121 config.cov_type,
122 config.cov_weight,
123 config.max_iter,
124 config.tol,
125 config.n_init,
126 config.seed,
127 config.use_icl,
128 )
129}
130
131#[must_use = "expensive computation whose result should not be discarded"]
154pub fn gmm_cluster(
155 data: &FdMatrix,
156 argvals: &[f64],
157 covariates: Option<&FdMatrix>,
158 k_range: &[usize],
159 nbasis: usize,
160 basis_type: ProjectionBasisType,
161 cov_type: CovType,
162 cov_weight: f64,
163 max_iter: usize,
164 tol: f64,
165 n_init: usize,
166 seed: u64,
167 use_icl: bool,
168) -> Result<GmmClusterResult, FdarError> {
169 let (features, _d) = build_features(data, argvals, covariates, nbasis, basis_type, cov_weight)
170 .ok_or_else(|| FdarError::ComputationFailed {
171 operation: "build_features",
172 detail: "basis projection failed; check that nbasis <= number of evaluation points and data is non-degenerate".to_string(),
173 })?;
174
175 let mut bic_values = Vec::new();
176 let mut icl_values = Vec::new();
177 let mut best_result: Option<GmmResult> = None;
178 let mut best_criterion = f64::INFINITY;
179
180 for &k in k_range {
181 let best_for_k = run_multiple_inits(&features, k, cov_type, max_iter, tol, n_init, seed);
182 let Some(result) = best_for_k else {
183 continue;
184 };
185
186 bic_values.push((k, result.bic));
187 icl_values.push((k, result.icl));
188
189 let criterion = if use_icl { result.icl } else { result.bic };
190 if criterion < best_criterion {
191 best_criterion = criterion;
192 best_result = Some(result);
193 }
194 }
195
196 best_result
197 .map(|best| GmmClusterResult {
198 best,
199 bic_values,
200 icl_values,
201 })
202 .ok_or_else(|| FdarError::ComputationFailed {
203 operation: "gmm_cluster",
204 detail: "no valid GMM fit found for any K in range; try widening k_range, increasing n_init, or reducing nbasis".to_string(),
205 })
206}
207
208#[must_use = "expensive computation whose result should not be discarded"]
225pub fn predict_gmm(
226 new_data: &FdMatrix,
227 argvals: &[f64],
228 new_covariates: Option<&FdMatrix>,
229 result: &GmmResult,
230 nbasis: usize,
231 basis_type: ProjectionBasisType,
232 cov_weight: f64,
233 cov_type: CovType,
234) -> Result<(Vec<usize>, FdMatrix), FdarError> {
235 let (features, _d) = build_features(
236 new_data,
237 argvals,
238 new_covariates,
239 nbasis,
240 basis_type,
241 cov_weight,
242 )
243 .ok_or_else(|| FdarError::ComputationFailed {
244 operation: "build_features",
245 detail: "basis projection failed for new data; ensure new_data has the same number of evaluation points as the training data".to_string(),
246 })?;
247
248 let k = result.k;
249 let d = result.d;
250 let n = features.len();
251
252 let (resp, _ll) = e_step(
253 &features,
254 &result.means,
255 &result.covariances,
256 &result.weights,
257 k,
258 d,
259 cov_type,
260 );
261
262 let cluster = hard_assignments(&resp, n, k);
263 let membership = resp_to_membership(&resp, n, k);
264
265 Ok((cluster, membership))
266}