sensorlm-rs 0.1.0

SensorLM – wearable sensor foundation model in Rust (Burn + WGPU)
Documentation
//! Level-2 (structural) caption generation.
//!
//! Detects and describes **temporal patterns** within each sensor channel:
//!
//! * **Trends** – monotonically increasing / decreasing / stable segments
//!   identified by fitting linear regression over overlapping windows.
//! * **Anomalies** – significant peaks, spikes, and valleys found by a
//!   prominence-threshold peak detector.
//!
//! # Algorithm overview
//!
//! ```text
//! 1. Denormalise (T, C) array.
//! 2. Downsample time axis: 1440 → 36 points (factor 40, average pooling).
//! 3. For every channel in every group:
//!    a. Fit linear regression over windows of size 6, 8, 12 data points
//!       (with 50% overlap).  Classify each window as increasing / decreasing /
//!       stable using slope and range thresholds.
//!    b. Run prominence-based peak / valley detector on the downsampled signal.
//! 4. Sample up to `max_insights` per category, format with random templates.
//! 5. Concatenate all group captions.
//! ```

use ndarray::{Array2, ArrayView2};
use rand::{seq::SliceRandom, Rng};

use crate::constants::CHANNEL_GROUPS;
use crate::data::captioning::templates::{ANOMALY_TEMPLATES, TREND_TEMPLATES};
use crate::data::preprocessing::{average_downsample_ct, denormalized};
use crate::error::Result;

/// Generate a level-2 structural caption.
///
/// # Arguments
///
/// * `x_norm`            – Normalised `(T, C)` sensor array.
/// * `max_per_category`  – Maximum number of insight sentences per category
///   (default used in reference: 7).
/// * `rng`               – Random number generator.
pub fn generate_structural_caption<R: Rng>(
    x_norm: &ArrayView2<f64>,
    max_per_category: usize,
    rng: &mut R,
) -> Result<String> {
    // Denormalise to physical units.
    let x_phys = denormalized(x_norm)?; // (T=1440, C=34)

    // Transpose to (C, T) for downsampling, then back.
    let ct: Array2<f64> = x_phys.t().to_owned(); // (C, 1440)

    // Downsample: 1440 → 36 points (factor 40, 40-minute averages).
    const TARGET_T: usize = 36;
    const DOWNSAMPLE_SCALE: usize = 40;
    let ct_ds = average_downsample_ct(&ct, TARGET_T); // (C, 36)

    let mut caption = String::new();

    for group in CHANNEL_GROUPS {
        let mut insights: Vec<(usize, String)> = Vec::new(); // (original_order, sentence)

        for &(display_name, ch_idx) in group.primary {
            if ch_idx >= ct_ds.nrows() {
                continue;
            }
            let channel_data: Vec<f64> = ct_ds.row(ch_idx).iter().copied().collect();

            // Trend detection.
            let trends = identify_trends(&channel_data, DOWNSAMPLE_SCALE);
            for (start, end, trend_type, _slope, _delta, _seg) in &trends {
                insights.push((
                    *start,
                    describe_trend(display_name, trend_type, *start, *end, rng),
                ));
            }

            // Anomaly detection.
            let peaks_valleys = detect_peaks_valleys(&channel_data, DOWNSAMPLE_SCALE);
            for (minute, anomaly_type) in &peaks_valleys {
                insights.push((
                    *minute,
                    describe_anomaly(display_name, anomaly_type, *minute, rng),
                ));
            }
        }

        // Randomly subsample if over budget.
        if insights.len() > max_per_category {
            insights.shuffle(rng);
            insights.truncate(max_per_category);
            // Re-sort by time for readability.
            insights.sort_by_key(|(t, _)| *t);
        }

        let category_text: Vec<&str> = insights.iter().map(|(_, s)| s.as_str()).collect();
        caption.push_str(&format!("{}: {}\n", group.category, category_text.join(" ")));
    }

    Ok(caption)
}

// ---------------------------------------------------------------------------
// Trend detection via linear regression
// ---------------------------------------------------------------------------

/// Result of trend detection for a single window.
///
/// Fields: `(start_minute, end_minute, trend_type, slope, delta, segment_size)`.
type TrendResult = (usize, usize, String, f64, f64, usize);

