sensorlm-rs 0.1.0

SensorLM – wearable sensor foundation model in Rust (Burn + WGPU)
Documentation
//! Sensor signal normalisation and denormalisation.
//!
//! The raw sensor values captured by wearable devices span very different
//! physical scales (e.g. heart rate ≈ 60–180 bpm vs. LF power ≈ 0–10 000).
//! Before feeding data into the ViT encoder every channel is z-score
//! normalised using population-level statistics.
//!
//! # Normalisation formula
//!
//! ```text
//! z[t, c] = (x[t, c] - mean[c]) / std[c]
//! ```
//!
//! # Denormalisation (used in caption generation)
//!
//! ```text
//! x[t, c] = z[t, c] * std[c] + mean[c]
//! ```
//!
//! After denormalisation certain channels are clamped to be non-negative
//! (e.g. `steps`, `sleep_coefficient`).

use ndarray::{Array2, ArrayView2};

use crate::constants::{FEATURE_NAMES, NORM_PARAMS, NON_NEGATIVE_CHANNELS, NUM_CHANNELS};
use crate::error::{Result, SensorLMError};

// ---------------------------------------------------------------------------
// Channel statistics
// ---------------------------------------------------------------------------

/// Per-channel normalisation statistics resolved from [`NORM_PARAMS`].
#[derive(Debug, Clone)]
pub struct ChannelStats {
    /// Population mean for each of the 34 channels.
    pub mean: Vec<f64>,
    /// Population standard deviation for each channel.
    pub std: Vec<f64>,
}

impl ChannelStats {
    /// Build [`ChannelStats`] from the compile-time [`NORM_PARAMS`] table.
    pub fn from_constants() -> Self {
        let mean: Vec<f64> = NORM_PARAMS.iter().map(|(m, _)| *m).collect();
        let std: Vec<f64> = NORM_PARAMS.iter().map(|(_, s)| *s).collect();
        Self { mean, std }
    }
}

// ---------------------------------------------------------------------------
// Normalise
// ---------------------------------------------------------------------------

/// Z-score normalise a `(T, C)` raw sensor array in-place.
///
/// `data` must have shape `[T, NUM_CHANNELS]`.  The function normalises along
/// the channel axis using population statistics from [`NORM_PARAMS`].
///
/// # Errors
///
/// Returns [`SensorLMError::ShapeMismatch`] if the number of columns ≠
/// `NUM_CHANNELS`.
pub fn normalize(data: &mut Array2<f64>) -> Result<()> {
    let (t, c) = (data.nrows(), data.ncols());
    if c != NUM_CHANNELS {
        return Err(SensorLMError::ShapeMismatch {
            expected: vec![t, NUM_CHANNELS],
            actual: vec![t, c],
        });
    }
    for ch in 0..NUM_CHANNELS {
        let (mean, std) = NORM_PARAMS[ch];
        if std == 0.0 {
            continue;
        }
        let mut col = data.column_mut(ch);
        col.mapv_inplace(|x| (x - mean) / std);
    }
    Ok(())
}

/// Z-score normalise a raw sensor array and return a new array.
///
/// See [`normalize`] for details.
pub fn normalized(data: &ArrayView2<f64>) -> Result<Array2<f64>> {
    let mut out = data.to_owned();
    normalize(&mut out)?;
    Ok(out)
}

// ---------------------------------------------------------------------------
// Denormalise
// ---------------------------------------------------------------------------

/// Reverse a previous call to [`normalize`].
///
/// Converts normalised values back to physical units and clamps channels
/// listed in [`NON_NEGATIVE_CHANNELS`] to `≥ 0`.
pub fn denormalize(data: &mut Array2<f64>) -> Result<()> {
    let (t, c) = (data.nrows(), data.ncols());
    if c != NUM_CHANNELS {
        return Err(SensorLMError::ShapeMismatch {
            expected: vec![t, NUM_CHANNELS],
            actual: vec![t, c],
        });
    }
    for ch in 0..NUM_CHANNELS {
        let (mean, std) = NORM_PARAMS[ch];
        let mut col = data.column_mut(ch);
        col.mapv_inplace(|z| z * std + mean);
    }
    // Clamp non-negative channels.
    for &ch in NON_NEGATIVE_CHANNELS {
        let mut col = data.column_mut(ch);
        col.mapv_inplace(|x| x.max(0.0));
    }
    Ok(())
}

/// Denormalise without mutating: returns a new owned array.
pub fn denormalized(data: &ArrayView2<f64>) -> Result<Array2<f64>> {
    let mut out = data.to_owned();
    denormalize(&mut out)?;
    Ok(out)
}

// ---------------------------------------------------------------------------
// Missingness handling
// ---------------------------------------------------------------------------

