sensorlm-rs 0.1.0

SensorLM – wearable sensor foundation model in Rust (Burn + WGPU)
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
//! Hierarchical configuration structs for every subsystem.
//!
//! All structs implement [`serde::Serialize`] / [`serde::Deserialize`] so they
//! can be persisted to / loaded from JSON config files.
//!
//! The defaults mirror the reference Python configuration exactly.

use serde::{Deserialize, Serialize};

use crate::constants::*;

// ===========================================================================
// Sensor encoder (ViT)
// ===========================================================================

/// Configuration for the Vision Transformer sensor encoder.
///
/// The encoder treats wearable sensor data as a 2-D grid
/// `(TIME_STEPS × NUM_CHANNELS)` and divides it into rectangular patches of
/// shape `(patch_h × patch_w)`.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SensorEncoderConfig {
    /// Number of time-steps in the input signal (default: 1440 = 24 h × 60 min).
    pub time_steps: usize,
    /// Number of sensor channels (features) per time-step (default: 34).
    pub num_channels: usize,
    /// Patch height (time axis), default: 10 minutes.
    pub patch_h: usize,
    /// Patch width (channel axis), default: 2 channels.
    pub patch_w: usize,
    /// Transformer hidden dimension (ViT-B = 768).
    pub d_model: usize,
    /// Number of transformer layers (ViT-B = 12).
    pub depth: usize,
    /// Number of attention heads per layer (ViT-B = 12).
    pub num_heads: usize,
    /// Feed-forward MLP hidden dimension (ViT-B = 3072).
    pub mlp_dim: usize,
    /// Dropout probability applied inside each transformer block.
    pub dropout: f64,
    /// Type of sequence pooling used after the transformer.
    /// `"map"` (Multihead Attention Pooling) is the default and matches the
    /// reference implementation.  `"gap"` (global average pooling) is a
    /// cheaper alternative.
    pub pool_type: PoolType,
    /// Whether to zero-initialise the output projection in the MAP head.
    pub head_zeroinit: bool,
    /// Chunked-attention window size (number of query rows per chunk).
    ///
    /// Limits **forward-pass** peak attention memory from `O(B·H·N²)` to
    /// `O(B·H·chunk·N)` and keeps individual WGPU GPU dispatches small
    /// enough to avoid OS watchdog (TDR) timeouts.
    ///
    /// **⚠ Training caveat — chunking does NOT save backward memory.**
    /// Burn's autodiff tape records every intermediate tensor produced inside
    /// the chunk loop (`q_chunk`, `scores`, `attn`, `chunk_out`).  All chunks
    /// for all layers are kept alive simultaneously until `loss.backward()`
    /// completes.  True backward memory savings require gradient checkpointing,
    /// which is not yet implemented.
    ///
    /// Rule of thumb for GPU VRAM (fp32, ViT-B, N = 2448, H = 12):
    ///
    /// | chunk | fwd attn @ B=4 | fwd attn @ B=8 |
    /// |-------|----------------|----------------|
    /// |  2448 (off) | 4.3 GB | 8.6 GB         |
    /// |   256 | 450 MB         | 900 MB         |
    /// |   128 | 225 MB         | 450 MB         |
    /// |    64 | 112 MB         | 225 MB         |
    ///
    /// Set to `0` to disable chunking (full N×N matrix — **not recommended
    /// on GPU** due to TDR risk and peak memory).
    ///
    /// Default: `64`.
    pub attn_chunk_size: usize,
}

impl Default for SensorEncoderConfig {
    fn default() -> Self {
        Self {
            time_steps: TIME_STEPS,
            num_channels: NUM_CHANNELS,
            patch_h: PATCH_H,
            patch_w: PATCH_W,
            d_model: VIT_WIDTH,
            depth: VIT_DEPTH,
            num_heads: VIT_HEADS,
            mlp_dim: VIT_MLP_DIM,
            dropout: 0.0,
            pool_type: PoolType::Map,
            head_zeroinit: false,
            attn_chunk_size: 64,
        }
    }
}

impl SensorEncoderConfig {
    /// Total number of patches = (time_steps / patch_h) × (num_channels / patch_w).
    ///
    /// Channel dimension is padded up to the next multiple of `patch_w` if
    /// `num_channels` is not evenly divisible.
    pub fn num_patches(&self) -> usize {
        let pt = self.time_steps / self.patch_h;
        let pc = (self.num_channels + self.patch_w - 1) / self.patch_w;
        pt * pc
    }
}

