Skip to main content

sensorlm/
constants.rs

1//! Physical constants, feature lists, and normalisation parameters.
2//!
3//! All values are ported directly from the Python reference implementation
4//! (`sensorlm/constants.py`).
5
6// ---------------------------------------------------------------------------
7// Sensor channel definition
8// ---------------------------------------------------------------------------
9
10/// Names of the 34 features that are stored in the dataset and used for both
11/// the ViT sensor encoder and the statistical captioning pipeline.
12///
13/// The feature order is significant: index `i` corresponds to channel `i` in
14/// every `(T, C)` sensor tensor.
15pub const FEATURE_NAMES: &[&str] = &[
16    "HR",
17    "eda_level_real",
18    "leads_contact_counts",
19    "steps",
20    "jerk_auto",
21    "log_energy",
22    "covariance",
23    "log_energy_ratio",
24    "zero_crossing_std",
25    "zero_crossing_avg",
26    "axis_mean",
27    "altim_std",
28    "kurtosis",
29    "sleep_coefficient",
30    "wrist_temperatures",
31    "rr_med",
32    "sdnn0595",
33    "rmssd0595",
34    "pnn20",
35    "coherence",
36    "ShEnRR",
37    "LF",
38    "HF",
39    "LF_HF",
40    "VLF",
41    "spectralEn",
42    "percent_good",
43    "sleep_stage_awake",
44    "sleep_stage_light",
45    "sleep_stage_deep",
46    "sleep_stage_rem",
47    "spo2",
48    "spo2_confidence",
49    "spo2_coverage",
50];
51
52/// Number of sensor channels (= `FEATURE_NAMES.len()`).
53pub const NUM_CHANNELS: usize = 34;
54
55/// Number of time-steps per sample (minutes in a day).
56pub const TIME_STEPS: usize = 1440;
57
58// ---------------------------------------------------------------------------
59// Normalisation parameters (mean, std)
60//
61// Each entry is `[mean, std]`.  Normalised value:
62//   z = (x - mean) / std
63// Denormalised value:
64//   x = z * std + mean
65// ---------------------------------------------------------------------------
66
67/// Population-level (mean, std) pairs for every channel in [`FEATURE_NAMES`].
68pub const NORM_PARAMS: &[(f64, f64)] = &[
69    // HR
70    (75.958_6, 16.188_7),
71    // eda_level_real
72    (4.176_7, 5.589_3),
73    // leads_contact_counts
74    (226.486_4, 67.331_2),
75    // steps
76    (5.167_9, 18.892_6),
77    // jerk_auto
78    (203.467_2, 30.056_3),
79    // log_energy
80    (53.080_4, 49.652_6),
81    // covariance
82    (43.407_7, 13.952_9),
83    // log_energy_ratio
84    (44.848_3, 22.974_6),
85    // zero_crossing_std
86    (155.186_3, 28.237_8),
87    // zero_crossing_avg
88    (51.004_3, 37.475_6),
89    // axis_mean
90    (123.165_9, 21.471_0),
91    // altim_std
92    (0.004_2, 0.059_7),
93    // kurtosis
94    (105.595_4, 66.849_5),
95    // sleep_coefficient
96    (7.262_3, 5.394_6),
97    // wrist_temperatures
98    (31.674_5, 2.578_9),
99    // rr_med
100    (856.830_4, 160.118_1),
101    // sdnn0595
102    (64.800_3, 55.585_0),
103    // rmssd0595
104    (65.342_1, 74.783_1),
105    // pnn20
106    (0.566_7, 0.262_3),
107    // coherence
108    (0.180_8, 0.130_5),
109    // ShEnRR
110    (3.058_2, 0.667_3),
111    // LF
112    (1_551.837_6, 2_399.422_8),
113    // HF
114    (757.227_1, 1_873.923_9),
115    // LF_HF
116    (4.126_5, 4.506_6),
117    // VLF
118    (1_303.384_8, 1_906.101_7),
119    // spectralEn
120    (2.525_5, 0.393_1),
121    // percent_good
122    (0.484_6, 0.343_9),
123    // sleep_stage_awake
124    (0.042_4, 0.191_6),
125    // sleep_stage_light
126    (0.043_4, 0.202_1),
127    // sleep_stage_deep
128    (0.185_5, 0.383_0),
129    // sleep_stage_rem
130    (0.057_5, 0.230_1),
131    // spo2
132    (95.201_9, 2.464_6),
133    // spo2_confidence
134    (56.639_1, 42.106_4),
135    // spo2_coverage
136    (50.125_1, 19.097_1),
137];
138
139// ---------------------------------------------------------------------------
140// Features that must be clamped to ≥ 0 after denormalisation
141// ---------------------------------------------------------------------------
142
143/// Indices of channels whose physical value cannot be negative.
144/// These are clamped after denormalisation (steps and sleep_coefficient).
145pub const NON_NEGATIVE_CHANNELS: &[usize] = &[
146    3,  // steps
147    13, // sleep_coefficient
148];
149
150// ---------------------------------------------------------------------------
151// Channel groupings used by the captioning pipeline
152// ---------------------------------------------------------------------------
153
154/// A logical group of sensor channels (name displayed in captions + feature
155/// index in [`FEATURE_NAMES`]).
156#[derive(Debug, Clone)]
157pub struct ChannelGroup {
158    /// Human-readable category name, e.g. `"Heart"`.
159    pub category: &'static str,
160    /// (display_name, feature_index) pairs for primary channels.
161    pub primary: &'static [(&'static str, usize)],
162    /// (display_name, feature_index) pairs for randomly sampled channels.
163    pub random: &'static [(&'static str, usize)],
164    /// How many random channels to sample per caption.
165    pub random_k: usize,
166}
167
168/// The four physiological channel groups used in captioning.
169pub const CHANNEL_GROUPS: &[ChannelGroup] = &[
170    ChannelGroup {
171        category: "Heart",
172        primary: &[
173            ("heart rate", 0),           // HR
174            ("hrv rr", 15),              // rr_med
175            ("hrv shannon entropy rr", 20), // ShEnRR
176            ("sdnn percentile", 16),     // sdnn0595
177        ],
178        random: &[
179            ("hr at rest mean", 0),      // placeholder – HR again (no hr_at_rest in 34-feature set)
180            ("hrv rr 80th percentile", 15),
181            ("hrv shannon entropy rrd", 20),
182            ("rmssd percentile mean", 17),
183        ],
184        random_k: 2,
185    },
186    ChannelGroup {
187        category: "Activity",
188        primary: &[
189            ("steps", 3),
190            ("jerk", 4),
191            ("log energy", 5),
192            ("kurtosis", 12),
193        ],
194        random: &[
195            ("covariance", 6),
196            ("log energy ratio", 7),
197            ("zero crossing std", 8),
198            ("zero crossing avg", 9),
199            ("axis mean", 10),
200            ("altim std", 11),
201        ],
202        random_k: 1,
203    },
204    ChannelGroup {
205        category: "Sleep",
206        primary: &[
207            ("sleep coefficient", 13),
208        ],
209        random: &[],
210        random_k: 0,
211    },
212    ChannelGroup {
213        category: "EDA",
214        primary: &[
215            ("eda level", 1),
216            ("skin temperature slope", 14),  // wrist_temperatures
217            ("wrist temperatures", 14),
218        ],
219        random: &[
220            ("leads contact counts", 2),
221            ("ceda slope real micro siemens", 1),
222        ],
223        random_k: 1,
224    },
225];
226
227// ---------------------------------------------------------------------------
228// Caption token budgets
229// ---------------------------------------------------------------------------
230
231/// Number of tokens allocated to each caption type in the training pipeline.
232pub const CAPTION_TOKEN_BUDGET: &[(&str, usize)] = &[
233    ("low_level_caption", 512),
234    ("middle_level_caption", 512),
235    ("high_level_summary_caption", 256),
236    ("high_level_all_caption", 1024),
237    ("middle_low_level_caption", 1024),
238    ("high_low_level_caption", 1024),
239    ("high_middle_level_caption", 512),
240    ("high_middle_low_level_caption", 1024),
241];
242
243// ---------------------------------------------------------------------------
244// Model / training hyper-parameters (defaults mirroring the reference config)
245// ---------------------------------------------------------------------------
246
247/// Default vocabulary size used by the c4_en / T5 SentencePiece tokeniser.
248pub const VOCAB_SIZE: usize = 32_000;
249
250/// ViT-B hidden dimension.
251pub const VIT_WIDTH: usize = 768;
252/// ViT-B depth (number of transformer blocks).
253pub const VIT_DEPTH: usize = 12;
254/// ViT-B MLP expansion (4 × width).
255pub const VIT_MLP_DIM: usize = 3072;
256/// ViT-B number of attention heads.
257pub const VIT_HEADS: usize = 12;
258
259/// Patch height (time-axis): 10 minutes per patch.
260pub const PATCH_H: usize = 10;
261/// Patch width (channel-axis): 2 channels per patch.
262pub const PATCH_W: usize = 2;
263
264/// Number of patches along the time axis (1440 / 10).
265pub const NUM_PATCHES_T: usize = TIME_STEPS / PATCH_H;
266/// Number of patches along the channel axis (34 / 2 = 17; padded to even if needed).
267pub const NUM_PATCHES_C: usize = (NUM_CHANNELS + PATCH_W - 1) / PATCH_W;
268/// Total number of patches fed into the transformer.
269pub const NUM_PATCHES: usize = NUM_PATCHES_T * NUM_PATCHES_C;
270
271/// Shared embedding dimension for both modalities after final projection.
272pub const EMBED_DIM: usize = 768;
273
274/// SigLIP temperature initialisation value.
275pub const TEMPERATURE_INIT: f32 = 10.0;
276/// SigLIP bias initialisation value.
277pub const BIAS_INIT: f32 = -10.0;
278
279/// Default training batch size.
280///
281/// # ⚠ Memory warning
282///
283/// The ViT-B sensor encoder produces N = 2 448 patch tokens per sample.
284/// Attention score tensors scale as `B × H × chunk × N`, so even with
285/// `attn_chunk_size = 64` a batch of 8 samples at fp32 consumes:
286///
287/// ```text
288/// 8 × 12 × 64 × 2448 × 4 bytes ≈ 60 MB  per chunk (forward only)
289/// ```
290///
291/// **The Burn autodiff tape holds ALL chunk intermediates simultaneously
292/// during the backward pass** — multiply by `ceil(N / chunk)` chunks and
293/// by `depth` layers.  Keep `batch_size` ≤ 8 for ViT-B on a 16 GB GPU.
294/// Use `--cpu` with a smaller model config for quick experiments.
295pub const DEFAULT_BATCH_SIZE: usize = 8;
296/// Default learning rate.
297pub const DEFAULT_LR: f64 = 5e-4;
298/// Default weight decay.
299pub const DEFAULT_WD: f64 = 1e-4;
300/// Adam β₂.
301pub const ADAM_BETA2: f64 = 0.999;
302/// Gradient clip norm.
303pub const GRAD_CLIP_NORM: f64 = 1.0;
304/// Total pre-training examples.
305pub const TOTAL_EXAMPLES: usize = 50_000_000;