# sensorlm-rs
A complete Rust implementation of **SensorLM** — *Learning the Language of Wearable Sensors* (NeurIPS 2025) — using the [Burn](https://burn.dev) deep-learning framework and the WGPU GPU backend.
---
## Table of contents
1. [What is SensorLM?](#what-is-sensorlm)
2. [Architecture](#architecture)
- [Sensor encoder (ViT-B/10/2)](#sensor-encoder)
- [Text encoder (Transformer-B)](#text-encoder)
- [Two-tower model](#two-tower-model)
- [SigLIP contrastive loss](#siglip-contrastive-loss)
3. [Caption generation pipeline](#caption-generation-pipeline)
- [Level 1 – Statistical](#level-1--statistical-captions)
- [Level 2 – Structural](#level-2--structural-captions)
- [Level 3 – Semantic](#level-3--semantic-captions)
- [Caption variants](#caption-variants)
4. [Crate structure](#crate-structure)
5. [Quick start](#quick-start)
6. [Training](#training)
7. [Inference](#inference)
8. [Quantisation](#quantisation)
9. [Dataset download](#dataset-download)
10. [Configuration reference](#configuration-reference)
11. [Backend selection](#backend-selection)
12. [Design decisions](#design-decisions)
---
## What is SensorLM?
SensorLM is a family of **sensor-language foundation models** that learn to align wearable device sensor data (heart rate, accelerometry, skin conductance, SpO₂, …) with natural language descriptions. Training uses a novel **hierarchical automatic captioning pipeline** that generates paired (sensor, text) data without human annotation, applied to 59.7 million hours of data from 103,000+ individuals.
Key capabilities:
| Zero-shot activity classification | Cosine similarity to class-name prompts |
| Few-shot learning | Fine-tune with a handful of labelled examples |
| Cross-modal retrieval | Nearest-neighbour in the shared embedding space |
| Sensor captioning | Generate natural-language descriptions of recordings |
---
## Architecture
```
┌──────────────────────────────────────────────────────────────────────────────┐
│ SensorLM (Two-Tower Model) │
│ │
│ ┌──────────────────────────────────┐ ┌──────────────────────────────────┐ │
│ │ SENSOR ENCODER │ │ TEXT ENCODER │ │
│ │ Input: (B, 1440, 34) f32 │ │ Input: token IDs (B, L) i32 │ │
│ │ │ │ │ │
│ │ ┌────────────────────────────┐ │ │ TokenEmbedding(32000, 768) │ │
│ │ │ PatchEmbedding │ │ │ + PositionalEmbedding(L, 768) │ │
│ │ │ Conv2d(1→768, k=(10,2), │ │ │ │ │
│ │ │ stride=(10,2)) │ │ │ ┌──────────────────────────┐ │ │
│ │ │ Output: (B, 2448, 768) │ │ │ │ TransformerBlock × 12 │ │ │
│ │ └────────────────────────────┘ │ │ │ ├─ LayerNorm │ │ │
│ │ + PosEmbed │ │ │ ├─ MHSA (12 heads) │ │ │
│ │ │ │ │ ├─ LayerNorm │ │ │
│ │ ┌────────────────────────────┐ │ │ │ └─ MLP (768→3072→768) │ │ │
│ │ │ TransformerBlock × 12 │ │ │ └──────────────────────────┘ │ │
│ │ │ ├─ LayerNorm │ │ │ │ │
│ │ │ ├─ MHSA (12 heads) │ │ │ Masked mean-pool → (B, 768) │ │
│ │ │ ├─ LayerNorm │ │ │ Linear projection → (B, 768) │ │
│ │ │ └─ MLP (768→3072→768) │ │ │ L2-normalise → (B, 768) │ │
│ │ └────────────────────────────┘ │ └──────────────────┬───────────────┘ │
│ │ │ │ │
│ │ MAPHead (probe cross-attn) │ │ │
│ │ L2-normalise → (B, 768) │ │ │
│ └──────────────────┬───────────────┘ │ │
│ │ │ │
│ └──────────────┬───────────────────────┘ │
│ ▼ │
│ S[i,j] = temperature × dot(z_s[i], z_t[j]) + bias │
│ │ │
│ SigLIP loss = -mean_{i,j}[ log(sigmoid(y[i,j] · S[i,j])) ] │
│ where y[i,j] = +1 if i==j, else -1 │
└──────────────────────────────────────────────────────────────────────────────┘
```
### Sensor encoder
The sensor encoder is a **Vision Transformer (ViT-B)** adapted for 2-D sensor grids:
| Input shape | `(B, 1440, 34)` – batch × time-steps × channels |
| Patch size | `(10, 2)` – 10 minutes × 2 channels per patch |
| Patch grid | `144 × 17 = 2448` patches |
| Hidden dim | 768 |
| Depth | 12 transformer blocks |
| Heads | 12 |
| MLP dim | 3072 (4× hidden) |
| Pooling | MAP (Multihead Attention Pooling) |
| Output | `(B, 768)` L2-normalised embedding |
**Patch embedding** treats the `(T, C)` sensor grid like a single-channel image:
```
Input : (B, T, C)
Reshape: (B, 1, T, C) ← add image channel dim
Conv2d : kernel=(10,2), stride=(10,2), out_channels=768
Output : (B, 768, 144, 17)
Reshape: (B, 2448, 768) ← flatten spatial dims
```
**MAP head (Multihead Attention Pooling)** replaces the standard `[CLS]` token:
```
probe : Param(1, 1, 768) ← learnable
expand : (B, 1, 768)
Q = proj(probe) ← query is the single probe token
K = proj(tokens) ← keys from all 2448 patch tokens
V = proj(tokens) ← values from all 2448 patch tokens
ctx = Attn(Q, K, V) ← (B, 1, 768)
out = ctx + MLP(LayerNorm(ctx)) ← (B, 768)
```
### Text encoder
| Vocabulary | 32 000 (c4_en SentencePiece) |
| Max sequence length | 1024 tokens |
| Hidden dim | 768 |
| Depth | 12 transformer blocks |
| Heads | 12 |
| MLP dim | 3072 |
| Pooling | Masked mean pooling |
| Output projection | Linear(768, 768) |
| Output | `(B, 768)` L2-normalised embedding |
### Two-tower model
The two towers share **no weights**; they are coupled only through the contrastive loss.
Two learnable scalars:
- `log_temperature` – initialised so `exp(log_temperature) ≈ 10.0`
- `bias` – initialised to `-10.0`
### SigLIP contrastive loss
SigLIP replaces the standard softmax CLIP loss with independent sigmoid binary
cross-entropy over every element of the `(B, B)` similarity matrix. This
avoids the softmax denominator's dependence on batch size and improves training
stability.
```
S[i,j] = temperature · dot(z_sensor[i], z_text[j]) + bias
y[i,j] = +1 if i == j (same sample)
-1 if i != j
L = -1/B² · Σ_{i,j} log( sigmoid( y[i,j] · S[i,j] ) )
```
The bias initialisation of `-10` makes the initial probability of a true pair
approximately `1/B`, giving a near-uniform initial loss that is easy to
optimise.
---
## Caption generation pipeline
The **key contribution** of SensorLM is the automatic generation of rich text
descriptions for unlabelled wearable recordings. Three complementary levels
of captions are generated and combined.
### Level 1 – Statistical captions
Describes per-channel statistics (mean, max, min, std) for each physiological
group after denormalising back to physical units.
```
For Heart, heart rate mean, max, min, std are 72.3, 96.1, 57.8, 8.4.
hrv rr exhibits a mean of 818.2, with range 960.0 to 680.0 and a
standard deviation of 45.2.
For Activity, steps mean, max, min, std are 3.1, 14.0, 0.0, 2.3. …
```
**Implementation** (`src/data/captioning/statistical.rs`):
1. Denormalise `(T, C)` array → physical units.
2. Apply missingness mask (set imputed values to NaN).
3. Compute `(mean, max, min, std)` per channel ignoring NaN.
4. For each group: pick all primary channels + `random_k` channels from the
random pool.
5. Fill a randomly selected template string.
### Level 2 – Structural captions
Describes temporal patterns – trends and anomalies.
```
Heart: From minute 240 to 480, heart rate exhibits an increasing trend.
A spike is detected for heart rate at minute 600.
steps is decreasing between minute 720 and 800.
```
**Trend detection** (multi-scale sliding-window OLS regression):
- Window sizes: 6, 8, 12 downsampled points (each = 40 minutes).
- Slope threshold: `scale × range / 40` (adaptive to signal amplitude).
- Keep top-3 non-overlapping trends per channel.
**Anomaly detection** (prominence-based peak finder):
- Prominence threshold: `0.5 × range`.
- Height threshold: `mean + 0.6 × range`.
- Minimum distance between peaks: 5 points.
### Level 3 – Semantic captions
Describes high-level events: labelled activities, sleep periods, mood logs.
```
Walking was detected between minutes 480 and 540.
Running occurred from minute 600 to 660.
Sleep during minutes 0 to 440.
The person logged their mood as calm at minute 300.
```
Activities are filtered by minimum duration (20 min) and the top-8 longest
are kept. Up to 2 sleep periods are included.
### Caption variants
| `low_level_caption` | 1 | 512 |
| `middle_level_caption` | 2 | 512 |
| `high_level_summary_caption` | 3 (short) | 256 |
| `high_level_all_caption` | 3 (full) | 1024 |
| `middle_low_level_caption` | 2 + 1 | 1024 |
| `high_low_level_caption` | 3 + 1 | 1024 |
| `high_middle_level_caption` | 3 + 2 | 512 |
| `high_middle_low_level_caption` | 3 + 2 + 1 | 1024 |
---
## Crate structure
```
sensorlm-rs/
├── Cargo.toml
├── README.md
├── examples/
│ ├── train_siglip.rs CPU training demo on synthetic data
│ ├── inference_demo.rs Zero-shot classification + Recall@k
│ └── quantize_model.rs INT8 PTQ demo
└── src/
├── lib.rs Crate root, backend type aliases
├── config.rs All configuration structs (serde JSON)
├── constants.rs Feature names, norm params, ViT hyperparams
├── error.rs Crate-wide error type
├── loss.rs SigLIP sigmoid contrastive loss, Recall@k
│
├── data/
│ ├── mod.rs
│ ├── preprocessing.rs Normalise / denormalise, masking, downsample
│ ├── dataset.rs SyntheticSensorDataset, ParquetSensorDataset
│ ├── download.rs HTTP download helper, PAMAP2/WESAD registry
│ └── captioning/
│ ├── mod.rs generate_caption() entry point
│ ├── templates.rs All text template strings
│ ├── statistical.rs Level-1 statistical captions
│ ├── structural.rs Level-2 trend + anomaly captions
│ └── semantic.rs Level-3 activity / sleep / mood captions
│
├── model/
│ ├── mod.rs Architecture documentation diagram
│ ├── sensor_encoder.rs PatchEmbedding, MAPHead, EncoderBlock, ViT
│ ├── text_encoder.rs TextEncoder (embedding + transformer + pool)
│ └── sensorlm.rs Two-tower model, SensorLMBatcher, TrainStep
│
├── training/
│ ├── mod.rs
│ ├── learner.rs train(), save_model(), load_model()
│ └── scheduler.rs RsqrtScheduler (warmup + rsqrt + cooldown)
│
├── inference/
│ ├── mod.rs
│ ├── zero_shot.rs ZeroShotClassifier
│ └── retrieval.rs RetrievalEngine (sensor↔text KNN)
│
├── quantization/
│ ├── mod.rs
│ └── int8.rs INT8 PTQ, FP16 export, QuantizedModel
│
└── bin/
└── sensorlm.rs CLI: train / infer / quantize / download / captions
```
---
## Quick start
```bash
# Build the crate (requires Rust 1.75+)
cd /agent/sensorlm-rs
cargo build --release
# Run the inference demo (CPU, no GPU required)
cargo run --example inference_demo
# Run the quantisation demo
cargo run --example quantize_model
# Generate a caption from a dummy sensor file
cargo run --bin sensorlm -- generate-captions \
--input dummy.csv \
--caption-type high-summary
```
---
## Training
### Full-precision GPU training
```bash
cargo run --bin sensorlm -- train \
--data-dir /path/to/dataset \
--artifact-dir ./artifacts \
--batch-size 1024
```
### CPU training (for testing)
```bash
cargo run --bin sensorlm -- train \
--data-dir /path/to/dataset \
--artifact-dir ./artifacts \
--batch-size 16 \
--cpu
```
### Training configuration
| `batch_size` | 1024 | Mini-batch size |
| `lr` | 5e-4 | Peak learning rate |
| `weight_decay` | 1e-4 | AdamW weight decay |
| `beta2` | 0.999 | Adam β₂ |
| `grad_clip_norm` | 1.0 | Gradient clip norm |
| `warmup_fraction` | 0.2 | Fraction of steps for linear warm-up |
| `cooldown_fraction` | 0.2 | Fraction of steps for linear cool-down |
| `total_steps` | 48828 | Steps for 50 M examples at batch 1024 |
---
## Inference
### Zero-shot classification
```rust
use sensorlm::inference::zero_shot::{ClassifierConfig, ZeroShotClassifier};
let cfg = ClassifierConfig {
class_names: vec!["walking".into(), "running".into(), "sleeping".into()],
prompt_template: "The person is {label}.".into(),
};
let clf = ZeroShotClassifier::new(model, &cfg, tokenize_fn);
let predictions = clf.predict(sensor_batch); // Vec<(class_idx, name, score)>
```
### Cross-modal retrieval
```rust
use sensorlm::inference::retrieval::RetrievalEngine;
let mut engine = RetrievalEngine::new(model);
engine.index_text(text_batches);
let results = engine.sensor_to_text(query_sensor, 5); // top-5 texts per query
```
---
## Quantisation
### INT8 weight-only PTQ
```bash
cargo run --bin sensorlm -- quantize \
--checkpoint ./artifacts/model_final \
--output ./artifacts/model_int8.json \
--calibration-data ./data/calibration.parquet \
--num-batches 100
```
Expected compression:
| FP32 | 1346 MB (ViT-B × 2) | baseline |
| FP16 | 673 MB | ≈ FP32 |
| INT8 | 336 MB | −0.5–1 % Recall@1 |
---
## Dataset download
```bash
# PAMAP2 physical activity monitoring (~680 MB)
cargo run --bin sensorlm -- download --dataset pamap2 --dest ./data
# WESAD wearable stress & affect detection (~1.8 GB)
cargo run --bin sensorlm -- download --dataset wesad --dest ./data
```
The original SensorLM training corpus is proprietary (Google internal). The
PAMAP2 and WESAD datasets are publicly available substitutes for research use.
---
## Configuration reference
All configuration structs implement `serde::Serialize / Deserialize` and can
be written/read as JSON:
```json
{
"sensor_encoder": {
"time_steps": 1440,
"num_channels": 34,
"patch_h": 10,
"patch_w": 2,
"d_model": 768,
"depth": 12,
"num_heads": 12,
"mlp_dim": 3072,
"dropout": 0.0,
"pool_type": "Map"
},
"text_encoder": {
"vocab_size": 32000,
"max_seq_len": 1024,
"d_model": 768,
"depth": 12,
"num_heads": 12,
"mlp_dim": 3072,
"dropout": 0.0,
"out_dim": 768
},
"embed_dim": 768,
"temperature_init": 10.0,
"bias_init": -10.0
}
```
---
## Backend selection
| GPU training | `Autodiff<Wgpu>` | `TrainBackend` |
| GPU inference | `Wgpu` | `WgpuBackend` |
| CPU training | `Autodiff<NdArray>` | `CpuTrainBackend` |
| CPU inference / tests | `NdArray` | `CpuBackend` |
Switch backend by changing the generic type parameter on `SensorLMModel<B>`.
---
## Design decisions
| Manual transformer blocks | Full control over architecture, no dependency on burn's internal transformer API versioning |
| MAP pooling | Matches reference implementation; outperforms GAP on sensor data |
| Learnable `Param` tensors for pos-embed / probe | Allows positional embeddings to adapt to the irregular sensor grid structure |
| Rectangular patches `(10, 2)` | Captures 10-minute temporal patterns while keeping spatial channel locality |
| SigLIP loss over CLIP | Batch-size independent; avoids softmax instability at large batch sizes |
| INT8 weight-only quant | Halves memory vs FP16 with ~1% quality delta; suitable for edge deployment |
| Separate `CaptionContext` struct | Cleanly separates activity/sleep metadata from raw sensor preprocessing |
| `f64` in captioning pipeline | Avoids floating-point error accumulating in statistical aggregations over 1440 time steps |
---
## Citation
If you use this library in academic work, please cite it as:
```bibtex
@software{kosmyna2026sensorlm,
author = {Kosmyna, Nataliya},
title = {{SensorLM}: Rust implementation of SensorLM},
year = {2026},
version = {0.0.1},
url = {https://github.com/nataliyakosmyna/sensorlm-rs},
note = {Learning the Language of Wearable Sensors},
license = {Apache 2.0},
}
```
Or in plain text:
> Kosmyna, N. (2026). *SensorLM: Rust implementation of SensorLM* (v0.0.1).
> <https://github.com/nataliyakosmyna/sensorlm-rs>
Cite original SensorLM work:
```bibtex
@inproceedings{zhang2025sensorlm,
title={SensorLM: Learning the Language of Wearable Sensors},
author={Yuwei Zhang and Kumar Ayush and Siyuan Qiao and A. Ali Heydari and
Girish Narayanswamy and Maxwell A. Xu and Ahmed A. Metwally and
Shawn Xu and Jake Garrison and Xuhai Xu and Tim Althoff and
Yun Liu and Pushmeet Kohli and Jiening Zhan and Mark Malhotra and
Shwetak Patel and Cecilia Mascolo and Xin Liu and
Daniel McDuff and Yuzhe Yang},
booktitle={Conference on Neural Information Processing Systems (NeurIPS)},
year={2025}
}
```
## License
Apache 2.0 — see [LICENSE](LICENSE).
> **Disclaimer**: This is a research re-implementation. It is not an officially
> supported Google product and is not intended for clinical use.