// ===========================================================================
// Named model-size presets
// ===========================================================================

/// Named ViT model-size variants, matching the standard ViT paper dimensions.
///
/// Memory figures assume fp32, N = 2448 patches, chunk = 64, and cover
/// *attention score/weight tensors for one transformer layer* (the practical
/// backward-pass peak).  Total GPU memory is 3–5× higher once weights,
/// activations, and Adam optimizer states are included.
///
/// | Size  | d_model | heads | ~params | per-layer bwd B=16 | per-layer bwd B=4 |
/// |-------|---------|-------|---------|--------------------|-------------------|
/// | Tiny  |   192   |   3   |  ~11 M  |  2.1 GB            | 0.5 GB            |
/// | Small |   384   |   6   |  ~44 M  |  4.4 GB            | 1.1 GB            |
/// | Base  |   768   |  12   | ~205 M  | 17.5 GB ✗          | 2.2 GB            |
///
/// Recommended `--batch-size` per preset (WGPU / Metal, 16 GB device):
/// - `tiny`:  up to **16** — comfortable; per-layer bwd ≈ 2.1 GB
/// - `small`: up to **8**  — comfortable; per-layer bwd ≈ 2.2 GB
/// - `base`:  up to **4**  — per-layer bwd ≈ 2.2 GB; total ≈ 10 GB
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum ModelSize {
    /// ViT-Ti: d=192, depth=12, heads=3, mlp=768. Fits in ~2 GB VRAM.
    #[default]
    Tiny,
    /// ViT-S: d=384, depth=12, heads=6, mlp=1536. Fits in ~6 GB VRAM.
    Small,
    /// ViT-B: d=768, depth=12, heads=12, mlp=3072. Requires ≥ 16 GB VRAM.
    Base,
}

impl ModelSize {
    /// Return the transformer hidden dimension for this size.
    pub fn d_model(self) -> usize {
        match self {
            Self::Tiny  => 192,
            Self::Small => 384,
            Self::Base  => VIT_WIDTH, // 768
        }
    }

    /// Return the number of transformer layers.
    pub fn depth(self) -> usize {
        12 // same across all ViT variants
    }

    /// Return the number of attention heads.
    pub fn num_heads(self) -> usize {
        match self {
            Self::Tiny  => 3,
            Self::Small => 6,
            Self::Base  => VIT_HEADS, // 12
        }
    }

    /// Return the MLP hidden dimension (4 × d_model).
    pub fn mlp_dim(self) -> usize {
        self.d_model() * 4
    }

    /// Build a [`SensorEncoderConfig`] for this size with sensible defaults.
    pub fn sensor_encoder_config(self) -> SensorEncoderConfig {
        SensorEncoderConfig {
            d_model: self.d_model(),
            depth:   self.depth(),
            num_heads: self.num_heads(),
            mlp_dim: self.mlp_dim(),
            // All sizes use chunking.  The per-dispatch attention tensor is
            // (B, H, chunk, N) × 4 bytes — its size depends on H, not d_model,
            // so full attention (chunk=0) would still produce a ≥1 GB kernel at
            // B=16 for Tiny (H=3) which risks an OS GPU watchdog (TDR) timeout.
            //   chunk=128, B=16, H=3:  dispatch = 16×3×128×2448×4 ≈ 144 MB ✓
            //   chunk=64,  B=16, H=3:  dispatch = 16×3× 64×2448×4 ≈  72 MB ✓
            attn_chunk_size: 64, // safe for all sizes and recommended batch sizes
            ..SensorEncoderConfig::default()
        }
    }

    /// Build a [`TextEncoderConfig`] for this size.
    pub fn text_encoder_config(self) -> TextEncoderConfig {
        TextEncoderConfig {
            d_model:   self.d_model(),
            depth:     self.depth(),
            num_heads: self.num_heads(),
            mlp_dim:   self.mlp_dim(),
            out_dim:   Some(self.d_model()), // embed_dim matches d_model
            ..TextEncoderConfig::default()
        }
    }

