sensorlm-rs 0.1.0

SensorLM – wearable sensor foundation model in Rust (Burn + WGPU)
Documentation
//! Dataset download utilities.
//!
//! The original SensorLM training data (59.7 M hours from 103 000 participants)
//! is internal to Google and is not publicly available.  This module provides:
//!
//! 1. A **synthetic data generator** that creates plausible random wearable
//!    tensors for local development, unit tests, and integration tests.
//! 2. **Downloaders** for two publicly available datasets that can substitute
//!    as proof-of-concept training data:
//!    * [PAMAP2](https://archive.ics.uci.edu/dataset/231/pamap2+physical+activity+monitoring)
//!    * [WESAD](https://archive.ics.uci.edu/dataset/465/wesad+wearable+stress+and+affect+detection)
//! 3. A generic [`download_file`] helper with SHA-256 checksum verification
//!    and a progress bar.

use std::{
    fs,
    io::Write,
    path::{Path, PathBuf},
};

use indicatif::{ProgressBar, ProgressStyle};
use sha2::{Digest, Sha256};
use tracing::{info, warn};

use crate::error::{Result, SensorLMError};

// ---------------------------------------------------------------------------
// Known public datasets
// ---------------------------------------------------------------------------

/// Registry entry for a downloadable dataset.
#[derive(Debug, Clone)]
pub struct DatasetEntry {
    /// Human-readable name.
    pub name: &'static str,
    /// Primary download URL.
    pub url: &'static str,
    /// Expected SHA-256 hex digest of the downloaded file (empty = skip check).
    pub sha256: &'static str,
    /// Total uncompressed size in bytes (used to size the progress bar).
    pub size_bytes: u64,
}

/// Publicly downloadable datasets compatible with this pipeline.
pub const KNOWN_DATASETS: &[DatasetEntry] = &[
    DatasetEntry {
        name: "PAMAP2",
        url: "https://archive.ics.uci.edu/static/public/231/pamap2+physical+activity+monitoring.zip",
        sha256: "",   // skip verification – checksum not published by UCI
        size_bytes: 680_000_000,
    },
    DatasetEntry {
        name: "WESAD",
        url: "https://uni-siegen.sciebo.de/s/HGdUkoNlW1Ub0Gx/download",
        sha256: "",
        size_bytes: 1_800_000_000,
    },
];

/// Resolve a dataset entry by name (case-insensitive).
pub fn find_dataset(name: &str) -> Option<&'static DatasetEntry> {
    KNOWN_DATASETS
        .iter()
        .find(|d| d.name.to_ascii_lowercase() == name.to_ascii_lowercase())
}

// ---------------------------------------------------------------------------
// Generic HTTP downloader
// ---------------------------------------------------------------------------

/// Download a file from `url` to `dest_path`.
///
/// * Shows a progress bar via `indicatif`.
/// * Verifies the SHA-256 digest if `expected_sha256` is non-empty.
/// * Skips the download if the file already exists **and** has the correct
///   checksum.
///
/// # Errors
///
/// Returns an error if the HTTP request fails, the write fails, or the
/// checksum does not match.
pub fn download_file(url: &str, dest_path: &Path, expected_sha256: &str) -> Result<()> {
    // Check if already downloaded with correct checksum.
    if dest_path.exists() && !expected_sha256.is_empty() {
        let existing_hash = sha256_of_file(dest_path)?;
        if existing_hash.eq_ignore_ascii_case(expected_sha256) {
            info!("✓ {} already downloaded and verified.", dest_path.display());
            return Ok(());
        } else {
            warn!(
                "Checksum mismatch for {}: expected {} got {}. Re-downloading.",
                dest_path.display(),
                expected_sha256,
                existing_hash
            );
        }
    }

    info!("Downloading {} → {}", url, dest_path.display());

    // Create parent directories.
    if let Some(parent) = dest_path.parent() {
        fs::create_dir_all(parent)?;
    }

    let client = reqwest::blocking::Client::builder()
        .timeout(std::time::Duration::from_secs(3600))
        .build()
        .map_err(|e| SensorLMError::DownloadError { url: url.to_string(), source: e })?;

    let mut response = client
        .get(url)
        .send()
        .map_err(|e| SensorLMError::DownloadError { url: url.to_string(), source: e })?;

    let total_bytes = response.content_length().unwrap_or(0);
    let pb = ProgressBar::new(total_bytes);
    pb.set_style(
        ProgressStyle::with_template(
            "{spinner:.green} [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})",
        )
        .unwrap()
        .progress_chars("=>-"),
    );

    let mut file = fs::File::create(dest_path)?;
    let mut downloaded = 0u64;
    let mut buf = vec![0u8; 8192];

    loop {
        use std::io::Read;
        let n = response
            .read(&mut buf)
            .map_err(|e| SensorLMError::Io(e))?;
        if n == 0 {
            break;
        }
        file.write_all(&buf[..n])?;
        downloaded += n as u64;
        pb.set_position(downloaded);
    }
    pb.finish_with_message("Download complete");

    // Verify checksum.
    if !expected_sha256.is_empty() {
        let actual_hash = sha256_of_file(dest_path)?;
        if !actual_hash.eq_ignore_ascii_case(expected_sha256) {
            fs::remove_file(dest_path)?;
            return Err(SensorLMError::DatasetError(format!(
                "SHA-256 mismatch: expected {expected_sha256}, got {actual_hash}"
            )));
        }
        info!("✓ Checksum verified.");
    }

    Ok(())
}

