sensorlm-rs 0.1.0

SensorLM – wearable sensor foundation model in Rust (Burn + WGPU)
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
# sensorlm-rs

A complete Rust implementation of **SensorLM** — *Learning the Language of Wearable Sensors* (NeurIPS 2025) — using the [Burn](https://burn.dev) deep-learning framework and the WGPU GPU backend.

---

## Table of contents

1. [What is SensorLM?]#what-is-sensorlm
2. [Architecture]#architecture
   - [Sensor encoder (ViT-B/10/2)]#sensor-encoder
   - [Text encoder (Transformer-B)]#text-encoder
   - [Two-tower model]#two-tower-model
   - [SigLIP contrastive loss]#siglip-contrastive-loss
3. [Caption generation pipeline]#caption-generation-pipeline
   - [Level 1 – Statistical]#level-1--statistical-captions
   - [Level 2 – Structural]#level-2--structural-captions
   - [Level 3 – Semantic]#level-3--semantic-captions
   - [Caption variants]#caption-variants
4. [Crate structure]#crate-structure
5. [Quick start]#quick-start
6. [Training]#training
7. [Inference]#inference
8. [Quantisation]#quantisation
9. [Dataset download]#dataset-download
10. [Configuration reference]#configuration-reference
11. [Backend selection]#backend-selection
12. [Design decisions]#design-decisions

---

## What is SensorLM?

SensorLM is a family of **sensor-language foundation models** that learn to align wearable device sensor data (heart rate, accelerometry, skin conductance, SpO₂, …) with natural language descriptions.  Training uses a novel **hierarchical automatic captioning pipeline** that generates paired (sensor, text) data without human annotation, applied to 59.7 million hours of data from 103,000+ individuals.

Key capabilities:

| Task | Method |
|------|--------|
| Zero-shot activity classification | Cosine similarity to class-name prompts |
| Few-shot learning | Fine-tune with a handful of labelled examples |
| Cross-modal retrieval | Nearest-neighbour in the shared embedding space |
| Sensor captioning | Generate natural-language descriptions of recordings |

---

## Architecture

```
┌──────────────────────────────────────────────────────────────────────────────┐
│                         SensorLM  (Two-Tower Model)                          │
│                                                                              │
│  ┌──────────────────────────────────┐   ┌──────────────────────────────────┐ │
│  │       SENSOR ENCODER             │   │       TEXT ENCODER               │ │
│  │  Input: (B, 1440, 34)  f32       │   │  Input: token IDs (B, L)  i32    │ │
│  │                                  │   │                                  │ │
│  │  ┌────────────────────────────┐  │   │  TokenEmbedding(32000, 768)      │ │
│  │  │  PatchEmbedding            │  │   │  + PositionalEmbedding(L, 768)   │ │
│  │  │  Conv2d(1→768, k=(10,2),   │  │   │                                  │ │
│  │  │         stride=(10,2))     │  │   │  ┌──────────────────────────┐    │ │
│  │  │  Output: (B, 2448, 768)    │  │   │  │  TransformerBlock × 12   │    │ │
│  │  └────────────────────────────┘  │   │  │  ├─ LayerNorm            │    │ │
│  │           + PosEmbed             │   │  │  ├─ MHSA (12 heads)      │    │ │
│  │                                  │   │  │  ├─ LayerNorm            │    │ │
│  │  ┌────────────────────────────┐  │   │  │  └─ MLP (768→3072→768)   │    │ │
│  │  │  TransformerBlock × 12     │  │   │  └──────────────────────────┘    │ │
│  │  │  ├─ LayerNorm              │  │   │                                  │ │
│  │  │  ├─ MHSA (12 heads)        │  │   │  Masked mean-pool → (B, 768)     │ │
│  │  │  ├─ LayerNorm              │  │   │  Linear projection → (B, 768)    │ │
│  │  │  └─ MLP (768→3072→768)     │  │   │  L2-normalise → (B, 768)         │ │
│  │  └────────────────────────────┘  │   └──────────────────┬───────────────┘ │
│  │                                  │                      │                 │
│  │  MAPHead (probe cross-attn)      │                      │                 │
│  │  L2-normalise → (B, 768)         │                      │                 │
│  └──────────────────┬───────────────┘                      │                 │
│                     │                                      │                 │
│                     └──────────────┬───────────────────────┘                 │
│                                    ▼                                         │
│              S[i,j] = temperature × dot(z_s[i], z_t[j]) + bias               │
│                                    │                                         │
│              SigLIP loss = -mean_{i,j}[ log(sigmoid(y[i,j] · S[i,j])) ]      │
│              where y[i,j] = +1 if i==j, else -1                              │
└──────────────────────────────────────────────────────────────────────────────┘
```

### Sensor encoder

The sensor encoder is a **Vision Transformer (ViT-B)** adapted for 2-D sensor grids:

| Property | Value |
|----------|-------|
| Input shape | `(B, 1440, 34)` – batch × time-steps × channels |
| Patch size | `(10, 2)` – 10 minutes × 2 channels per patch |
| Patch grid | `144 × 17 = 2448` patches |
| Hidden dim | 768 |
| Depth | 12 transformer blocks |
| Heads | 12 |
| MLP dim | 3072 (4× hidden) |
| Pooling | MAP (Multihead Attention Pooling) |
| Output | `(B, 768)` L2-normalised embedding |

**Patch embedding** treats the `(T, C)` sensor grid like a single-channel image:

```
Input  : (B, T, C)
Reshape: (B, 1, T, C)   ← add image channel dim
Conv2d : kernel=(10,2), stride=(10,2), out_channels=768
Output : (B, 768, 144, 17)
Reshape: (B, 2448, 768)  ← flatten spatial dims
```

**MAP head (Multihead Attention Pooling)** replaces the standard `[CLS]` token:

```
probe   : Param(1, 1, 768)          ← learnable
expand  : (B, 1, 768)
Q = proj(probe)  ← query is the single probe token
K = proj(tokens) ← keys from all 2448 patch tokens
V = proj(tokens) ← values from all 2448 patch tokens
ctx = Attn(Q, K, V)                 ← (B, 1, 768)
out = ctx + MLP(LayerNorm(ctx))      ← (B, 768)
```

### Text encoder

| Property | Value |
|----------|-------|
| Vocabulary | 32 000 (c4_en SentencePiece) |
| Max sequence length | 1024 tokens |
| Hidden dim | 768 |
| Depth | 12 transformer blocks |
| Heads | 12 |
| MLP dim | 3072 |
| Pooling | Masked mean pooling |
| Output projection | Linear(768, 768) |
| Output | `(B, 768)` L2-normalised embedding |

### Two-tower model

The two towers share **no weights**; they are coupled only through the contrastive loss.

Two learnable scalars:
- `log_temperature` – initialised so `exp(log_temperature) ≈ 10.0`
- `bias`            – initialised to `-10.0`

### SigLIP contrastive loss

SigLIP replaces the standard softmax CLIP loss with independent sigmoid binary
cross-entropy over every element of the `(B, B)` similarity matrix.  This
avoids the softmax denominator's dependence on batch size and improves training
stability.

```
S[i,j] = temperature · dot(z_sensor[i], z_text[j]) + bias

y[i,j] = +1   if i == j  (same sample)
         -1   if i != j

L = -1/B² · Σ_{i,j} log( sigmoid( y[i,j] · S[i,j] ) )
```

The bias initialisation of `-10` makes the initial probability of a true pair
approximately `1/B`, giving a near-uniform initial loss that is easy to
optimise.

---

## Caption generation pipeline

The **key contribution** of SensorLM is the automatic generation of rich text
descriptions for unlabelled wearable recordings.  Three complementary levels
of captions are generated and combined.

### Level 1 – Statistical captions

Describes per-channel statistics (mean, max, min, std) for each physiological
group after denormalising back to physical units.

```
For Heart, heart rate mean, max, min, std are 72.3, 96.1, 57.8, 8.4.
hrv rr exhibits a mean of 818.2, with range 960.0 to 680.0 and a
standard deviation of 45.2.
For Activity, steps mean, max, min, std are 3.1, 14.0, 0.0, 2.3. …
```

**Implementation** (`src/data/captioning/statistical.rs`):
1. Denormalise `(T, C)` array → physical units.
2. Apply missingness mask (set imputed values to NaN).
3. Compute `(mean, max, min, std)` per channel ignoring NaN.
4. For each group: pick all primary channels + `random_k` channels from the
   random pool.
5. Fill a randomly selected template string.

### Level 2 – Structural captions

Describes temporal patterns – trends and anomalies.

```
Heart: From minute 240 to 480, heart rate exhibits an increasing trend.
A spike is detected for heart rate at minute 600.
steps is decreasing between minute 720 and 800.
```

**Trend detection** (multi-scale sliding-window OLS regression):
- Window sizes: 6, 8, 12 downsampled points (each = 40 minutes).
- Slope threshold: `scale × range / 40` (adaptive to signal amplitude).
- Keep top-3 non-overlapping trends per channel.

**Anomaly detection** (prominence-based peak finder):
- Prominence threshold: `0.5 × range`.
- Height threshold: `mean + 0.6 × range`.
- Minimum distance between peaks: 5 points.

### Level 3 – Semantic captions

Describes high-level events: labelled activities, sleep periods, mood logs.

```
Walking was detected between minutes 480 and 540.
Running occurred from minute 600 to 660.
Sleep during minutes 0 to 440.
The person logged their mood as calm at minute 300.
```

Activities are filtered by minimum duration (20 min) and the top-8 longest
are kept.  Up to 2 sleep periods are included.

### Caption variants

| Key | Levels | Max tokens |
|-----|--------|-----------|
| `low_level_caption` | 1 | 512 |
| `middle_level_caption` | 2 | 512 |
| `high_level_summary_caption` | 3 (short) | 256 |
| `high_level_all_caption` | 3 (full) | 1024 |
| `middle_low_level_caption` | 2 + 1 | 1024 |
| `high_low_level_caption` | 3 + 1 | 1024 |
| `high_middle_level_caption` | 3 + 2 | 512 |
| `high_middle_low_level_caption` | 3 + 2 + 1 | 1024 |

---

## Crate structure

```
sensorlm-rs/
├── Cargo.toml
├── README.md
├── examples/
│   ├── train_siglip.rs        CPU training demo on synthetic data
│   ├── inference_demo.rs      Zero-shot classification + Recall@k
│   └── quantize_model.rs      INT8 PTQ demo
└── src/
    ├── lib.rs                 Crate root, backend type aliases
    ├── config.rs              All configuration structs (serde JSON)
    ├── constants.rs           Feature names, norm params, ViT hyperparams
    ├── error.rs               Crate-wide error type
    ├── loss.rs                SigLIP sigmoid contrastive loss, Recall@k
    ├── data/
    │   ├── mod.rs
    │   ├── preprocessing.rs   Normalise / denormalise, masking, downsample
    │   ├── dataset.rs         SyntheticSensorDataset, ParquetSensorDataset
    │   ├── download.rs        HTTP download helper, PAMAP2/WESAD registry
    │   └── captioning/
    │       ├── mod.rs         generate_caption() entry point
    │       ├── templates.rs   All text template strings
    │       ├── statistical.rs Level-1 statistical captions
    │       ├── structural.rs  Level-2 trend + anomaly captions
    │       └── semantic.rs    Level-3 activity / sleep / mood captions
    ├── model/
    │   ├── mod.rs             Architecture documentation diagram
    │   ├── sensor_encoder.rs  PatchEmbedding, MAPHead, EncoderBlock, ViT
    │   ├── text_encoder.rs    TextEncoder (embedding + transformer + pool)
    │   └── sensorlm.rs        Two-tower model, SensorLMBatcher, TrainStep
    ├── training/
    │   ├── mod.rs
    │   ├── learner.rs         train(), save_model(), load_model()
    │   └── scheduler.rs       RsqrtScheduler (warmup + rsqrt + cooldown)
    ├── inference/
    │   ├── mod.rs
    │   ├── zero_shot.rs       ZeroShotClassifier
    │   └── retrieval.rs       RetrievalEngine (sensor↔text KNN)
    ├── quantization/
    │   ├── mod.rs
    │   └── int8.rs            INT8 PTQ, FP16 export, QuantizedModel
    └── bin/
        └── sensorlm.rs        CLI: train / infer / quantize / download / captions
```

---

## Quick start

```bash
# Build the crate (requires Rust 1.75+)
cd /agent/sensorlm-rs
cargo build --release

# Run the inference demo (CPU, no GPU required)
cargo run --example inference_demo

# Run the quantisation demo
cargo run --example quantize_model

# Generate a caption from a dummy sensor file
cargo run --bin sensorlm -- generate-captions \
    --input dummy.csv \
    --caption-type high-summary
```

---

## Training

### Full-precision GPU training

```bash
cargo run --bin sensorlm -- train \
    --data-dir /path/to/dataset \
    --artifact-dir ./artifacts \
    --batch-size 1024
```

### CPU training (for testing)

```bash
cargo run --bin sensorlm -- train \
    --data-dir /path/to/dataset \
    --artifact-dir ./artifacts \
    --batch-size 16 \
    --cpu
```

### Training configuration

| Parameter | Default | Description |
|-----------|---------|-------------|
| `batch_size` | 1024 | Mini-batch size |
| `lr` | 5e-4 | Peak learning rate |
| `weight_decay` | 1e-4 | AdamW weight decay |
| `beta2` | 0.999 | Adam β₂ |
| `grad_clip_norm` | 1.0 | Gradient clip norm |
| `warmup_fraction` | 0.2 | Fraction of steps for linear warm-up |
| `cooldown_fraction` | 0.2 | Fraction of steps for linear cool-down |
| `total_steps` | 48828 | Steps for 50 M examples at batch 1024 |

---

## Inference

### Zero-shot classification

```rust
use sensorlm::inference::zero_shot::{ClassifierConfig, ZeroShotClassifier};

let cfg = ClassifierConfig {
    class_names: vec!["walking".into(), "running".into(), "sleeping".into()],
    prompt_template: "The person is {label}.".into(),
};
let clf = ZeroShotClassifier::new(model, &cfg, tokenize_fn);
let predictions = clf.predict(sensor_batch); // Vec<(class_idx, name, score)>
```

### Cross-modal retrieval

```rust
use sensorlm::inference::retrieval::RetrievalEngine;

let mut engine = RetrievalEngine::new(model);
engine.index_text(text_batches);

let results = engine.sensor_to_text(query_sensor, 5); // top-5 texts per query
```

---

## Quantisation

### INT8 weight-only PTQ

```bash
cargo run --bin sensorlm -- quantize \
    --checkpoint ./artifacts/model_final \
    --output ./artifacts/model_int8.json \
    --calibration-data ./data/calibration.parquet \
    --num-batches 100
```

Expected compression:

| Precision | Memory | Relative quality |
|-----------|--------|-----------------|
| FP32 | 1346 MB (ViT-B × 2) | baseline |
| FP16 | 673 MB | ≈ FP32 |
| INT8 | 336 MB | −0.5–1 % Recall@1 |

---

## Dataset download

```bash
# PAMAP2 physical activity monitoring (~680 MB)
cargo run --bin sensorlm -- download --dataset pamap2 --dest ./data

# WESAD wearable stress & affect detection (~1.8 GB)
cargo run --bin sensorlm -- download --dataset wesad --dest ./data
```

The original SensorLM training corpus is proprietary (Google internal).  The
PAMAP2 and WESAD datasets are publicly available substitutes for research use.

---

## Configuration reference

All configuration structs implement `serde::Serialize / Deserialize` and can
be written/read as JSON:

```json
{
  "sensor_encoder": {
    "time_steps": 1440,
    "num_channels": 34,
    "patch_h": 10,
    "patch_w": 2,
    "d_model": 768,
    "depth": 12,
    "num_heads": 12,
    "mlp_dim": 3072,
    "dropout": 0.0,
    "pool_type": "Map"
  },
  "text_encoder": {
    "vocab_size": 32000,
    "max_seq_len": 1024,
    "d_model": 768,
    "depth": 12,
    "num_heads": 12,
    "mlp_dim": 3072,
    "dropout": 0.0,
    "out_dim": 768
  },
  "embed_dim": 768,
  "temperature_init": 10.0,
  "bias_init": -10.0
}
```

---

## Backend selection

| Scenario | Backend | Type alias |
|----------|---------|-----------|
| GPU training | `Autodiff<Wgpu>` | `TrainBackend` |
| GPU inference | `Wgpu` | `WgpuBackend` |
| CPU training | `Autodiff<NdArray>` | `CpuTrainBackend` |
| CPU inference / tests | `NdArray` | `CpuBackend` |

Switch backend by changing the generic type parameter on `SensorLMModel<B>`.

---

## Design decisions

| Decision | Rationale |
|----------|-----------|
| Manual transformer blocks | Full control over architecture, no dependency on burn's internal transformer API versioning |
| MAP pooling | Matches reference implementation; outperforms GAP on sensor data |
| Learnable `Param` tensors for pos-embed / probe | Allows positional embeddings to adapt to the irregular sensor grid structure |
| Rectangular patches `(10, 2)` | Captures 10-minute temporal patterns while keeping spatial channel locality |
| SigLIP loss over CLIP | Batch-size independent; avoids softmax instability at large batch sizes |
| INT8 weight-only quant | Halves memory vs FP16 with ~1% quality delta; suitable for edge deployment |
| Separate `CaptionContext` struct | Cleanly separates activity/sleep metadata from raw sensor preprocessing |
| `f64` in captioning pipeline | Avoids floating-point error accumulating in statistical aggregations over 1440 time steps |

---

## Citation

If you use this library in academic work, please cite it as:

```bibtex
@software{kosmyna2026sensorlm,
  author    = {Kosmyna, Nataliya},
  title     = {{SensorLM}: Rust implementation of SensorLM},
  year      = {2026},
  version   = {0.0.1},
  url       = {https://github.com/nataliyakosmyna/sensorlm-rs},
  note      = {Learning the Language of Wearable Sensors},
  license   = {Apache 2.0},
}
```

Or in plain text:

> Kosmyna, N. (2026). *SensorLM: Rust implementation of SensorLM* (v0.0.1).
> <https://github.com/nataliyakosmyna/sensorlm-rs>

Cite original SensorLM work:

```bibtex
@inproceedings{zhang2025sensorlm,
    title={SensorLM: Learning the Language of Wearable Sensors},
    author={Yuwei Zhang and Kumar Ayush and Siyuan Qiao and A. Ali Heydari and
            Girish Narayanswamy and Maxwell A. Xu and Ahmed A. Metwally and
            Shawn Xu and Jake Garrison and Xuhai Xu and Tim Althoff and
            Yun Liu and Pushmeet Kohli and Jiening Zhan and Mark Malhotra and
            Shwetak Patel and Cecilia Mascolo and Xin Liu and
            Daniel McDuff and Yuzhe Yang},
    booktitle={Conference on Neural Information Processing Systems (NeurIPS)},
    year={2025}
}
```

## License

Apache 2.0 — see [LICENSE](LICENSE).

> **Disclaimer**: This is a research re-implementation.  It is not an officially
> supported Google product and is not intended for clinical use.