    /// Build a complete [`SensorLMConfig`] for this size.
    pub fn sensorlm_config(self) -> SensorLMConfig {
        SensorLMConfig {
            sensor_encoder: self.sensor_encoder_config(),
            text_encoder:   self.text_encoder_config(),
            embed_dim:      self.d_model(),
            ..SensorLMConfig::default()
        }
    }

    /// Human-readable approximate parameter count for both towers combined.
    pub fn approx_params(self) -> &'static str {
        match self {
            Self::Tiny  => "~11 M",
            Self::Small => "~44 M",
            Self::Base  => "~205 M",
        }
    }
}

/// Sequence-level pooling strategy after the ViT encoder.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum PoolType {
    /// **Multihead Attention Pooling** – a learnable probe token attends over
    /// all patch tokens (reference implementation default, matches
    /// [`crate::model::sensor_encoder`]).
    Map,
    /// **Global Average Pooling** – mean over the patch-token sequence
    /// (cheaper, slightly lower quality).
    Gap,
}

// ===========================================================================
// Text encoder
// ===========================================================================

/// Configuration for the text transformer encoder.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextEncoderConfig {
    /// Vocabulary size (default: 32 000, c4_en SentencePiece vocabulary).
    pub vocab_size: usize,
    /// Maximum token sequence length.
    pub max_seq_len: usize,
    /// Transformer hidden dimension (ViT-B = 768).
    pub d_model: usize,
    /// Number of transformer layers (ViT-B = 12).
    pub depth: usize,
    /// Number of attention heads per layer.
    pub num_heads: usize,
    /// Feed-forward MLP hidden dimension.
    pub mlp_dim: usize,
    /// Dropout probability.
    pub dropout: f64,
    /// Output projection dimension.  `None` means no projection (identity).
    pub out_dim: Option<usize>,
}

impl Default for TextEncoderConfig {
    fn default() -> Self {
        Self {
            vocab_size: VOCAB_SIZE,
            max_seq_len: 1024,
            d_model: VIT_WIDTH,
            depth: VIT_DEPTH,
            num_heads: VIT_HEADS,
            mlp_dim: VIT_MLP_DIM,
            dropout: 0.0,
            out_dim: Some(EMBED_DIM),
        }
    }
}

// ===========================================================================
// Two-tower SensorLM
// ===========================================================================

/// Top-level configuration for the combined SensorLM model.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SensorLMConfig {
    /// Sensor (image) encoder configuration.
    pub sensor_encoder: SensorEncoderConfig,
    /// Text encoder configuration.
    pub text_encoder: TextEncoderConfig,
    /// Shared embedding dimensionality (must match both encoder outputs).
    pub embed_dim: usize,
    /// Initial value of the SigLIP temperature scalar (log-scale before exp).
    pub temperature_init: f32,
    /// Initial value of the SigLIP bias scalar.
    pub bias_init: f32,
}

impl Default for SensorLMConfig {
    fn default() -> Self {
        Self {
            sensor_encoder: SensorEncoderConfig::default(),
            text_encoder: TextEncoderConfig::default(),
            embed_dim: EMBED_DIM,
            temperature_init: TEMPERATURE_INIT,
            bias_init: BIAS_INIT,
        }
    }
}

// ===========================================================================
// Training
// ===========================================================================

/// Learning-rate schedule type.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LrScheduleType {
    /// Inverse-square-root schedule with linear warm-up and cool-down.
    /// Matches the reference `decay_type='rsqrt'` setting.
    RsqrtWithWarmupCooldown,
    /// Cosine annealing.
    Cosine,
    /// Constant learning rate (no schedule).
    Constant,
}