/// Apply a missingness mask, replacing imputed values with `NaN`.
///
/// `mask[t, c] == 1` signals that the value at `(t, c)` was imputed (not
/// observed).  This function sets those positions to `NaN` so they are
/// excluded from mean / statistics computations in the captioning pipeline.
///
/// Both arrays must have the same shape.
pub fn apply_mask(data: &mut Array2<f64>, mask: &Array2<u8>) -> Result<()> {
    if data.shape() != mask.shape() {
        return Err(SensorLMError::ShapeMismatch {
            expected: data.shape().to_vec(),
            actual: mask.shape().to_vec(),
        });
    }
    for (d, m) in data.iter_mut().zip(mask.iter()) {
        if *m == 1 {
            *d = f64::NAN;
        }
    }
    Ok(())
}

// ---------------------------------------------------------------------------
// Downsample
// ---------------------------------------------------------------------------

/// Average-pool a `(C, T)` array down to `(C, target_t)` time-steps.
///
/// Used by the structural caption generator to reduce the 1440 time-steps
/// to a manageable 36 points (factor 40).
///
/// # Panics
///
/// Panics if `T` is not divisible by `target_t`.
pub fn average_downsample_ct(data: &Array2<f64>, target_t: usize) -> Array2<f64> {
    let (channels, t) = (data.nrows(), data.ncols());
    assert_eq!(t % target_t, 0, "T must be divisible by target_t");
    let factor = t / target_t;
    let mut out = Array2::<f64>::zeros((channels, target_t));
    for c in 0..channels {
        for i in 0..target_t {
            let slice = data.slice(ndarray::s![c, i * factor..(i + 1) * factor]);
            out[[c, i]] = slice.mean().unwrap_or(0.0);
        }
    }
    out
}

// ---------------------------------------------------------------------------
// Compute per-channel stats (used by captioning)
// ---------------------------------------------------------------------------

/// Compute `(mean, max, min, std)` for every channel in a `(T, C)` array.
///
/// NaN values are ignored in all statistics.  Returns a vector of length `C`
/// where each entry is `(mean, max, min, std)`.  If a channel is entirely NaN
/// the tuple fields will be NaN.
pub fn channel_stats(data: &Array2<f64>) -> Vec<(f64, f64, f64, f64)> {
    let c = data.ncols();
    (0..c)
        .map(|ch| {
            let col: Vec<f64> = data
                .column(ch)
                .iter()
                .copied()
                .filter(|v| !v.is_nan())
                .collect();
            if col.is_empty() {
                return (f64::NAN, f64::NAN, f64::NAN, f64::NAN);
            }
            let n = col.len() as f64;
            let mean = col.iter().sum::<f64>() / n;
            let max = col.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
            let min = col.iter().cloned().fold(f64::INFINITY, f64::min);
            let var = col.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
            let std = var.sqrt();
            (mean, max, min, std)
        })
        .collect()
}

// ---------------------------------------------------------------------------
// Flat f32 ↔ ndarray helpers (for burn tensor creation)
// ---------------------------------------------------------------------------

/// Convert a flat `Vec<f32>` (row-major, shape `[T, C]`) into an ndarray
/// after normalising.
///
/// Returns a `(T, C)` [`Array2<f64>`] with z-score normalised values.
pub fn f32_slice_to_normalised(raw: &[f32], t: usize, c: usize) -> Result<Array2<f64>> {
    if raw.len() != t * c {
        return Err(SensorLMError::ShapeMismatch {
            expected: vec![t * c],
            actual: vec![raw.len()],
        });
    }
    let data_f64: Vec<f64> = raw.iter().map(|&x| x as f64).collect();
    let mut arr = Array2::from_shape_vec((t, c), data_f64)
        .map_err(|e| SensorLMError::DatasetError(e.to_string()))?;
    normalize(&mut arr)?;
    Ok(arr)
}

/// Return the human-readable name for a channel index.
///
/// Returns `"unknown"` for out-of-range indices.
pub fn channel_name(idx: usize) -> &'static str {
    FEATURE_NAMES.get(idx).copied().unwrap_or("unknown")
}

#[cfg(test)]
mod tests {
    use super::*;
    use ndarray::Array2;

    #[test]
    fn test_roundtrip_normalise() {
        let original = Array2::<f64>::from_elem((10, NUM_CHANNELS), 1.0);
        let mut data = original.clone();
        normalize(&mut data).unwrap();
        denormalize(&mut data).unwrap();
        // After round-trip every value (except clamped channels that might have
        // changed sign) should be very close to 1.0.
        for (orig, norm) in original.iter().zip(data.iter()) {
            assert!((orig - norm).abs() < 1e-9 || *norm >= 0.0);
        }
    }

    #[test]
    fn test_non_negative_clamp() {
        let mut data = Array2::<f64>::from_elem((5, NUM_CHANNELS), -100.0);
        denormalize(&mut data).unwrap();
        for &ch in NON_NEGATIVE_CHANNELS {
            for t in 0..5 {
                assert!(data[[t, ch]] >= 0.0, "channel {ch} should be >= 0");
            }
        }
    }

    #[test]
    fn test_downsample() {
        let data = Array2::<f64>::ones((NUM_CHANNELS, 1440));
        let ds = average_downsample_ct(&data, 36);
        assert_eq!(ds.shape(), &[NUM_CHANNELS, 36]);
    }
}