sensorlm-rs 0.1.0

SensorLM – wearable sensor foundation model in Rust (Burn + WGPU)
Documentation
//! Level-1 (statistical) caption generation.
//!
//! Produces a text description of the **mean, maximum, minimum, and standard
//! deviation** of every sensor channel, grouped by physiological category.
//!
//! # Pipeline
//!
//! 1. Denormalise the input `(T × C)` array back to physical units.
//! 2. Optionally apply the missingness mask (set imputed values to NaN).
//! 3. For each physiological group in [`crate::constants::CHANNEL_GROUPS`]:
//!    a. Compute stats for the primary channels.
//!    b. Sample `random_k` additional channels from the random pool.
//!    c. Format each channel description using a randomly selected template.
//! 4. Concatenate all group descriptions into a single string.
//!
//! # Example output
//!
//! ```text
//! For Heart, heart rate mean, max, min, std are 72.1, 95.3, 58.4, 8.2.
//! hrv rr exhibits a mean of 820.4, with range 960.0 to 680.0 and a
//! standard deviation of 45.1.
//! For Activity, steps mean, max, min, std are 3.2, 12.0, 0.0, 2.1. ...
//! ```

use ndarray::{Array2, ArrayView2};
use rand::{seq::SliceRandom, Rng};

use crate::constants::CHANNEL_GROUPS;
use crate::data::preprocessing::{channel_stats, denormalized};
use crate::data::captioning::templates::LOW_LEVEL_TEMPLATES;
use crate::error::Result;

/// Generate a level-1 statistical caption.
///
/// # Arguments
///
/// * `x_norm`  – Normalised sensor array, shape `(T, C)`.
/// * `mask`    – Optional missingness mask, shape `(T, C)`.
///   `mask[t, c] == 1` means the value was imputed; set to `None` to include
///   all values.
/// * `rng`     – Random number generator used to pick templates and random
///   channel subsets.
///
/// # Returns
///
/// A multi-line string suitable for use as the `low_level_caption` text pair.
pub fn generate_statistical_caption<R: Rng>(
    x_norm: &ArrayView2<f64>,
    mask: Option<&Array2<u8>>,
    rng: &mut R,
) -> Result<String> {
    use crate::data::preprocessing::apply_mask;

    // Step 1 – Denormalise.
    let mut x_phys = denormalized(x_norm)?;

    // Step 2 – Optionally blank out imputed values.
    if let Some(m) = mask {
        apply_mask(&mut x_phys, m)?;
    }

    // Step 3 – Compute per-channel stats on the denormalised data.
    let stats = channel_stats(&x_phys); // Vec<(mean, max, min, std)>

    // Step 4 – Build caption.
    let mut parts = Vec::new();

    for group in CHANNEL_GROUPS {
        let mut group_parts = Vec::new();

        // Primary channels (always included).
        for &(display_name, ch_idx) in group.primary {
            let (mean, max, min, std) = stats[ch_idx];
            if [mean, max, min, std].iter().any(|v| v.is_nan()) {
                continue;
            }
            group_parts.push(describe_low_level(display_name, mean, max, min, std, rng));
        }

        // Random channels (sampled).
        if group.random_k > 0 && !group.random.is_empty() {
            let sample: Vec<_> = group
                .random
                .choose_multiple(rng, group.random_k)
                .collect();
            for &&(display_name, ch_idx) in &sample {
                let (mean, max, min, std) = stats[ch_idx];
                if [mean, max, min, std].iter().any(|v| v.is_nan()) {
                    continue;
                }
                group_parts.push(describe_low_level(display_name, mean, max, min, std, rng));
            }
        }

        if !group_parts.is_empty() {
            parts.push(format!("For {}, {}\n", group.category, group_parts.join(" ")));
        }
    }

    Ok(parts.concat())
}

// ---------------------------------------------------------------------------
// Internal helpers
// ---------------------------------------------------------------------------

/// Pick a random low-level template and fill in the placeholders.
fn describe_low_level<R: Rng>(
    name: &str,
    mean_val: f64,
    max_val: f64,
    min_val: f64,
    std_val: f64,
    rng: &mut R,
) -> String {
    let tmpl = LOW_LEVEL_TEMPLATES.choose(rng).copied().unwrap_or(LOW_LEVEL_TEMPLATES[0]);
    tmpl.replace("{name}", name)
        .replace("{mean_val:.1}", &format!("{mean_val:.1}"))
        .replace("{max_val:.1}", &format!("{max_val:.1}"))
        .replace("{min_val:.1}", &format!("{min_val:.1}"))
        .replace("{std_val:.1}", &format!("{std_val:.1}"))
}

#[cfg(test)]
mod tests {
    use super::*;
    use ndarray::Array2;
    use rand::SeedableRng;
    use rand::rngs::StdRng;
    use crate::constants::NUM_CHANNELS;

    #[test]
    fn test_statistical_caption_runs() {
        let x = Array2::<f64>::zeros((1440, NUM_CHANNELS));
        let mut rng = StdRng::seed_from_u64(42);
        let cap = generate_statistical_caption(&x.view(), None, &mut rng).unwrap();
        assert!(!cap.is_empty(), "Caption must be non-empty");
        assert!(cap.contains("Heart"), "Caption must mention Heart group");
    }
}