/// Training hyperparameters.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingConfig {
    /// Model size preset (`tiny` / `small` / `base`).
    ///
    /// Overrides `SensorLMConfig` when passed through the CLI.  Building
    /// a config from a preset is the recommended way to avoid mismatched
    /// `d_model` / `embed_dim` values between the two towers.
    pub model_size: ModelSize,
    /// Total number of gradient update steps.
    pub total_steps: usize,
    /// Mini-batch size (default: 8).
    pub batch_size: usize,
    /// Peak learning rate (default: 5 × 10⁻⁴).
    pub lr: f64,
    /// AdamW weight decay (default: 1 × 10⁻⁴).
    pub weight_decay: f64,
    /// Adam β₂ (default: 0.999, reference uses `scale_by_adam b2=0.999`).
    pub beta2: f64,
    /// Adam β₁.
    pub beta1: f64,
    /// Adam ε.
    pub epsilon: f64,
    /// Gradient clip norm (default: 1.0).
    pub grad_clip_norm: f64,
    /// Fraction of total steps used for linear warm-up (default: 0.2).
    pub warmup_fraction: f64,
    /// Fraction of total steps used for cool-down (default: 0.2).
    pub cooldown_fraction: f64,
    /// LR schedule type.
    pub lr_schedule: LrScheduleType,
    /// Save a checkpoint every N steps.
    pub checkpoint_every: usize,
    /// Log metrics every N steps.
    pub log_every: usize,
    /// Random seed.
    pub seed: u64,
    /// Caption type key to use during this training run.
    pub caption_key: CaptionKey,
    /// Path to SentencePiece tokeniser model file.
    pub tokenizer_path: String,
    /// Directory to write checkpoints / logs.
    pub artifact_dir: String,
    /// Directory containing the dataset (Parquet or raw files).
    pub data_dir: String,
    /// Number of DataLoader worker threads for CPU-side data preparation.
    ///
    /// Must be **≥ 1** — Burn's `PartialDataset::split` divides the dataset
    /// length by `num_workers`, so `0` causes a divide-by-zero panic.
    ///
    /// The WGPU backend (including Metal on macOS) is internally thread-safe:
    /// worker threads can call `Tensor::from_floats(…, &device)` safely.
    /// 2 workers is a reasonable default; increase on machines with many CPU
    /// cores and fast NVMe storage.  Use 1 if you observe data-loading becoming
    /// the training bottleneck (rare with synthetic data).
    pub num_workers: usize,
    /// Available GPU VRAM in gigabytes.
    ///
    /// When set the pre-flight guard derives the attention-tensor budget as
    /// `vram_gb / 3` and **auto-caps `batch_size`** to the largest value that
    /// fits, so you never have to tune `--batch-size` manually.
    ///
    /// Memory split used (all figures are estimates):
    /// ```text
    /// ┌─────────────────────────────────────────────────────┐
    /// │  1/3 → attention score/weight tensors (one layer)  │
    /// │  1/3 → model weights + gradients + Adam states     │
    /// │  1/3 → non-attention activations + OS/driver slack │
    /// └─────────────────────────────────────────────────────┘
    /// ```
    ///
    /// Examples for ViT-B (base), depth=12, H=12, chunk=64, N=2448:
    ///
    /// The peak memory is `depth × per-layer` because Burn's forward pass
    /// builds autodiff tape for **all** transformer layers before `backward()`
    /// starts.  70% of VRAM is budgeted for this tape; the rest covers
    /// model weights + gradients + Adam states + other activations.
    ///
    /// | VRAM | attn budget (×0.7) | max batch | all-layers peak |
    /// |------|-------------------|-----------|-----------------|
    /// |  8 GB |  5.6 GB          |     1     |   6.6 GB        |
    /// | 16 GB | 11.2 GB          |     1     |   6.6 GB        |
    /// | 24 GB | 16.8 GB          |     2     |  13.1 GB        |
    /// | 32 GB | 22.4 GB          |     3     |  19.7 GB        |
    /// | 48 GB | 33.6 GB          |     5     |  32.8 GB        |
    /// | 80 GB | 56.0 GB          |     8     |  52.4 GB        |
    pub vram_gb: Option<f64>,
    /// Skip the pre-flight VRAM safety check and proceed even if the
    /// estimated attention memory exceeds the computed limit.
    ///
    /// Use this only when you are certain your GPU has enough free VRAM.
    /// You accept full responsibility for OOM errors or GPU driver crashes.
    pub skip_vram_check: bool,

    /// Print Burn's `═══ Learner Summary ═══` table after training.
    ///
    /// Disabled by default to keep the terminal output clean.
    /// Pass `--summary` on the CLI to enable.
    pub show_summary: bool,
}

