sensorlm-rs 0.1.0

SensorLM – wearable sensor foundation model in Rust (Burn + WGPU)
Documentation
//! Physical constants, feature lists, and normalisation parameters.
//!
//! All values are ported directly from the Python reference implementation
//! (`sensorlm/constants.py`).

// ---------------------------------------------------------------------------
// Sensor channel definition
// ---------------------------------------------------------------------------

/// Names of the 34 features that are stored in the dataset and used for both
/// the ViT sensor encoder and the statistical captioning pipeline.
///
/// The feature order is significant: index `i` corresponds to channel `i` in
/// every `(T, C)` sensor tensor.
pub const FEATURE_NAMES: &[&str] = &[
    "HR",
    "eda_level_real",
    "leads_contact_counts",
    "steps",
    "jerk_auto",
    "log_energy",
    "covariance",
    "log_energy_ratio",
    "zero_crossing_std",
    "zero_crossing_avg",
    "axis_mean",
    "altim_std",
    "kurtosis",
    "sleep_coefficient",
    "wrist_temperatures",
    "rr_med",
    "sdnn0595",
    "rmssd0595",
    "pnn20",
    "coherence",
    "ShEnRR",
    "LF",
    "HF",
    "LF_HF",
    "VLF",
    "spectralEn",
    "percent_good",
    "sleep_stage_awake",
    "sleep_stage_light",
    "sleep_stage_deep",
    "sleep_stage_rem",
    "spo2",
    "spo2_confidence",
    "spo2_coverage",
];

/// Number of sensor channels (= `FEATURE_NAMES.len()`).
pub const NUM_CHANNELS: usize = 34;

/// Number of time-steps per sample (minutes in a day).
pub const TIME_STEPS: usize = 1440;

// ---------------------------------------------------------------------------
// Normalisation parameters (mean, std)
//
// Each entry is `[mean, std]`.  Normalised value:
//   z = (x - mean) / std
// Denormalised value:
//   x = z * std + mean
// ---------------------------------------------------------------------------

/// Population-level (mean, std) pairs for every channel in [`FEATURE_NAMES`].
pub const NORM_PARAMS: &[(f64, f64)] = &[
    // HR
    (75.958_6, 16.188_7),
    // eda_level_real
    (4.176_7, 5.589_3),
    // leads_contact_counts
    (226.486_4, 67.331_2),
    // steps
    (5.167_9, 18.892_6),
    // jerk_auto
    (203.467_2, 30.056_3),
    // log_energy
    (53.080_4, 49.652_6),
    // covariance
    (43.407_7, 13.952_9),
    // log_energy_ratio
    (44.848_3, 22.974_6),
    // zero_crossing_std
    (155.186_3, 28.237_8),
    // zero_crossing_avg
    (51.004_3, 37.475_6),
    // axis_mean
    (123.165_9, 21.471_0),
    // altim_std
    (0.004_2, 0.059_7),
    // kurtosis
    (105.595_4, 66.849_5),
    // sleep_coefficient
    (7.262_3, 5.394_6),
    // wrist_temperatures
    (31.674_5, 2.578_9),
    // rr_med
    (856.830_4, 160.118_1),
    // sdnn0595
    (64.800_3, 55.585_0),
    // rmssd0595
    (65.342_1, 74.783_1),
    // pnn20
    (0.566_7, 0.262_3),
    // coherence
    (0.180_8, 0.130_5),
    // ShEnRR
    (3.058_2, 0.667_3),
    // LF
    (1_551.837_6, 2_399.422_8),
    // HF
    (757.227_1, 1_873.923_9),
    // LF_HF
    (4.126_5, 4.506_6),
    // VLF
    (1_303.384_8, 1_906.101_7),
    // spectralEn
    (2.525_5, 0.393_1),
    // percent_good
    (0.484_6, 0.343_9),
    // sleep_stage_awake
    (0.042_4, 0.191_6),
    // sleep_stage_light
    (0.043_4, 0.202_1),
    // sleep_stage_deep
    (0.185_5, 0.383_0),
    // sleep_stage_rem
    (0.057_5, 0.230_1),
    // spo2
    (95.201_9, 2.464_6),
    // spo2_confidence
    (56.639_1, 42.106_4),
    // spo2_coverage
    (50.125_1, 19.097_1),
];

// ---------------------------------------------------------------------------
// Features that must be clamped to ≥ 0 after denormalisation
// ---------------------------------------------------------------------------

/// Indices of channels whose physical value cannot be negative.
/// These are clamped after denormalisation (steps and sleep_coefficient).
pub const NON_NEGATIVE_CHANNELS: &[usize] = &[
    3,  // steps
    13, // sleep_coefficient
];

// ---------------------------------------------------------------------------
// Channel groupings used by the captioning pipeline
// ---------------------------------------------------------------------------

/// A logical group of sensor channels (name displayed in captions + feature
/// index in [`FEATURE_NAMES`]).
#[derive(Debug, Clone)]
pub struct ChannelGroup {
    /// Human-readable category name, e.g. `"Heart"`.
    pub category: &'static str,
    /// (display_name, feature_index) pairs for primary channels.
    pub primary: &'static [(&'static str, usize)],
    /// (display_name, feature_index) pairs for randomly sampled channels.
    pub random: &'static [(&'static str, usize)],
    /// How many random channels to sample per caption.
    pub random_k: usize,
}

