sensorlm/constants.rs
1//! Physical constants, feature lists, and normalisation parameters.
2//!
3//! All values are ported directly from the Python reference implementation
4//! (`sensorlm/constants.py`).
5
6// ---------------------------------------------------------------------------
7// Sensor channel definition
8// ---------------------------------------------------------------------------
9
10/// Names of the 34 features that are stored in the dataset and used for both
11/// the ViT sensor encoder and the statistical captioning pipeline.
12///
13/// The feature order is significant: index `i` corresponds to channel `i` in
14/// every `(T, C)` sensor tensor.
15pub const FEATURE_NAMES: &[&str] = &[
16 "HR",
17 "eda_level_real",
18 "leads_contact_counts",
19 "steps",
20 "jerk_auto",
21 "log_energy",
22 "covariance",
23 "log_energy_ratio",
24 "zero_crossing_std",
25 "zero_crossing_avg",
26 "axis_mean",
27 "altim_std",
28 "kurtosis",
29 "sleep_coefficient",
30 "wrist_temperatures",
31 "rr_med",
32 "sdnn0595",
33 "rmssd0595",
34 "pnn20",
35 "coherence",
36 "ShEnRR",
37 "LF",
38 "HF",
39 "LF_HF",
40 "VLF",
41 "spectralEn",
42 "percent_good",
43 "sleep_stage_awake",
44 "sleep_stage_light",
45 "sleep_stage_deep",
46 "sleep_stage_rem",
47 "spo2",
48 "spo2_confidence",
49 "spo2_coverage",
50];
51
52/// Number of sensor channels (= `FEATURE_NAMES.len()`).
53pub const NUM_CHANNELS: usize = 34;
54
55/// Number of time-steps per sample (minutes in a day).
56pub const TIME_STEPS: usize = 1440;
57
58// ---------------------------------------------------------------------------
59// Normalisation parameters (mean, std)
60//
61// Each entry is `[mean, std]`. Normalised value:
62// z = (x - mean) / std
63// Denormalised value:
64// x = z * std + mean
65// ---------------------------------------------------------------------------
66
67/// Population-level (mean, std) pairs for every channel in [`FEATURE_NAMES`].
68pub const NORM_PARAMS: &[(f64, f64)] = &[
69 // HR
70 (75.958_6, 16.188_7),
71 // eda_level_real
72 (4.176_7, 5.589_3),
73 // leads_contact_counts
74 (226.486_4, 67.331_2),
75 // steps
76 (5.167_9, 18.892_6),
77 // jerk_auto
78 (203.467_2, 30.056_3),
79 // log_energy
80 (53.080_4, 49.652_6),
81 // covariance
82 (43.407_7, 13.952_9),
83 // log_energy_ratio
84 (44.848_3, 22.974_6),
85 // zero_crossing_std
86 (155.186_3, 28.237_8),
87 // zero_crossing_avg
88 (51.004_3, 37.475_6),
89 // axis_mean
90 (123.165_9, 21.471_0),
91 // altim_std
92 (0.004_2, 0.059_7),
93 // kurtosis
94 (105.595_4, 66.849_5),
95 // sleep_coefficient
96 (7.262_3, 5.394_6),
97 // wrist_temperatures
98 (31.674_5, 2.578_9),
99 // rr_med
100 (856.830_4, 160.118_1),
101 // sdnn0595
102 (64.800_3, 55.585_0),
103 // rmssd0595
104 (65.342_1, 74.783_1),
105 // pnn20
106 (0.566_7, 0.262_3),
107 // coherence
108 (0.180_8, 0.130_5),
109 // ShEnRR
110 (3.058_2, 0.667_3),
111 // LF
112 (1_551.837_6, 2_399.422_8),
113 // HF
114 (757.227_1, 1_873.923_9),
115 // LF_HF
116 (4.126_5, 4.506_6),
117 // VLF
118 (1_303.384_8, 1_906.101_7),
119 // spectralEn
120 (2.525_5, 0.393_1),
121 // percent_good
122 (0.484_6, 0.343_9),
123 // sleep_stage_awake
124 (0.042_4, 0.191_6),
125 // sleep_stage_light
126 (0.043_4, 0.202_1),
127 // sleep_stage_deep
128 (0.185_5, 0.383_0),
129 // sleep_stage_rem
130 (0.057_5, 0.230_1),
131 // spo2
132 (95.201_9, 2.464_6),
133 // spo2_confidence
134 (56.639_1, 42.106_4),
135 // spo2_coverage
136 (50.125_1, 19.097_1),
137];
138
139// ---------------------------------------------------------------------------
140// Features that must be clamped to ≥ 0 after denormalisation
141// ---------------------------------------------------------------------------
142
143/// Indices of channels whose physical value cannot be negative.
144/// These are clamped after denormalisation (steps and sleep_coefficient).
145pub const NON_NEGATIVE_CHANNELS: &[usize] = &[
146 3, // steps
147 13, // sleep_coefficient
148];
149
150// ---------------------------------------------------------------------------
151// Channel groupings used by the captioning pipeline
152// ---------------------------------------------------------------------------
153
154/// A logical group of sensor channels (name displayed in captions + feature
155/// index in [`FEATURE_NAMES`]).
156#[derive(Debug, Clone)]
157pub struct ChannelGroup {
158 /// Human-readable category name, e.g. `"Heart"`.
159 pub category: &'static str,
160 /// (display_name, feature_index) pairs for primary channels.
161 pub primary: &'static [(&'static str, usize)],
162 /// (display_name, feature_index) pairs for randomly sampled channels.
163 pub random: &'static [(&'static str, usize)],
164 /// How many random channels to sample per caption.
165 pub random_k: usize,
166}
167
168/// The four physiological channel groups used in captioning.
169pub const CHANNEL_GROUPS: &[ChannelGroup] = &[
170 ChannelGroup {
171 category: "Heart",
172 primary: &[
173 ("heart rate", 0), // HR
174 ("hrv rr", 15), // rr_med
175 ("hrv shannon entropy rr", 20), // ShEnRR
176 ("sdnn percentile", 16), // sdnn0595
177 ],
178 random: &[
179 ("hr at rest mean", 0), // placeholder – HR again (no hr_at_rest in 34-feature set)
180 ("hrv rr 80th percentile", 15),
181 ("hrv shannon entropy rrd", 20),
182 ("rmssd percentile mean", 17),
183 ],
184 random_k: 2,
185 },
186 ChannelGroup {
187 category: "Activity",
188 primary: &[
189 ("steps", 3),
190 ("jerk", 4),
191 ("log energy", 5),
192 ("kurtosis", 12),
193 ],
194 random: &[
195 ("covariance", 6),
196 ("log energy ratio", 7),
197 ("zero crossing std", 8),
198 ("zero crossing avg", 9),
199 ("axis mean", 10),
200 ("altim std", 11),
201 ],
202 random_k: 1,
203 },
204 ChannelGroup {
205 category: "Sleep",
206 primary: &[
207 ("sleep coefficient", 13),
208 ],
209 random: &[],
210 random_k: 0,
211 },
212 ChannelGroup {
213 category: "EDA",
214 primary: &[
215 ("eda level", 1),
216 ("skin temperature slope", 14), // wrist_temperatures
217 ("wrist temperatures", 14),
218 ],
219 random: &[
220 ("leads contact counts", 2),
221 ("ceda slope real micro siemens", 1),
222 ],
223 random_k: 1,
224 },
225];
226
227// ---------------------------------------------------------------------------
228// Caption token budgets
229// ---------------------------------------------------------------------------
230
231/// Number of tokens allocated to each caption type in the training pipeline.
232pub const CAPTION_TOKEN_BUDGET: &[(&str, usize)] = &[
233 ("low_level_caption", 512),
234 ("middle_level_caption", 512),
235 ("high_level_summary_caption", 256),
236 ("high_level_all_caption", 1024),
237 ("middle_low_level_caption", 1024),
238 ("high_low_level_caption", 1024),
239 ("high_middle_level_caption", 512),
240 ("high_middle_low_level_caption", 1024),
241];
242
243// ---------------------------------------------------------------------------
244// Model / training hyper-parameters (defaults mirroring the reference config)
245// ---------------------------------------------------------------------------
246
247/// Default vocabulary size used by the c4_en / T5 SentencePiece tokeniser.
248pub const VOCAB_SIZE: usize = 32_000;
249
250/// ViT-B hidden dimension.
251pub const VIT_WIDTH: usize = 768;
252/// ViT-B depth (number of transformer blocks).
253pub const VIT_DEPTH: usize = 12;
254/// ViT-B MLP expansion (4 × width).
255pub const VIT_MLP_DIM: usize = 3072;
256/// ViT-B number of attention heads.
257pub const VIT_HEADS: usize = 12;
258
259/// Patch height (time-axis): 10 minutes per patch.
260pub const PATCH_H: usize = 10;
261/// Patch width (channel-axis): 2 channels per patch.
262pub const PATCH_W: usize = 2;
263
264/// Number of patches along the time axis (1440 / 10).
265pub const NUM_PATCHES_T: usize = TIME_STEPS / PATCH_H;
266/// Number of patches along the channel axis (34 / 2 = 17; padded to even if needed).
267pub const NUM_PATCHES_C: usize = (NUM_CHANNELS + PATCH_W - 1) / PATCH_W;
268/// Total number of patches fed into the transformer.
269pub const NUM_PATCHES: usize = NUM_PATCHES_T * NUM_PATCHES_C;
270
271/// Shared embedding dimension for both modalities after final projection.
272pub const EMBED_DIM: usize = 768;
273
274/// SigLIP temperature initialisation value.
275pub const TEMPERATURE_INIT: f32 = 10.0;
276/// SigLIP bias initialisation value.
277pub const BIAS_INIT: f32 = -10.0;
278
279/// Default training batch size.
280///
281/// # ⚠ Memory warning
282///
283/// The ViT-B sensor encoder produces N = 2 448 patch tokens per sample.
284/// Attention score tensors scale as `B × H × chunk × N`, so even with
285/// `attn_chunk_size = 64` a batch of 8 samples at fp32 consumes:
286///
287/// ```text
288/// 8 × 12 × 64 × 2448 × 4 bytes ≈ 60 MB per chunk (forward only)
289/// ```
290///
291/// **The Burn autodiff tape holds ALL chunk intermediates simultaneously
292/// during the backward pass** — multiply by `ceil(N / chunk)` chunks and
293/// by `depth` layers. Keep `batch_size` ≤ 8 for ViT-B on a 16 GB GPU.
294/// Use `--cpu` with a smaller model config for quick experiments.
295pub const DEFAULT_BATCH_SIZE: usize = 8;
296/// Default learning rate.
297pub const DEFAULT_LR: f64 = 5e-4;
298/// Default weight decay.
299pub const DEFAULT_WD: f64 = 1e-4;
300/// Adam β₂.
301pub const ADAM_BETA2: f64 = 0.999;
302/// Gradient clip norm.
303pub const GRAD_CLIP_NORM: f64 = 1.0;
304/// Total pre-training examples.
305pub const TOTAL_EXAMPLES: usize = 50_000_000;