1use super::covariance::compute_covariances;
4use super::CovType;
5use crate::basis::fdata_to_basis;
6use crate::basis::projection::ProjectionBasisType;
7use crate::matrix::FdMatrix;
8use rand::prelude::*;
9
10pub(super) fn build_features(
17 data: &FdMatrix,
18 argvals: &[f64],
19 covariates: Option<&FdMatrix>,
20 nbasis: usize,
21 basis_type: ProjectionBasisType,
22 cov_weight: f64,
23) -> Option<(Vec<Vec<f64>>, usize)> {
24 let n = data.nrows();
25 let proj = fdata_to_basis(data, argvals, nbasis, basis_type)?;
26 let coef = &proj.coefficients;
27 let d_basis = coef.ncols();
28
29 let d_cov = covariates.map_or(0, super::super::matrix::FdMatrix::ncols);
30 let d = d_basis + d_cov;
31
32 let mut features = Vec::with_capacity(n);
33 for i in 0..n {
34 let mut row = Vec::with_capacity(d);
35 for j in 0..d_basis {
36 row.push(coef[(i, j)]);
37 }
38 if let Some(cov) = covariates {
39 for j in 0..d_cov {
40 row.push(cov[(i, j)] * cov_weight);
41 }
42 }
43 features.push(row);
44 }
45
46 Some((features, d))
47}
48
49fn dist_sq(a: &[f64], b: &[f64]) -> f64 {
55 a.iter().zip(b.iter()).map(|(&x, &y)| (x - y).powi(2)).sum()
56}
57
58fn weighted_sample(weights: &[f64], rng: &mut StdRng) -> usize {
60 let total: f64 = weights.iter().sum();
61 if total < 1e-15 {
62 return rng.gen_range(0..weights.len());
63 }
64 let r = rng.gen::<f64>() * total;
65 let mut cum = 0.0;
66 for (i, &w) in weights.iter().enumerate() {
67 cum += w;
68 if cum >= r {
69 return i;
70 }
71 }
72 weights.len() - 1
73}
74
75fn kmeans_pp_init(features: &[Vec<f64>], k: usize, rng: &mut StdRng) -> Vec<Vec<f64>> {
77 let n = features.len();
78 let mut centers: Vec<Vec<f64>> = Vec::with_capacity(k);
79 centers.push(features[rng.gen_range(0..n)].clone());
80
81 let mut min_dists = vec![f64::INFINITY; n];
82 for c_idx in 1..k {
83 let last = ¢ers[c_idx - 1];
84 for i in 0..n {
85 min_dists[i] = min_dists[i].min(dist_sq(&features[i], last));
86 }
87 let chosen = weighted_sample(&min_dists, rng);
88 centers.push(features[chosen].clone());
89 }
90 centers
91}
92
93fn assign_nearest(features: &[Vec<f64>], centers: &[Vec<f64>]) -> Vec<usize> {
95 features
96 .iter()
97 .map(|f| {
98 centers
99 .iter()
100 .enumerate()
101 .map(|(c, ctr)| (c, dist_sq(f, ctr)))
102 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
103 .map_or(0, |(c, _)| c)
104 })
105 .collect()
106}
107
108fn update_centers(
110 features: &[Vec<f64>],
111 assignments: &[usize],
112 old_centers: &[Vec<f64>],
113 k: usize,
114) -> Vec<Vec<f64>> {
115 let d = features[0].len();
116 let mut counts = vec![0usize; k];
117 let mut new_centers = vec![vec![0.0; d]; k];
118 for (i, &c) in assignments.iter().enumerate() {
119 counts[c] += 1;
120 for j in 0..d {
121 new_centers[c][j] += features[i][j];
122 }
123 }
124 for c in 0..k {
125 if counts[c] > 0 {
126 for j in 0..d {
127 new_centers[c][j] /= counts[c] as f64;
128 }
129 } else {
130 new_centers[c].clone_from(&old_centers[c]);
131 }
132 }
133 new_centers
134}
135
136pub(super) fn kmeans_init_assignments(
138 features: &[Vec<f64>],
139 k: usize,
140 rng: &mut StdRng,
141) -> Vec<usize> {
142 let mut centers = kmeans_pp_init(features, k, rng);
143 let mut assignments = vec![0usize; features.len()];
144 for _ in 0..10 {
145 assignments = assign_nearest(features, ¢ers);
146 centers = update_centers(features, &assignments, ¢ers, k);
147 }
148 assignments
149}
150
151pub(super) fn init_params_from_assignments(
153 features: &[Vec<f64>],
154 assignments: &[usize],
155 k: usize,
156 d: usize,
157 cov_type: CovType,
158) -> (Vec<Vec<f64>>, Vec<Vec<f64>>, Vec<f64>) {
159 let n = features.len();
160 let mut counts = vec![0usize; k];
161 let mut means = vec![vec![0.0; d]; k];
162
163 for i in 0..n {
164 let c = assignments[i];
165 counts[c] += 1;
166 for j in 0..d {
167 means[c][j] += features[i][j];
168 }
169 }
170 for c in 0..k {
171 if counts[c] > 0 {
172 for j in 0..d {
173 means[c][j] /= counts[c] as f64;
174 }
175 }
176 }
177
178 let reg = 1e-6; let covariances = compute_covariances(features, assignments, &means, k, d, cov_type, reg);
180
181 let weights: Vec<f64> = counts.iter().map(|&c| c.max(1) as f64 / n as f64).collect();
182
183 (means, covariances, weights)
184}