1use ndarray::Array2;
27
28use super::NoiseCov;
29
30#[derive(Debug, Clone, Copy, PartialEq)]
32pub enum Regularization {
33 Empirical,
35 ShrunkIdentity(Option<f64>),
40 Diagonal,
42}
43
44pub fn compute_covariance(data: &Array2<f64>, reg: Regularization) -> NoiseCov {
58 let (n_ch, n_t) = data.dim();
59 assert!(n_t > 1, "Need at least 2 time points for covariance");
60
61 let means = data.mean_axis(ndarray::Axis(1)).unwrap();
63 let mut centered = data.clone();
64 for i in 0..n_ch {
65 for j in 0..n_t {
66 centered[[i, j]] -= means[i];
67 }
68 }
69
70 match reg {
71 Regularization::Empirical => {
72 let cov = centered.dot(¢ered.t()) / (n_t - 1) as f64;
73 NoiseCov::full(cov)
74 }
75 Regularization::ShrunkIdentity(alpha_opt) => {
76 let cov = centered.dot(¢ered.t()) / (n_t - 1) as f64;
77 let alpha = alpha_opt.unwrap_or_else(|| ledoit_wolf_alpha(¢ered, &cov));
78 let alpha = alpha.clamp(0.0, 1.0);
79 let trace = cov.diag().sum();
80 let mu = trace / n_ch as f64;
81 let shrunk = cov.mapv(|v| v * (1.0 - alpha)) + Array2::<f64>::eye(n_ch).mapv(|v: f64| v * alpha * mu);
82 NoiseCov::full(shrunk)
83 }
84 Regularization::Diagonal => {
85 let mut vars = Vec::with_capacity(n_ch);
86 for i in 0..n_ch {
87 let mut sum_sq = 0.0;
88 for j in 0..n_t {
89 sum_sq += centered[[i, j]].powi(2);
90 }
91 vars.push(sum_sq / (n_t - 1) as f64);
92 }
93 NoiseCov::diagonal(vars)
94 }
95 }
96}
97
98pub fn compute_covariance_epochs(
103 epochs: &ndarray::Array3<f64>,
104 reg: Regularization,
105) -> NoiseCov {
106 let (n_epochs, n_ch, n_t) = epochs.dim();
107 let total_t = n_epochs * n_t;
108
109 let mut concat = Array2::zeros((n_ch, total_t));
111 for e in 0..n_epochs {
112 let epoch = epochs.slice(ndarray::s![e, .., ..]);
113 let mean = epoch.mean_axis(ndarray::Axis(1)).unwrap();
114 for i in 0..n_ch {
115 for j in 0..n_t {
116 concat[[i, e * n_t + j]] = epoch[[i, j]] - mean[i];
117 }
118 }
119 }
120
121 let (_, total) = concat.dim();
123 match reg {
124 Regularization::Empirical => {
125 let cov = concat.dot(&concat.t()) / (total - 1) as f64;
126 NoiseCov::full(cov)
127 }
128 Regularization::ShrunkIdentity(alpha_opt) => {
129 let cov = concat.dot(&concat.t()) / (total - 1) as f64;
130 let alpha = alpha_opt.unwrap_or_else(|| ledoit_wolf_alpha(&concat, &cov));
131 let alpha = alpha.clamp(0.0, 1.0);
132 let trace = cov.diag().sum();
133 let mu = trace / n_ch as f64;
134 let shrunk = cov.mapv(|v| v * (1.0 - alpha)) + Array2::<f64>::eye(n_ch).mapv(|v: f64| v * alpha * mu);
135 NoiseCov::full(shrunk)
136 }
137 Regularization::Diagonal => {
138 let mut vars = Vec::with_capacity(n_ch);
139 for i in 0..n_ch {
140 let mut sum_sq = 0.0;
141 for j in 0..total {
142 sum_sq += concat[[i, j]].powi(2);
143 }
144 vars.push(sum_sq / (total - 1) as f64);
145 }
146 NoiseCov::diagonal(vars)
147 }
148 }
149}
150
151fn ledoit_wolf_alpha(x: &Array2<f64>, sample_cov: &Array2<f64>) -> f64 {
156 let (p, n) = x.dim(); if n < 2 {
159 return 1.0;
160 }
161
162 let trace_s = sample_cov.diag().sum();
163 let trace_s2 = sample_cov.iter().map(|v| v * v).sum::<f64>();
164 let mu = trace_s / p as f64;
165
166 let mut beta_sum = 0.0;
169 for t in 0..n {
170 let mut xtx = 0.0;
173 for i in 0..p {
174 xtx += x[[i, t]] * x[[i, t]];
175 }
176 let mut xt_s_xt = 0.0;
177 for i in 0..p {
178 let mut row_dot = 0.0;
179 for j in 0..p {
180 row_dot += sample_cov[[i, j]] * x[[j, t]];
181 }
182 xt_s_xt += x[[i, t]] * row_dot;
183 }
184 beta_sum += xtx * xtx - 2.0 * xt_s_xt + trace_s2;
185 }
186 let beta = beta_sum / (n * n) as f64;
187
188 let delta = trace_s2 - p as f64 * mu * mu;
190
191 if delta <= 0.0 {
192 return 1.0;
193 }
194
195 (beta / delta).clamp(0.0, 1.0)
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use ndarray::Array2;
202
203 #[test]
204 fn test_empirical_covariance_shape() {
205 let data = Array2::<f64>::from_shape_fn((4, 100), |(i, j)| {
206 ((i * 100 + j) as f64 * 0.1).sin()
207 });
208 let cov = compute_covariance(&data, Regularization::Empirical);
209 assert_eq!(cov.n_channels(), 4);
210 let full = cov.to_full();
211 assert_eq!(full.dim(), (4, 4));
212 }
213
214 #[test]
215 fn test_empirical_covariance_symmetric() {
216 let data = Array2::<f64>::from_shape_fn((5, 200), |(i, j)| {
217 ((i * 200 + j) as f64 * 0.3).cos() * 1e-6
218 });
219 let cov = compute_covariance(&data, Regularization::Empirical);
220 let full = cov.to_full();
221 for i in 0..5 {
222 for j in 0..5 {
223 approx::assert_abs_diff_eq!(full[[i, j]], full[[j, i]], epsilon = 1e-15);
224 }
225 }
226 }
227
228 #[test]
229 fn test_empirical_covariance_positive_diagonal() {
230 let data = Array2::<f64>::from_shape_fn((3, 500), |(i, j)| {
231 ((i * 500 + j) as f64 * 0.7).sin() * 1e-6
232 });
233 let cov = compute_covariance(&data, Regularization::Empirical);
234 let diag = cov.diag_elements();
235 for &v in diag.iter() {
236 assert!(v > 0.0, "Diagonal should be positive");
237 }
238 }
239
240 #[test]
241 fn test_diagonal_covariance() {
242 let data = Array2::<f64>::from_shape_fn((3, 500), |(i, j)| {
243 ((i * 500 + j) as f64 * 0.2).sin() * (i as f64 + 1.0) * 1e-6
244 });
245 let cov = compute_covariance(&data, Regularization::Diagonal);
246 assert!(cov.diag);
247 assert_eq!(cov.n_channels(), 3);
248 }
249
250 #[test]
251 fn test_shrunk_covariance_between_empirical_and_identity() {
252 let data = Array2::<f64>::from_shape_fn((4, 200), |(i, j)| {
253 ((i * 200 + j) as f64 * 0.5).sin() * 1e-6
254 });
255 let emp = compute_covariance(&data, Regularization::Empirical).to_full();
256 let shrunk = compute_covariance(&data, Regularization::ShrunkIdentity(None)).to_full();
257
258 let mut emp_offdiag = 0.0;
260 let mut shrunk_offdiag = 0.0;
261 for i in 0..4 {
262 for j in 0..4 {
263 if i != j {
264 emp_offdiag += emp[[i, j]].abs();
265 shrunk_offdiag += shrunk[[i, j]].abs();
266 }
267 }
268 }
269 assert!(
270 shrunk_offdiag <= emp_offdiag + 1e-20,
271 "Shrinkage should reduce off-diagonal: shrunk={shrunk_offdiag}, emp={emp_offdiag}"
272 );
273 }
274
275 #[test]
276 fn test_covariance_from_epochs() {
277 let epochs = ndarray::Array3::<f64>::from_shape_fn((10, 3, 50), |(e, i, j)| {
278 ((e * 150 + i * 50 + j) as f64 * 0.4).sin() * 1e-6
279 });
280 let cov = compute_covariance_epochs(&epochs, Regularization::Empirical);
281 assert_eq!(cov.n_channels(), 3);
282 let full = cov.to_full();
283 for i in 0..3 {
285 for j in 0..3 {
286 approx::assert_abs_diff_eq!(full[[i, j]], full[[j, i]], epsilon = 1e-15);
287 }
288 }
289 for i in 0..3 {
291 assert!(full[[i, i]] > 0.0);
292 }
293 }
294}