/// The four physiological channel groups used in captioning.
pub const CHANNEL_GROUPS: &[ChannelGroup] = &[
    ChannelGroup {
        category: "Heart",
        primary: &[
            ("heart rate", 0),           // HR
            ("hrv rr", 15),              // rr_med
            ("hrv shannon entropy rr", 20), // ShEnRR
            ("sdnn percentile", 16),     // sdnn0595
        ],
        random: &[
            ("hr at rest mean", 0),      // placeholder – HR again (no hr_at_rest in 34-feature set)
            ("hrv rr 80th percentile", 15),
            ("hrv shannon entropy rrd", 20),
            ("rmssd percentile mean", 17),
        ],
        random_k: 2,
    },
    ChannelGroup {
        category: "Activity",
        primary: &[
            ("steps", 3),
            ("jerk", 4),
            ("log energy", 5),
            ("kurtosis", 12),
        ],
        random: &[
            ("covariance", 6),
            ("log energy ratio", 7),
            ("zero crossing std", 8),
            ("zero crossing avg", 9),
            ("axis mean", 10),
            ("altim std", 11),
        ],
        random_k: 1,
    },
    ChannelGroup {
        category: "Sleep",
        primary: &[
            ("sleep coefficient", 13),
        ],
        random: &[],
        random_k: 0,
    },
    ChannelGroup {
        category: "EDA",
        primary: &[
            ("eda level", 1),
            ("skin temperature slope", 14),  // wrist_temperatures
            ("wrist temperatures", 14),
        ],
        random: &[
            ("leads contact counts", 2),
            ("ceda slope real micro siemens", 1),
        ],
        random_k: 1,
    },
];

// ---------------------------------------------------------------------------
// Caption token budgets
// ---------------------------------------------------------------------------

/// Number of tokens allocated to each caption type in the training pipeline.
pub const CAPTION_TOKEN_BUDGET: &[(&str, usize)] = &[
    ("low_level_caption", 512),
    ("middle_level_caption", 512),
    ("high_level_summary_caption", 256),
    ("high_level_all_caption", 1024),
    ("middle_low_level_caption", 1024),
    ("high_low_level_caption", 1024),
    ("high_middle_level_caption", 512),
    ("high_middle_low_level_caption", 1024),
];

// ---------------------------------------------------------------------------
// Model / training hyper-parameters (defaults mirroring the reference config)
// ---------------------------------------------------------------------------

/// Default vocabulary size used by the c4_en / T5 SentencePiece tokeniser.
pub const VOCAB_SIZE: usize = 32_000;

/// ViT-B hidden dimension.
pub const VIT_WIDTH: usize = 768;
/// ViT-B depth (number of transformer blocks).
pub const VIT_DEPTH: usize = 12;
/// ViT-B MLP expansion (4 × width).
pub const VIT_MLP_DIM: usize = 3072;
/// ViT-B number of attention heads.
pub const VIT_HEADS: usize = 12;

/// Patch height (time-axis): 10 minutes per patch.
pub const PATCH_H: usize = 10;
/// Patch width (channel-axis): 2 channels per patch.
pub const PATCH_W: usize = 2;

/// Number of patches along the time axis (1440 / 10).
pub const NUM_PATCHES_T: usize = TIME_STEPS / PATCH_H;
/// Number of patches along the channel axis (34 / 2 = 17; padded to even if needed).
pub const NUM_PATCHES_C: usize = (NUM_CHANNELS + PATCH_W - 1) / PATCH_W;
/// Total number of patches fed into the transformer.
pub const NUM_PATCHES: usize = NUM_PATCHES_T * NUM_PATCHES_C;

/// Shared embedding dimension for both modalities after final projection.
pub const EMBED_DIM: usize = 768;

/// SigLIP temperature initialisation value.
pub const TEMPERATURE_INIT: f32 = 10.0;
/// SigLIP bias initialisation value.
pub const BIAS_INIT: f32 = -10.0;

/// Default training batch size.
///
/// # ⚠ Memory warning
///
/// The ViT-B sensor encoder produces N = 2 448 patch tokens per sample.
/// Attention score tensors scale as `B × H × chunk × N`, so even with
/// `attn_chunk_size = 64` a batch of 8 samples at fp32 consumes:
///
/// ```text
/// 8 × 12 × 64 × 2448 × 4 bytes ≈ 60 MB  per chunk (forward only)
/// ```
///
/// **The Burn autodiff tape holds ALL chunk intermediates simultaneously
/// during the backward pass** — multiply by `ceil(N / chunk)` chunks and
/// by `depth` layers.  Keep `batch_size` ≤ 8 for ViT-B on a 16 GB GPU.
/// Use `--cpu` with a smaller model config for quick experiments.
pub const DEFAULT_BATCH_SIZE: usize = 8;
/// Default learning rate.
pub const DEFAULT_LR: f64 = 5e-4;
/// Default weight decay.
pub const DEFAULT_WD: f64 = 1e-4;
/// Adam β₂.
pub const ADAM_BETA2: f64 = 0.999;
/// Gradient clip norm.
pub const GRAD_CLIP_NORM: f64 = 1.0;
/// Total pre-training examples.
pub const TOTAL_EXAMPLES: usize = 50_000_000;