1use crate::error::{ModelError, ModelResult};
15use scirs2_core::ndarray::{Array1, Array2};
16
17#[derive(Debug, Clone)]
26pub struct ActivationHistogram {
27 pub bins: Array2<f32>,
29 pub edges: Array1<f32>,
31 pub num_dims: usize,
33 num_bins: usize,
34 total_counts: Array1<f32>,
35}
36
37impl ActivationHistogram {
38 pub fn new(num_bins: usize, min_val: f32, max_val: f32, num_dims: usize) -> Self {
41 let bins = Array2::zeros((num_bins, num_dims));
42 let step = (max_val - min_val) / num_bins as f32;
43 let edges = Array1::from_vec((0..=num_bins).map(|i| min_val + i as f32 * step).collect());
44 let total_counts = Array1::zeros(num_dims);
45 Self {
46 bins,
47 edges,
48 num_dims,
49 num_bins,
50 total_counts,
51 }
52 }
53
54 pub fn update(&mut self, x: &Array1<f32>) -> ModelResult<()> {
56 if x.len() != self.num_dims {
57 return Err(ModelError::dimension_mismatch(
58 "ActivationHistogram::update",
59 self.num_dims,
60 x.len(),
61 ));
62 }
63 let min_val = self.edges[0];
64 let max_val = self.edges[self.num_bins];
65 let range = max_val - min_val;
66 if range <= 0.0 {
67 return Err(ModelError::invalid_config(
68 "ActivationHistogram: zero-range edges",
69 ));
70 }
71 for (d, &v) in x.iter().enumerate() {
72 let clamped = v.clamp(min_val, max_val - f32::EPSILON * range);
74 let frac = (clamped - min_val) / range;
75 let bin = (frac * self.num_bins as f32) as usize;
76 let bin = bin.min(self.num_bins - 1);
77 self.bins[(bin, d)] += 1.0;
78 self.total_counts[d] += 1.0;
79 }
80 Ok(())
81 }
82
83 pub fn density(&self) -> Array2<f32> {
87 let mut out = Array2::zeros((self.num_bins, self.num_dims));
88 for d in 0..self.num_dims {
89 let total = self.total_counts[d];
90 if total > 0.0 {
91 for b in 0..self.num_bins {
92 out[(b, d)] = self.bins[(b, d)] / total;
93 }
94 }
95 }
96 out
97 }
98
99 pub fn per_dim_entropy(&self) -> Array1<f32> {
101 let density = self.density();
102 let mut entropy = Array1::zeros(self.num_dims);
103 for d in 0..self.num_dims {
104 let mut h = 0.0_f32;
105 for b in 0..self.num_bins {
106 let p = density[(b, d)];
107 if p > 0.0 {
108 h -= p * p.ln();
109 }
110 }
111 entropy[d] = h;
112 }
113 entropy
114 }
115
116 pub fn most_active_dims(&self, top_k: usize) -> Vec<usize> {
118 let entropy = self.per_dim_entropy();
119 let mut indexed: Vec<(usize, f32)> = entropy.iter().copied().enumerate().collect();
120 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
121 let k = top_k.min(self.num_dims);
122 indexed.into_iter().take(k).map(|(i, _)| i).collect()
123 }
124}
125
126#[derive(Debug, Clone)]
132pub struct GatingPatternRecorder {
133 patterns: Vec<Array1<f32>>,
134 max_steps: usize,
135}
136
137impl GatingPatternRecorder {
138 pub fn new(max_steps: usize) -> Self {
140 Self {
141 patterns: Vec::with_capacity(max_steps),
142 max_steps,
143 }
144 }
145
146 pub fn record(&mut self, gate: &Array1<f32>) -> ModelResult<()> {
148 if self.patterns.len() >= self.max_steps {
149 return Err(ModelError::invalid_config(
150 "GatingPatternRecorder: max_steps exceeded",
151 ));
152 }
153 if !self.patterns.is_empty() && gate.len() != self.patterns[0].len() {
154 return Err(ModelError::dimension_mismatch(
155 "GatingPatternRecorder::record",
156 self.patterns[0].len(),
157 gate.len(),
158 ));
159 }
160 self.patterns.push(gate.clone());
161 Ok(())
162 }
163
164 pub fn as_matrix(&self) -> ModelResult<Array2<f32>> {
166 if self.patterns.is_empty() {
167 return Err(ModelError::invalid_config(
168 "GatingPatternRecorder: no patterns recorded",
169 ));
170 }
171 let t = self.patterns.len();
172 let d = self.patterns[0].len();
173 let mut m = Array2::zeros((t, d));
174 for (i, p) in self.patterns.iter().enumerate() {
175 for (j, &v) in p.iter().enumerate() {
176 m[(i, j)] = v;
177 }
178 }
179 Ok(m)
180 }
181
182 pub fn cross_correlation(&self) -> ModelResult<Array2<f32>> {
187 let m = self.as_matrix()?;
188 let t = m.nrows();
189 let d = m.ncols();
190
191 if t < 2 {
192 return Err(ModelError::invalid_config(
193 "GatingPatternRecorder::cross_correlation: need at least 2 time steps",
194 ));
195 }
196
197 let mut means = Array1::<f32>::zeros(d);
199 let mut stds = Array1::<f32>::zeros(d);
200 for j in 0..d {
201 let sum: f32 = (0..t).map(|i| m[(i, j)]).sum();
202 let mean = sum / t as f32;
203 means[j] = mean;
204 let var: f32 = (0..t).map(|i| (m[(i, j)] - mean).powi(2)).sum::<f32>() / t as f32;
205 stds[j] = var.sqrt();
206 }
207
208 let mut corr = Array2::<f32>::zeros((d, d));
209 for a in 0..d {
210 for b in 0..d {
211 if stds[a] < 1e-12 || stds[b] < 1e-12 {
212 corr[(a, b)] = if a == b { 1.0 } else { 0.0 };
214 } else {
215 let cov: f32 = (0..t)
216 .map(|i| (m[(i, a)] - means[a]) * (m[(i, b)] - means[b]))
217 .sum::<f32>()
218 / t as f32;
219 corr[(a, b)] = (cov / (stds[a] * stds[b])).clamp(-1.0, 1.0);
220 }
221 }
222 }
223 Ok(corr)
224 }
225
226 pub fn correlated_dims(&self, threshold: f32) -> ModelResult<Vec<(usize, usize, f32)>> {
228 let corr = self.cross_correlation()?;
229 let d = corr.nrows();
230 let mut result = Vec::new();
231 for a in 0..d {
232 for b in (a + 1)..d {
233 let c = corr[(a, b)];
234 if c.abs() >= threshold {
235 result.push((a, b, c));
236 }
237 }
238 }
239 Ok(result)
240 }
241
242 pub fn smoothed(&self, window: usize) -> ModelResult<Array2<f32>> {
244 let m = self.as_matrix()?;
245 let t = m.nrows();
246 let d = m.ncols();
247 let w = window.max(1);
248 let mut out = Array2::zeros((t, d));
249 for i in 0..t {
250 let start = (i + 1).saturating_sub(w);
251 let count = (i - start + 1) as f32;
252 for j in 0..d {
253 let sum: f32 = (start..=i).map(|k| m[(k, j)]).sum();
254 out[(i, j)] = sum / count;
255 }
256 }
257 Ok(out)
258 }
259}
260
261#[derive(Debug, Clone)]
268pub struct PhasePortrait {
269 trajectory: Vec<Array1<f32>>,
270 dim: usize,
271}
272
273impl PhasePortrait {
274 pub fn new(dim: usize, capacity: usize) -> Self {
277 Self {
278 trajectory: Vec::with_capacity(capacity),
279 dim,
280 }
281 }
282
283 pub fn record(&mut self, state: &Array1<f32>) -> ModelResult<()> {
285 if state.len() != self.dim {
286 return Err(ModelError::dimension_mismatch(
287 "PhasePortrait::record",
288 self.dim,
289 state.len(),
290 ));
291 }
292 self.trajectory.push(state.clone());
293 Ok(())
294 }
295
296 pub fn pca_projection(&self) -> ModelResult<Array2<f32>> {
300 let t = self.trajectory.len();
301 if t < 2 {
302 return Err(ModelError::invalid_config(
303 "PhasePortrait::pca_projection: need at least 2 recorded states",
304 ));
305 }
306 let d = self.dim;
307
308 let mut data = Array2::<f32>::zeros((t, d));
310 for (i, s) in self.trajectory.iter().enumerate() {
311 for (j, &v) in s.iter().enumerate() {
312 data[(i, j)] = v;
313 }
314 }
315 for j in 0..d {
317 let col_mean: f32 = (0..t).map(|i| data[(i, j)]).sum::<f32>() / t as f32;
318 for i in 0..t {
319 data[(i, j)] -= col_mean;
320 }
321 }
322
323 let mut out = Array2::<f32>::zeros((t, 2));
324
325 let mut data_copy = data.clone();
327 for pc_idx in 0..2 {
328 let mut v = Array1::<f32>::zeros(d);
330 v[pc_idx % d] = 1.0;
331
332 for _ in 0..50 {
333 let mut u = Array1::<f32>::zeros(t);
335 for i in 0..t {
336 u[i] = (0..d).map(|j| data_copy[(i, j)] * v[j]).sum();
337 }
338 let mut v_new = Array1::<f32>::zeros(d);
340 for j in 0..d {
341 v_new[j] = (0..t).map(|i| data_copy[(i, j)] * u[i]).sum();
342 }
343 let norm = v_new.iter().map(|&x| x * x).sum::<f32>().sqrt();
345 if norm < 1e-12 {
346 break;
347 }
348 v = v_new.mapv(|x| x / norm);
349 }
350
351 for i in 0..t {
353 let proj: f32 = (0..d).map(|j| data_copy[(i, j)] * v[j]).sum();
354 out[(i, pc_idx)] = proj;
355 }
356
357 for i in 0..t {
359 let score = out[(i, pc_idx)];
360 for j in 0..d {
361 data_copy[(i, j)] -= score * v[j];
362 }
363 }
364 }
365
366 Ok(out)
367 }
368
369 pub fn divergence_estimate(&self) -> ModelResult<f32> {
373 let t = self.trajectory.len();
374 if t < 3 {
375 return Err(ModelError::invalid_config(
376 "PhasePortrait::divergence_estimate: need at least 3 states",
377 ));
378 }
379 let mut log_ratios = Vec::new();
380 let dist = |a: &Array1<f32>, b: &Array1<f32>| -> f32 {
381 a.iter()
382 .zip(b.iter())
383 .map(|(x, y)| (x - y).powi(2))
384 .sum::<f32>()
385 .sqrt()
386 };
387 for i in 0..(t - 2) {
388 let d0 = dist(&self.trajectory[i], &self.trajectory[i + 1]);
389 let d1 = dist(&self.trajectory[i + 1], &self.trajectory[i + 2]);
390 if d0 > 1e-12 && d1 > 1e-12 {
391 log_ratios.push((d1 / d0).ln());
392 }
393 }
394 if log_ratios.is_empty() {
395 return Ok(0.0);
396 }
397 Ok(log_ratios.iter().sum::<f32>() / log_ratios.len() as f32)
398 }
399
400 pub fn fixed_points(&self, tolerance: f32) -> Vec<Array1<f32>> {
404 let mut representatives: Vec<Array1<f32>> = Vec::new();
405 let dist = |a: &Array1<f32>, b: &Array1<f32>| -> f32 {
406 a.iter()
407 .zip(b.iter())
408 .map(|(x, y)| (x - y).powi(2))
409 .sum::<f32>()
410 .sqrt()
411 };
412 for state in &self.trajectory {
413 let already_covered = representatives
414 .iter()
415 .any(|rep| dist(rep, state) <= tolerance);
416 if !already_covered {
417 representatives.push(state.clone());
418 }
419 }
420 representatives
421 }
422
423 pub fn periodicity_score(&self) -> ModelResult<f32> {
428 let t = self.trajectory.len();
429 if t < 4 {
430 return Err(ModelError::invalid_config(
431 "PhasePortrait::periodicity_score: need at least 4 states",
432 ));
433 }
434
435 let norms: Vec<f32> = self
437 .trajectory
438 .iter()
439 .map(|s| s.iter().map(|&x| x * x).sum::<f32>().sqrt())
440 .collect();
441
442 let mean = norms.iter().sum::<f32>() / t as f32;
443 let centered: Vec<f32> = norms.iter().map(|&x| x - mean).collect();
444 let var: f32 = centered.iter().map(|&x| x * x).sum::<f32>() / t as f32;
445
446 if var < 1e-12 {
447 return Ok(1.0);
449 }
450
451 let max_lag = (t / 2).max(1);
453 let mut peak = 0.0_f32;
454 for lag in 1..=max_lag {
455 let cov: f32 = (0..(t - lag))
456 .map(|i| centered[i] * centered[i + lag])
457 .sum::<f32>()
458 / (t - lag) as f32;
459 let acf = (cov / var).abs();
460 if acf > peak {
461 peak = acf;
462 }
463 }
464 Ok(peak.min(1.0))
465 }
466}
467
468pub fn matrix_to_csv(m: &Array2<f32>) -> String {
477 let (rows, cols) = m.dim();
478 let mut lines = Vec::with_capacity(rows);
479 for i in 0..rows {
480 let row_str: Vec<String> = (0..cols).map(|j| format!("{}", m[(i, j)])).collect();
481 lines.push(row_str.join(","));
482 }
483 lines.join("\n")
484}
485
486pub fn signal_to_svg_sparkline(signal: &Array1<f32>, width: usize, height: usize) -> String {
491 let n = signal.len();
492 if n == 0 || width == 0 || height == 0 {
493 return format!(
494 r#"<svg xmlns="http://www.w3.org/2000/svg" width="{}" height="{}"></svg>"#,
495 width, height
496 );
497 }
498
499 let min_val = signal.iter().cloned().fold(f32::INFINITY, f32::min);
500 let max_val = signal.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
501 let range = (max_val - min_val).max(f32::EPSILON);
502
503 let pad = 2usize;
504 let draw_w = (width.saturating_sub(pad * 2)).max(1) as f32;
505 let draw_h = (height.saturating_sub(pad * 2)).max(1) as f32;
506
507 let points: Vec<String> = signal
508 .iter()
509 .enumerate()
510 .map(|(i, &v)| {
511 let x = pad as f32 + i as f32 * draw_w / (n - 1).max(1) as f32;
512 let y = pad as f32 + (1.0 - (v - min_val) / range) * draw_h;
514 format!("{:.2},{:.2}", x, y)
515 })
516 .collect();
517
518 format!(
519 r##"<svg xmlns="http://www.w3.org/2000/svg" width="{w}" height="{h}"><polyline points="{pts}" fill="none" stroke="#4488cc" stroke-width="1.5"/></svg>"##,
520 w = width,
521 h = height,
522 pts = points.join(" ")
523 )
524}
525
526#[cfg(test)]
531mod tests {
532 use super::*;
533
534 fn pseudo_rand(seed: usize) -> f32 {
536 let v = seed.wrapping_mul(1664525).wrapping_add(1013904223);
538 (v & 0xFFFF) as f32 / 65536.0
539 }
540
541 #[test]
542 fn test_histogram_update_and_density() {
543 let num_bins = 10;
544 let num_dims = 4;
545 let mut hist = ActivationHistogram::new(num_bins, -1.0, 1.0, num_dims);
546
547 for i in 0..100 {
548 let vals: Vec<f32> = (0..num_dims)
549 .map(|d| pseudo_rand(i * num_dims + d) * 2.0 - 1.0)
550 .collect();
551 let x = Array1::from_vec(vals);
552 hist.update(&x).expect("update failed");
553 }
554
555 let density = hist.density();
556 assert_eq!(density.nrows(), num_bins);
557 assert_eq!(density.ncols(), num_dims);
558
559 for d in 0..num_dims {
560 let col_sum: f32 = (0..num_bins).map(|b| density[(b, d)]).sum();
561 assert!(
562 (col_sum - 1.0).abs() < 1e-3,
563 "density sum for dim {d} = {col_sum}"
564 );
565 }
566 }
567
568 #[test]
569 fn test_histogram_most_active_dims() {
570 let num_bins = 8;
571 let num_dims = 6;
572 let mut hist = ActivationHistogram::new(num_bins, 0.0, 1.0, num_dims);
573
574 for i in 0..80 {
575 let vals: Vec<f32> = (0..num_dims)
576 .map(|d| pseudo_rand(i * num_dims + d + 1))
577 .collect();
578 hist.update(&Array1::from_vec(vals)).expect("update failed");
579 }
580
581 let top2 = hist.most_active_dims(2);
582 assert_eq!(top2.len(), 2);
583 for &idx in &top2 {
584 assert!(idx < num_dims);
585 }
586 assert_ne!(top2[0], top2[1]);
588 }
589
590 #[test]
591 fn test_gating_pattern_as_matrix() {
592 let dim = 8;
593 let steps = 20;
594 let mut recorder = GatingPatternRecorder::new(50);
595
596 for i in 0..steps {
597 let gate = Array1::from_vec((0..dim).map(|d| pseudo_rand(i * dim + d)).collect());
598 recorder.record(&gate).expect("record failed");
599 }
600
601 let m = recorder.as_matrix().expect("as_matrix failed");
602 assert_eq!(m.nrows(), steps);
603 assert_eq!(m.ncols(), dim);
604 }
605
606 #[test]
607 fn test_gating_pattern_cross_correlation_diagonal() {
608 let dim = 4;
609 let steps = 30;
610 let mut recorder = GatingPatternRecorder::new(100);
611
612 for i in 0..steps {
613 let gate = Array1::from_vec((0..dim).map(|d| pseudo_rand(i * dim + d + 42)).collect());
614 recorder.record(&gate).expect("record failed");
615 }
616
617 let corr = recorder
618 .cross_correlation()
619 .expect("cross_correlation failed");
620 assert_eq!(corr.nrows(), dim);
621 assert_eq!(corr.ncols(), dim);
622
623 for d in 0..dim {
624 let diag = corr[(d, d)];
625 assert!(
626 (diag - 1.0).abs() < 1e-4,
627 "diagonal[{d}] = {diag}, expected ≈ 1.0"
628 );
629 }
630 }
631
632 #[test]
633 fn test_phase_portrait_pca_projection() {
634 let dim = 16;
635 let steps = 30;
636 let mut pp = PhasePortrait::new(dim, 64);
637
638 for i in 0..steps {
639 let state = Array1::from_vec(
640 (0..dim)
641 .map(|d| pseudo_rand(i * dim + d + 7) * 2.0 - 1.0)
642 .collect(),
643 );
644 pp.record(&state).expect("record failed");
645 }
646
647 let proj = pp.pca_projection().expect("pca_projection failed");
648 assert_eq!(proj.nrows(), steps);
649 assert_eq!(proj.ncols(), 2);
650 }
651
652 #[test]
653 fn test_phase_portrait_fixed_points() {
654 let dim = 4;
655 let mut pp = PhasePortrait::new(dim, 20);
656 let fixed = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
657
658 for _ in 0..10 {
659 pp.record(&fixed).expect("record failed");
660 }
661
662 let fps = pp.fixed_points(1e-3);
663 assert_eq!(
664 fps.len(),
665 1,
666 "expected exactly 1 fixed point, got {}",
667 fps.len()
668 );
669 for (a, b) in fps[0].iter().zip(fixed.iter()) {
671 assert!((a - b).abs() < 1e-6);
672 }
673 }
674
675 #[test]
676 fn test_matrix_to_csv_format() {
677 let m = Array2::from_shape_vec(
678 (3, 4),
679 vec![
680 1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
681 ],
682 )
683 .expect("shape error");
684
685 let csv = matrix_to_csv(&m);
686 let lines: Vec<&str> = csv.lines().collect();
687 assert_eq!(lines.len(), 3, "expected 3 lines");
688
689 for line in &lines {
690 let comma_count = line.chars().filter(|&c| c == ',').count();
691 assert_eq!(
692 comma_count, 3,
693 "expected 3 commas per line, got {comma_count} in '{line}'"
694 );
695 }
696 }
697
698 #[test]
699 fn test_signal_to_svg_sparkline_valid() {
700 let signal = Array1::from_vec((0..20).map(|i| (i as f32 * 0.3).sin()).collect());
701 let svg = signal_to_svg_sparkline(&signal, 200, 50);
702 assert!(svg.contains("<svg"), "missing <svg tag");
703 assert!(svg.contains("</svg>"), "missing </svg> tag");
704 assert!(svg.contains("polyline"), "missing polyline element");
705 }
706}