use super::em::{e_step, gmm_em, hard_assignments, resp_to_membership};
use super::init::build_features;
use super::{CovType, GmmClusterResult, GmmResult};
use crate::basis::projection::ProjectionBasisType;
use crate::error::FdarError;
use crate::matrix::FdMatrix;
pub(super) fn run_multiple_inits(
features: &[Vec<f64>],
k: usize,
cov_type: CovType,
max_iter: usize,
tol: f64,
n_init: usize,
base_seed: u64,
) -> Option<GmmResult> {
let mut best: Option<GmmResult> = None;
for init in 0..n_init.max(1) {
let seed = base_seed.wrapping_add(init as u64 * 1000 + k as u64);
if let Ok(result) = gmm_em(features, k, cov_type, max_iter, tol, seed) {
let is_better = best
.as_ref()
.map_or(true, |b| result.log_likelihood > b.log_likelihood);
if is_better {
best = Some(result);
}
}
}
best
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct GmmClusterConfig {
pub nbasis: usize,
pub basis_type: ProjectionBasisType,
pub cov_type: CovType,
pub cov_weight: f64,
pub max_iter: usize,
pub tol: f64,
pub n_init: usize,
pub seed: u64,
pub use_icl: bool,
}
impl Default for GmmClusterConfig {
fn default() -> Self {
Self {
nbasis: 5,
basis_type: ProjectionBasisType::Bspline,
cov_type: CovType::Diagonal,
cov_weight: 1.0,
max_iter: 200,
tol: 1e-6,
n_init: 3,
seed: 42,
use_icl: false,
}
}
}
#[must_use = "expensive computation whose result should not be discarded"]
pub fn gmm_cluster_with_config(
data: &FdMatrix,
argvals: &[f64],
covariates: Option<&FdMatrix>,
k_range: &[usize],
config: &GmmClusterConfig,
) -> Result<GmmClusterResult, FdarError> {
gmm_cluster(
data,
argvals,
covariates,
k_range,
config.nbasis,
config.basis_type,
config.cov_type,
config.cov_weight,
config.max_iter,
config.tol,
config.n_init,
config.seed,
config.use_icl,
)
}
#[must_use = "expensive computation whose result should not be discarded"]
pub fn gmm_cluster(
data: &FdMatrix,
argvals: &[f64],
covariates: Option<&FdMatrix>,
k_range: &[usize],
nbasis: usize,
basis_type: ProjectionBasisType,
cov_type: CovType,
cov_weight: f64,
max_iter: usize,
tol: f64,
n_init: usize,
seed: u64,
use_icl: bool,
) -> Result<GmmClusterResult, FdarError> {
let (features, _d) = build_features(data, argvals, covariates, nbasis, basis_type, cov_weight)
.ok_or_else(|| FdarError::ComputationFailed {
operation: "build_features",
detail: "basis projection failed; check that nbasis <= number of evaluation points and data is non-degenerate".to_string(),
})?;
let mut bic_values = Vec::new();
let mut icl_values = Vec::new();
let mut best_result: Option<GmmResult> = None;
let mut best_criterion = f64::INFINITY;
for &k in k_range {
let best_for_k = run_multiple_inits(&features, k, cov_type, max_iter, tol, n_init, seed);
let Some(result) = best_for_k else {
continue;
};
bic_values.push((k, result.bic));
icl_values.push((k, result.icl));
let criterion = if use_icl { result.icl } else { result.bic };
if criterion < best_criterion {
best_criterion = criterion;
best_result = Some(result);
}
}
best_result
.map(|best| GmmClusterResult {
best,
bic_values,
icl_values,
})
.ok_or_else(|| FdarError::ComputationFailed {
operation: "gmm_cluster",
detail: "no valid GMM fit found for any K in range; try widening k_range, increasing n_init, or reducing nbasis".to_string(),
})
}
#[must_use = "expensive computation whose result should not be discarded"]
pub fn predict_gmm(
new_data: &FdMatrix,
argvals: &[f64],
new_covariates: Option<&FdMatrix>,
result: &GmmResult,
nbasis: usize,
basis_type: ProjectionBasisType,
cov_weight: f64,
cov_type: CovType,
) -> Result<(Vec<usize>, FdMatrix), FdarError> {
let (features, _d) = build_features(
new_data,
argvals,
new_covariates,
nbasis,
basis_type,
cov_weight,
)
.ok_or_else(|| FdarError::ComputationFailed {
operation: "build_features",
detail: "basis projection failed for new data; ensure new_data has the same number of evaluation points as the training data".to_string(),
})?;
let k = result.k;
let d = result.d;
let n = features.len();
let (resp, _ll) = e_step(
&features,
&result.means,
&result.covariances,
&result.weights,
k,
d,
cov_type,
);
let cluster = hard_assignments(&resp, n, k);
let membership = resp_to_membership(&resp, n, k);
Ok((cluster, membership))
}