/// Compute the SHA-256 hex digest of a file on disk.
fn sha256_of_file(path: &Path) -> Result<String> {
    let bytes = fs::read(path)?;
    let mut hasher = Sha256::new();
    hasher.update(&bytes);
    Ok(hex::encode(hasher.finalize()))
}

// ---------------------------------------------------------------------------
// Default data directory
// ---------------------------------------------------------------------------

/// Return the platform-appropriate data directory for sensorlm-rs.
///
/// On Linux / macOS: `~/.local/share/sensorlm`
/// On Windows: `%APPDATA%\sensorlm`
pub fn default_data_dir() -> PathBuf {
    dirs::data_local_dir()
        .unwrap_or_else(|| PathBuf::from("."))
        .join("sensorlm")
}

// ---------------------------------------------------------------------------
// Synthetic dataset generator
// ---------------------------------------------------------------------------

use ndarray::Array2;
use rand::{Rng as _, SeedableRng};
use rand_distr::{Distribution, Normal};

/// Parameters controlling the synthetic data generator.
#[derive(Debug, Clone)]
pub struct SyntheticDataConfig {
    /// Number of samples (individuals) to generate.
    pub num_samples: usize,
    /// Random seed for reproducibility.
    pub seed: u64,
    /// Whether to add realistic circadian structure to heart rate.
    pub add_circadian: bool,
    /// Whether to simulate missing data (realistic wearable dropout).
    pub add_missingness: bool,
    /// Fraction of time-steps to mark as missing [0, 1].
    pub missingness_rate: f64,
}

impl Default for SyntheticDataConfig {
    fn default() -> Self {
        Self {
            num_samples: 1000,
            seed: 42,
            add_circadian: true,
            add_missingness: true,
            missingness_rate: 0.1,
        }
    }
}

/// A single synthetic wearable sample.
#[derive(Debug, Clone)]
pub struct SyntheticSample {
    /// Normalised sensor tensor, shape `(TIME_STEPS, NUM_CHANNELS)`.
    pub sensor: Array2<f32>,
    /// Missingness mask, shape `(TIME_STEPS, NUM_CHANNELS)`.  1 = imputed.
    pub mask: Array2<u8>,
    /// Pre-generated caption (high-level summary).
    pub caption: String,
    /// Sample ID.
    pub id: usize,
}

/// Generate a batch of synthetic wearable samples.
///
/// Each sample simulates one 24-hour recording window with:
///
/// * Normally distributed channel noise scaled by the population parameters
///   in [`NORM_PARAMS`].
/// * A sinusoidal circadian rhythm on heart rate and step count (optional).
/// * Random missingness blocks mimicking sensor dropout (optional).
pub fn generate_synthetic_dataset(cfg: &SyntheticDataConfig) -> Vec<SyntheticSample> {
    use crate::constants::{NUM_CHANNELS, NORM_PARAMS, TIME_STEPS};

    let mut rng = rand::rngs::StdRng::seed_from_u64(cfg.seed);
    let mut samples = Vec::with_capacity(cfg.num_samples);

    for id in 0..cfg.num_samples {
        let mut sensor = Array2::<f32>::zeros((TIME_STEPS, NUM_CHANNELS));
        let mut mask = Array2::<u8>::zeros((TIME_STEPS, NUM_CHANNELS));

        for ch in 0..NUM_CHANNELS {
            let (_mean, _std) = NORM_PARAMS[ch]; // reserved for real normalisation
            let noise = Normal::new(0.0f64, 0.3).unwrap();

            for t in 0..TIME_STEPS {
                let base: f64 = noise.sample(&mut rng);

                // Add a gentle circadian sine wave on HR (channel 0) and
                // steps (channel 3).
                let circadian = if cfg.add_circadian && (ch == 0 || ch == 3) {
                    0.5 * (2.0 * std::f64::consts::PI * t as f64 / TIME_STEPS as f64).sin()
                } else {
                    0.0
                };

                // Store as z-score (already normalised by construction).
                sensor[[t, ch]] = (base + circadian) as f32;
            }

            // Simulate missingness blocks.
            if cfg.add_missingness {
                let n_missing = ((TIME_STEPS as f64 * cfg.missingness_rate) as usize).max(1);
                for _ in 0..n_missing {
                    let t: usize = rng.gen_range(0..TIME_STEPS);
                    mask[[t, ch]] = 1;
                }
            }
        }

        let caption = format!(
            "Synthetic 24-hour recording for individual {id}. \
             Heart rate shows typical circadian variation. \
             Activity patterns reflect normal daily movement."
        );

        samples.push(SyntheticSample { sensor, mask, caption, id });
    }

    samples
}