impl Default for TrainingConfig {
    fn default() -> Self {
        let total_examples = TOTAL_EXAMPLES;
        let batch_size = DEFAULT_BATCH_SIZE;
        let total_steps = total_examples / batch_size;
        Self {
            model_size: ModelSize::default(), // Tiny
            total_steps,
            batch_size,
            lr: DEFAULT_LR,
            weight_decay: DEFAULT_WD,
            beta1: 0.9,
            beta2: ADAM_BETA2,
            epsilon: 1e-8,
            grad_clip_norm: GRAD_CLIP_NORM,
            warmup_fraction: 0.2,
            cooldown_fraction: 0.2,
            lr_schedule: LrScheduleType::RsqrtWithWarmupCooldown,
            checkpoint_every: 500,
            log_every: 50,
            seed: 0,
            caption_key: CaptionKey::HighLevelSummary,
            tokenizer_path: "tokenizer.model".to_string(),
            artifact_dir: "./artifacts".to_string(),
            data_dir: "./data".to_string(),
            num_workers: 2,
            vram_gb: None,
            skip_vram_check: false,
            show_summary: false,
        }
    }
}

/// Which caption tier to use as the text pair during training.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CaptionKey {
    /// Statistical summary only (level 1).
    LowLevel,
    /// Structural patterns only (level 2).
    MiddleLevel,
    /// High-level semantic summary (level 3, 256 tokens).
    HighLevelSummary,
    /// Full high-level caption (level 3, 1024 tokens).
    HighLevelAll,
    /// Levels 2 + 1.
    MiddleLow,
    /// Levels 3 + 1.
    HighLow,
    /// Levels 3 + 2.
    HighMiddle,
    /// All three levels concatenated.
    HighMiddleLow,
}

impl CaptionKey {
    /// Maximum token budget for this caption type.
    pub fn max_tokens(self) -> usize {
        match self {
            Self::LowLevel => 512,
            Self::MiddleLevel => 512,
            Self::HighLevelSummary => 256,
            Self::HighLevelAll => 1024,
            Self::MiddleLow => 1024,
            Self::HighLow => 1024,
            Self::HighMiddle => 512,
            Self::HighMiddleLow => 1024,
        }
    }
}

// ===========================================================================
// Inference
// ===========================================================================

/// Configuration for inference / evaluation.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InferenceConfig {
    /// Path to model checkpoint.
    pub checkpoint: String,
    /// Path to tokeniser model.
    pub tokenizer_path: String,
    /// Maximum sequence length for text input.
    pub max_seq_len: usize,
    /// Batch size for encoding.
    pub batch_size: usize,
    /// Use FP16 for faster inference (requires `fp16` feature).
    pub fp16: bool,
    /// Caption key to use when generating text from sensor data.
    pub caption_key: CaptionKey,
}

impl Default for InferenceConfig {
    fn default() -> Self {
        Self {
            checkpoint: "./artifacts/model_final.bin".to_string(),
            tokenizer_path: "tokenizer.model".to_string(),
            max_seq_len: 256,
            batch_size: 64,
            fp16: false,
            caption_key: CaptionKey::HighLevelSummary,
        }
    }
}

// ===========================================================================
// Quantisation
// ===========================================================================

/// INT8 quantisation scheme.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum QuantScheme {
    /// Symmetric per-tensor quantisation.
    SymmetricPerTensor,
    /// Asymmetric per-tensor quantisation.
    AsymmetricPerTensor,
    /// Symmetric per-channel (output channel) quantisation.
    SymmetricPerChannel,
}

/// Post-training quantisation configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizationConfig {
    /// Source FP32 checkpoint to quantise.
    pub source_checkpoint: String,
    /// Output path for the quantised model.
    pub output_path: String,
    /// Path to a calibration dataset subset (Parquet).
    pub calibration_data: String,
    /// Number of calibration batches.
    pub num_calibration_batches: usize,
    /// Batch size during calibration.
    pub calibration_batch_size: usize,
    /// INT8 quantisation scheme.
    pub scheme: QuantScheme,
    /// Quantise text encoder weights (in addition to sensor encoder).
    pub quantise_text_encoder: bool,
    /// Path to tokeniser model.
    pub tokenizer_path: String,
}

impl Default for QuantizationConfig {
    fn default() -> Self {
        Self {
            source_checkpoint: "./artifacts/model_final.bin".to_string(),
            output_path: "./artifacts/model_int8.bin".to_string(),
            calibration_data: "./data/calibration.parquet".to_string(),
            num_calibration_batches: 100,
            calibration_batch_size: 32,
            scheme: QuantScheme::SymmetricPerTensor,
            quantise_text_encoder: true,
            tokenizer_path: "tokenizer.model".to_string(),
        }
    }
}