/// Identify up to 3 non-overlapping trends in `data` using multi-scale
/// sliding-window linear regression.
///
/// Segment sizes mirror the reference: 6, 8, 12 downsampled points.
/// Each downsampled point represents `downsample_scale` minutes.
fn identify_trends(data: &[f64], downsample_scale: usize) -> Vec<TrendResult> {
    if data.is_empty() {
        return vec![];
    }

    let max_v = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
    let min_v = data.iter().cloned().fold(f64::INFINITY, f64::min);
    let range = (max_v - min_v).max(1e-9);

    // Slope thresholds for each segment size (scaled by range).
    let thresholds: &[(usize, f64)] = &[(6, 1.5), (8, 1.3), (12, 1.0)];
    let stable_threshold = 0.01 * range;

    let mut candidates: Vec<TrendResult> = Vec::new();

    for &(seg, scale) in thresholds {
        let slope_thresh = scale * range / 40.0;
        let step = seg / 2; // 50 % overlap
        let mut start = 0;
        while start + seg <= data.len() {
            let slice = &data[start..start + seg];
            let slope = linear_regression_slope(slice);
            let delta_val = slice[seg - 1] - slice[0];
            let start_min = (start + 1) * downsample_scale;
            let end_min = (start + seg) * downsample_scale;

            if slope > slope_thresh && delta_val > 0.2 * range {
                candidates.push((start_min, end_min, "increasing".into(), slope, delta_val, seg));
            } else if slope < -slope_thresh && (-delta_val) > 0.2 * range {
                candidates.push((start_min, end_min, "decreasing".into(), slope, -delta_val, seg));
            } else if slope.abs() < stable_threshold {
                let seg_range = slice.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
                    - slice.iter().cloned().fold(f64::INFINITY, f64::min);
                if seg_range < 0.1 * range {
                    candidates.push((start_min, end_min, "stable".into(), slope, slice[seg - 1], seg));
                }
            }

            start += step;
        }
    }

    // Sort by |delta| descending, then select up to 3 non-overlapping trends.
    candidates.sort_by(|a, b| b.4.partial_cmp(&a.4).unwrap_or(std::cmp::Ordering::Equal));

    let mut selected: Vec<TrendResult> = Vec::new();
    'outer: for cand in candidates {
        if selected.len() == 3 {
            break;
        }
        let (s1, e1, ..) = cand;
        for &(s2, e2, ..) in &selected {
            let overlap = overlap_fraction(s1, e1, s2, e2);
            if overlap > 0.3 {
                continue 'outer;
            }
        }
        selected.push(cand);
    }

    selected
}

/// Compute the fraction of the shorter segment that overlaps with another.
fn overlap_fraction(s1: usize, e1: usize, s2: usize, e2: usize) -> f64 {
    let ov = (e1.min(e2) as isize - s1.max(s2) as isize).max(0) as f64;
    let shorter = ((e1 - s1).min(e2 - s2)) as f64;
    if shorter == 0.0 { 0.0 } else { ov / shorter }
}

/// Ordinary-least-squares slope for an evenly spaced sequence.
fn linear_regression_slope(y: &[f64]) -> f64 {
    let n = y.len() as f64;
    let x_mean = (n - 1.0) / 2.0;
    let y_mean: f64 = y.iter().sum::<f64>() / n;
    let num: f64 = y
        .iter()
        .enumerate()
        .map(|(i, &yi)| (i as f64 - x_mean) * (yi - y_mean))
        .sum();
    let den: f64 = (0..y.len())
        .map(|i| (i as f64 - x_mean).powi(2))
        .sum();
    if den == 0.0 { 0.0 } else { num / den }
}

// ---------------------------------------------------------------------------
// Peak / valley detection
// ---------------------------------------------------------------------------

