use ndarray::{Array2, ArrayView2};
use crate::constants::{FEATURE_NAMES, NORM_PARAMS, NON_NEGATIVE_CHANNELS, NUM_CHANNELS};
use crate::error::{Result, SensorLMError};
#[derive(Debug, Clone)]
pub struct ChannelStats {
pub mean: Vec<f64>,
pub std: Vec<f64>,
}
impl ChannelStats {
pub fn from_constants() -> Self {
let mean: Vec<f64> = NORM_PARAMS.iter().map(|(m, _)| *m).collect();
let std: Vec<f64> = NORM_PARAMS.iter().map(|(_, s)| *s).collect();
Self { mean, std }
}
}
pub fn normalize(data: &mut Array2<f64>) -> Result<()> {
let (t, c) = (data.nrows(), data.ncols());
if c != NUM_CHANNELS {
return Err(SensorLMError::ShapeMismatch {
expected: vec![t, NUM_CHANNELS],
actual: vec![t, c],
});
}
for ch in 0..NUM_CHANNELS {
let (mean, std) = NORM_PARAMS[ch];
if std == 0.0 {
continue;
}
let mut col = data.column_mut(ch);
col.mapv_inplace(|x| (x - mean) / std);
}
Ok(())
}
pub fn normalized(data: &ArrayView2<f64>) -> Result<Array2<f64>> {
let mut out = data.to_owned();
normalize(&mut out)?;
Ok(out)
}
pub fn denormalize(data: &mut Array2<f64>) -> Result<()> {
let (t, c) = (data.nrows(), data.ncols());
if c != NUM_CHANNELS {
return Err(SensorLMError::ShapeMismatch {
expected: vec![t, NUM_CHANNELS],
actual: vec![t, c],
});
}
for ch in 0..NUM_CHANNELS {
let (mean, std) = NORM_PARAMS[ch];
let mut col = data.column_mut(ch);
col.mapv_inplace(|z| z * std + mean);
}
for &ch in NON_NEGATIVE_CHANNELS {
let mut col = data.column_mut(ch);
col.mapv_inplace(|x| x.max(0.0));
}
Ok(())
}
pub fn denormalized(data: &ArrayView2<f64>) -> Result<Array2<f64>> {
let mut out = data.to_owned();
denormalize(&mut out)?;
Ok(out)
}
pub fn apply_mask(data: &mut Array2<f64>, mask: &Array2<u8>) -> Result<()> {
if data.shape() != mask.shape() {
return Err(SensorLMError::ShapeMismatch {
expected: data.shape().to_vec(),
actual: mask.shape().to_vec(),
});
}
for (d, m) in data.iter_mut().zip(mask.iter()) {
if *m == 1 {
*d = f64::NAN;
}
}
Ok(())
}
pub fn average_downsample_ct(data: &Array2<f64>, target_t: usize) -> Array2<f64> {
let (channels, t) = (data.nrows(), data.ncols());
assert_eq!(t % target_t, 0, "T must be divisible by target_t");
let factor = t / target_t;
let mut out = Array2::<f64>::zeros((channels, target_t));
for c in 0..channels {
for i in 0..target_t {
let slice = data.slice(ndarray::s![c, i * factor..(i + 1) * factor]);
out[[c, i]] = slice.mean().unwrap_or(0.0);
}
}
out
}
pub fn channel_stats(data: &Array2<f64>) -> Vec<(f64, f64, f64, f64)> {
let c = data.ncols();
(0..c)
.map(|ch| {
let col: Vec<f64> = data
.column(ch)
.iter()
.copied()
.filter(|v| !v.is_nan())
.collect();
if col.is_empty() {
return (f64::NAN, f64::NAN, f64::NAN, f64::NAN);
}
let n = col.len() as f64;
let mean = col.iter().sum::<f64>() / n;
let max = col.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let min = col.iter().cloned().fold(f64::INFINITY, f64::min);
let var = col.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
let std = var.sqrt();
(mean, max, min, std)
})
.collect()
}
pub fn f32_slice_to_normalised(raw: &[f32], t: usize, c: usize) -> Result<Array2<f64>> {
if raw.len() != t * c {
return Err(SensorLMError::ShapeMismatch {
expected: vec![t * c],
actual: vec![raw.len()],
});
}
let data_f64: Vec<f64> = raw.iter().map(|&x| x as f64).collect();
let mut arr = Array2::from_shape_vec((t, c), data_f64)
.map_err(|e| SensorLMError::DatasetError(e.to_string()))?;
normalize(&mut arr)?;
Ok(arr)
}
pub fn channel_name(idx: usize) -> &'static str {
FEATURE_NAMES.get(idx).copied().unwrap_or("unknown")
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
#[test]
fn test_roundtrip_normalise() {
let original = Array2::<f64>::from_elem((10, NUM_CHANNELS), 1.0);
let mut data = original.clone();
normalize(&mut data).unwrap();
denormalize(&mut data).unwrap();
for (orig, norm) in original.iter().zip(data.iter()) {
assert!((orig - norm).abs() < 1e-9 || *norm >= 0.0);
}
}
#[test]
fn test_non_negative_clamp() {
let mut data = Array2::<f64>::from_elem((5, NUM_CHANNELS), -100.0);
denormalize(&mut data).unwrap();
for &ch in NON_NEGATIVE_CHANNELS {
for t in 0..5 {
assert!(data[[t, ch]] >= 0.0, "channel {ch} should be >= 0");
}
}
}
#[test]
fn test_downsample() {
let data = Array2::<f64>::ones((NUM_CHANNELS, 1440));
let ds = average_downsample_ct(&data, 36);
assert_eq!(ds.shape(), &[NUM_CHANNELS, 36]);
}
}