sensorlm-rs 0.1.0

SensorLM – wearable sensor foundation model in Rust (Burn + WGPU)
Documentation
//! Burn [`Dataset`] implementations for SensorLM.
//!
//! Two dataset types are provided:
//!
//! * [`SyntheticSensorDataset`] – Uses the built-in synthetic data generator.
//!   Good for unit tests, profiling, and quick experiments without real data.
//!
//! * [`CsvSensorDataset`] – Loads pre-processed sensor data from a CSV file.
//!   The file is expected to have one row per sample, with columns named after
//!   [`crate::constants::FEATURE_NAMES`] (normalised f32 values) plus a
//!   `"caption"` column containing the text.
//!
//! Both implement [`burn::data::dataset::Dataset`] so they can be wrapped by
//! [`burn::data::dataloader::DataLoaderBuilder`].

use std::path::Path;

use burn::data::dataset::Dataset;

use crate::constants::{NUM_CHANNELS, TIME_STEPS};
use crate::data::download::{generate_synthetic_dataset, SyntheticDataConfig};
use crate::error::{Result, SensorLMError};

// ---------------------------------------------------------------------------
// Shared item type
// ---------------------------------------------------------------------------

/// A single (sensor, caption) training pair.
#[derive(Debug, Clone)]
pub struct SensorTextItem {
    /// Normalised sensor values, flat `f32` slice of length `T × C`
    /// (row-major: index `t * C + c`).
    pub sensor: Vec<f32>,
    /// Tokenised caption as a `Vec<i32>` (token IDs, padded to `max_seq_len`).
    pub token_ids: Vec<i32>,
    /// Padding mask: `1` for real tokens, `0` for padding.
    pub attention_mask: Vec<i32>,
    /// Raw caption text (kept for debugging / evaluation).
    pub caption_text: String,
}

// ---------------------------------------------------------------------------
// Synthetic dataset
// ---------------------------------------------------------------------------

/// An in-memory dataset of synthetically generated sensor-text pairs.
///
/// Useful for smoke-testing the training pipeline without a real dataset.
///
/// # Example
///
/// ```rust,no_run
/// use sensorlm::data::dataset::SyntheticSensorDataset;
/// use burn::data::dataset::Dataset;
/// let ds = SyntheticSensorDataset::new(512, 42, 256);
/// println!("Dataset size: {}", ds.len());
/// ```
pub struct SyntheticSensorDataset {
    items: Vec<SensorTextItem>,
}

impl SyntheticSensorDataset {
    /// Create a new synthetic dataset.
    ///
    /// # Arguments
    ///
    /// * `num_samples` – Number of (sensor, caption) pairs to generate.
    /// * `seed`        – Random seed.
    /// * `max_seq_len` – Token sequence length (captions are padded / truncated).
    pub fn new(num_samples: usize, seed: u64, max_seq_len: usize) -> Self {
        let cfg = SyntheticDataConfig {
            num_samples,
            seed,
            add_circadian: true,
            add_missingness: true,
            missingness_rate: 0.05,
        };
        let raw = generate_synthetic_dataset(&cfg);

        let items = raw
            .into_iter()
            .map(|s| {
                let sensor: Vec<f32> = s.sensor.iter().copied().collect();

                // Very simple character-level "tokenisation" for the synthetic
                // dataset (a real run uses a SentencePiece tokeniser).
                let raw_ids: Vec<i32> = s
                    .caption
                    .chars()
                    .take(max_seq_len)
                    .map(|c| c as i32 % 32_000)
                    .collect();

                let len = raw_ids.len();
                let mut token_ids = raw_ids;
                token_ids.resize(max_seq_len, 1); // pad with id=1

                let mut attention_mask = vec![1i32; len];
                attention_mask.resize(max_seq_len, 0);

                SensorTextItem {
                    sensor,
                    token_ids,
                    attention_mask,
                    caption_text: s.caption,
                }
            })
            .collect();

        Self { items }
    }
}

impl Dataset<SensorTextItem> for SyntheticSensorDataset {
    fn get(&self, index: usize) -> Option<SensorTextItem> {
        self.items.get(index).cloned()
    }

    fn len(&self) -> usize {
        self.items.len()
    }
}

// ---------------------------------------------------------------------------
// CSV-backed dataset
// ---------------------------------------------------------------------------

/// A dataset loaded from a CSV file.
///
/// Expected CSV schema:
///
/// ```text
/// col_0, col_1, ..., col_33, caption
/// <f32>,  <f32>, ..., <f32>, "The person is walking..."
/// ```
///
/// There must be exactly `T × C` numeric columns followed by a `caption`
/// string column.  If your CSV has one row per time-step, pre-aggregate to
/// one row per sample before loading.
pub struct CsvSensorDataset {
    items: Vec<SensorTextItem>,
}

impl CsvSensorDataset {
    /// Load all rows from a CSV file.
    ///
    /// # Arguments
    ///
    /// * `path`        – Path to the `.csv` file.
    /// * `max_seq_len` – Target token sequence length.
    /// * `tokenize`    – A closure that converts a caption string into token IDs.
    pub fn from_csv<F>(path: &Path, max_seq_len: usize, tokenize: F) -> Result<Self>
    where
        F: Fn(&str) -> Vec<i32>,
    {
        let expected_sensor_len = TIME_STEPS * NUM_CHANNELS;
        let mut items = Vec::new();

        let mut rdr = csv::Reader::from_path(path)
            .map_err(|e| SensorLMError::DatasetError(e.to_string()))?;

        for result in rdr.records() {
            let record = result.map_err(|e| SensorLMError::DatasetError(e.to_string()))?;

            if record.len() < expected_sensor_len + 1 {
                return Err(SensorLMError::DatasetError(format!(
                    "Expected at least {} columns, got {}",
                    expected_sensor_len + 1,
                    record.len()
                )));
            }

            let sensor: Vec<f32> = (0..expected_sensor_len)
                .map(|i| {
                    record[i]
                        .trim()
                        .parse::<f32>()
                        .unwrap_or(0.0)
                })
                .collect();

            let caption = record[expected_sensor_len].trim().to_string();

            let mut token_ids = tokenize(&caption);
            let real_len = token_ids.len().min(max_seq_len);
            token_ids.truncate(real_len);
            let mut attn = vec![1i32; real_len];
            token_ids.resize(max_seq_len, 1);
            attn.resize(max_seq_len, 0);

            items.push(SensorTextItem {
                sensor,
                token_ids,
                attention_mask: attn,
                caption_text: caption,
            });
        }

        Ok(Self { items })
    }
}

impl Dataset<SensorTextItem> for CsvSensorDataset {
    fn get(&self, index: usize) -> Option<SensorTextItem> {
        self.items.get(index).cloned()
    }

    fn len(&self) -> usize {
        self.items.len()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_synthetic_dataset() {
        let ds = SyntheticSensorDataset::new(16, 99, 256);
        assert_eq!(ds.len(), 16);
        let item = ds.get(0).expect("first item");
        assert_eq!(item.sensor.len(), TIME_STEPS * NUM_CHANNELS);
        assert_eq!(item.token_ids.len(), 256);
        assert_eq!(item.attention_mask.len(), 256);
    }
}