/// Detect significant peaks (spikes) and valleys (drops) in `data`.
///
/// Returns a vector of `(minute, event_type)` pairs.
fn detect_peaks_valleys(data: &[f64], downsample_scale: usize) -> Vec<(usize, String)> {
    if data.len() < 3 {
        return vec![];
    }

    let max_v = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
    let min_v = data.iter().cloned().fold(f64::INFINITY, f64::min);
    let mean_v: f64 = data.iter().sum::<f64>() / data.len() as f64;
    let range = (max_v - min_v).max(1e-9);

    const PROMINENCE_THRESHOLD: f64 = 0.5;
    const HEIGHT_THRESHOLD: f64 = 0.6;
    const DISTANCE: usize = 5;

    let prom_thresh = PROMINENCE_THRESHOLD * range;
    let height_thresh = HEIGHT_THRESHOLD * range + mean_v;
    let valley_thresh = -(mean_v + (1.0 - HEIGHT_THRESHOLD) * range);

    let mut results = Vec::new();

    // Peaks
    let peaks = find_peaks(data, prom_thresh, Some(height_thresh), DISTANCE);
    for p in peaks {
        results.push(((p + 1) * downsample_scale, "spike".to_string()));
    }

    // Valleys (invert the signal)
    let inv: Vec<f64> = data.iter().map(|x| -x).collect();
    let valleys = find_peaks(&inv, prom_thresh, Some(valley_thresh), DISTANCE);
    for v in valleys {
        results.push(((v + 1) * downsample_scale, "drop".to_string()));
    }

    results
}

/// Simple local-maximum peak finder with prominence and minimum height filters.
///
/// Returns indices of detected peaks.
fn find_peaks(
    data: &[f64],
    prominence_threshold: f64,
    height_threshold: Option<f64>,
    min_distance: usize,
) -> Vec<usize> {
    let n = data.len();
    let mut peaks: Vec<(usize, f64)> = Vec::new();

    for i in 1..n - 1 {
        if data[i] > data[i - 1] && data[i] > data[i + 1] {
            // Check height.
            if let Some(ht) = height_threshold {
                if data[i] < ht {
                    continue;
                }
            }
            // Approximate prominence: difference to the lowest surrounding base.
            let left_min = data[..i].iter().cloned().fold(f64::INFINITY, f64::min);
            let right_min = data[i + 1..].iter().cloned().fold(f64::INFINITY, f64::min);
            let prominence = data[i] - left_min.max(right_min);
            if prominence >= prominence_threshold {
                peaks.push((i, data[i]));
            }
        }
    }

    // Enforce minimum distance: greedily keep highest peaks.
    peaks.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
    let mut selected: Vec<usize> = Vec::new();
    for (idx, _) in peaks {
        if selected
            .iter()
            .all(|&s| (idx as isize - s as isize).unsigned_abs() >= min_distance)
        {
            selected.push(idx);
        }
    }
    selected
}

// ---------------------------------------------------------------------------
// Formatting helpers
// ---------------------------------------------------------------------------

fn describe_trend<R: Rng>(
    sensor_name: &str,
    trend_type: &str,
    start: usize,
    end: usize,
    rng: &mut R,
) -> String {
    let tmpl = TREND_TEMPLATES
        .choose(rng)
        .copied()
        .unwrap_or(TREND_TEMPLATES[0]);
    tmpl.replace("{sensor_name}", sensor_name)
        .replace("{trend_type}", trend_type)
        .replace("{start}", &start.to_string())
        .replace("{end}", &end.to_string())
}

fn describe_anomaly<R: Rng>(
    sensor_name: &str,
    anomaly: &str,
    time: usize,
    rng: &mut R,
) -> String {
    let tmpl = ANOMALY_TEMPLATES
        .choose(rng)
        .copied()
        .unwrap_or(ANOMALY_TEMPLATES[0]);
    tmpl.replace("{sensor_name}", sensor_name)
        .replace("{anomaly}", anomaly)
        .replace("{time}", &time.to_string())
}

#[cfg(test)]
mod tests {
    use super::*;
    use ndarray::Array2;
    use rand::{rngs::StdRng, SeedableRng};
    use crate::constants::NUM_CHANNELS;

    #[test]
    fn test_structural_caption_runs() {
        let x = Array2::<f64>::zeros((1440, NUM_CHANNELS));
        let mut rng = StdRng::seed_from_u64(7);
        let cap = generate_structural_caption(&x.view(), 7, &mut rng).unwrap();
        assert!(!cap.is_empty());
    }

    #[test]
    fn test_linreg_slope() {
        let y: Vec<f64> = (0..10).map(|i| i as f64).collect();
        let slope = linear_regression_slope(&y);
        assert!((slope - 1.0).abs() < 1e-9, "slope should be 1.0, got {slope}");
    }

    #[test]
    fn test_find_peaks() {
        let data = vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0];
        let peaks = find_peaks(&data, 0.5, None, 2);
        assert_eq!(peaks.len(), 2);
    }
}