1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
//! `sensorlm` – command-line interface for the SensorLM pipeline.
//!
//! # Commands
//!
//! ```text
//! sensorlm train [options] Train a SensorLM-SigLIP model
//! sensorlm infer [options] Run inference / retrieval
//! sensorlm quantize [options] Post-training quantisation
//! sensorlm download [options] Download a public dataset
//! sensorlm generate-captions Generate captions from a sensor file
//! ```
use clap::{Parser, Subcommand};
#[derive(Parser)]
#[command(
name = "sensorlm",
about = "SensorLM – wearable sensor foundation model (Burn + WGPU)",
version = env!("CARGO_PKG_VERSION"),
long_about = None,
)]
struct Cli {
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
/// Train a SensorLM-SigLIP model.
Train {
/// Path to training config JSON (uses defaults if not provided).
#[arg(short, long)]
config: Option<String>,
/// Directory where checkpoints and logs are saved.
#[arg(short, long, default_value = "./artifacts")]
artifact_dir: String,
/// Path to the dataset directory (Parquet files).
#[arg(short, long, default_value = "./data")]
data_dir: String,
/// Model size preset: tiny | small | base
///
/// Selects ViT variant for both the sensor and text encoders:
/// tiny – ViT-Ti d=192 heads=3 mlp=768 ~11 M params ≤ 2 GB VRAM
/// small – ViT-S d=384 heads=6 mlp=1536 ~44 M params ≤ 6 GB VRAM
/// base – ViT-B d=768 heads=12 mlp=3072 ~205 M params ≥ 16 GB VRAM
#[arg(long, default_value = "tiny")]
model_size: String,
/// Batch size.
///
/// The Burn autodiff tape holds all chunked-attention intermediates for
/// one transformer layer simultaneously; peak scales as B × H × chunks × N.
/// Suggested maximums with chunk=64 on a 16 GB device:
/// tiny → 16 (per-layer bwd ≈ 2.1 GB)
/// small → 8 (per-layer bwd ≈ 2.2 GB)
/// base → 4 (per-layer bwd ≈ 2.2 GB)
#[arg(short, long, default_value_t = 16)]
batch_size: usize,
/// Available GPU / unified-memory VRAM in gigabytes.
///
/// When provided the tool derives the attention-tensor budget as
/// VRAM/3 and **auto-caps --batch-size** to the largest value that
/// fits — you no longer need to tune batch size manually.
///
/// --vram-gb 8 → base max batch 4, small max batch 9
/// --vram-gb 16 → base max batch 9, small max batch 18
/// --vram-gb 24 → base max batch 13, small max batch 27
/// --vram-gb 32 → base max batch 18, small max batch 36
///
/// Apple Silicon example (M2 Max, 32 GB unified memory):
/// cargo run … train --model-size base --vram-gb 32
#[arg(long)]
vram_gb: Option<f64>,
/// DataLoader worker threads for CPU-side data preparation (minimum 1).
///
/// WGPU (including Metal on macOS) is thread-safe; worker threads can
/// create GPU tensors without causing synchronisation stalls.
/// 2 is a good default. Raise on machines with many CPU cores.
#[arg(long, default_value_t = 2)]
num_workers: usize,
/// Skip the pre-flight VRAM safety check.
///
/// Use only when --vram-gb does not cover your use case and you are
/// certain the GPU has enough free VRAM. OOM errors and driver
/// crashes are your responsibility.
#[arg(long)]
no_vram_check: bool,
/// Print Burn's Learner Summary table at the end of training.
///
/// Shows per-metric train/valid values in a formatted table.
/// Hidden by default to keep output clean.
#[arg(long)]
summary: bool,
/// Use CPU backend instead of WGPU (for testing on machines without a GPU).
#[arg(long)]
cpu: bool,
},
/// Run zero-shot inference on a sensor file.
Infer {
/// Path to model checkpoint.
#[arg(short, long)]
checkpoint: String,
/// Path to tokeniser model file.
#[arg(long, default_value = "tokenizer.model")]
tokenizer: String,
/// Comma-separated class labels for zero-shot classification.
#[arg(short, long, default_value = "walking,running,sleeping,sedentary")]
classes: String,
/// Use CPU backend.
#[arg(long)]
cpu: bool,
},
/// Post-training INT8 quantisation.
Quantize {
/// FP32 checkpoint to quantise.
#[arg(short, long)]
checkpoint: String,
/// Output path for the quantised model JSON.
#[arg(short, long, default_value = "./artifacts/model_int8.json")]
output: String,
/// Path to calibration dataset (Parquet).
#[arg(long, default_value = "./data/calibration.parquet")]
calibration_data: String,
/// Number of calibration batches.
#[arg(long, default_value_t = 100)]
num_batches: usize,
},
/// Download a public dataset.
Download {
/// Dataset name: `pamap2` or `wesad`.
#[arg(short, long)]
dataset: String,
/// Destination directory.
#[arg(short, long, default_value = "./data")]
dest: String,
},
/// Generate text captions from a single normalised sensor file.
GenerateCaptions {
/// Path to a CSV/Parquet file with columns matching FEATURE_NAMES.
#[arg(short, long)]
input: String,
/// Caption type: low, middle, high-summary, high-all, or combinations.
#[arg(short, long, default_value = "high-summary")]
caption_type: String,
/// Random seed for template selection.
#[arg(long, default_value_t = 42)]
seed: u64,
},
}
fn main() {
// Initialise tracing.
//
// Default filter: suppress noisy WGPU/Metal internals (especially the
// `Device::maintain: waiting for submission index N` spam that floods
// the terminal during GPU training), while keeping sensorlm and burn logs.
//
// Override with RUST_LOG env var, e.g.:
// RUST_LOG=debug cargo run ... (show everything)
// RUST_LOG=error cargo run ... (errors only)
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new(
// Global default: warn (suppresses GPU backend spam).
// burn_train=error: silences the harmless "Failed to install
// the experiment logger" warn that fires because Burn's
// LearnerBuilder tries to set a second global
// tracing-subscriber after we've already set ours.
// sensorlm=info: keep our own pre-flight messages visible.
"warn,burn_train=error,sensorlm=info"
)),
)
.init();
let cli = Cli::parse();
match cli.command {
Commands::Train {
config,
artifact_dir,
data_dir,
model_size,
batch_size,
vram_gb,
num_workers,
no_vram_check,
summary,
cpu,
} => {
use sensorlm::config::{ModelSize, TrainingConfig};
// Parse model-size preset.
let size = match model_size.to_lowercase().as_str() {
"tiny" | "ti" => ModelSize::Tiny,
"small" | "s" => ModelSize::Small,
"base" | "b" => ModelSize::Base,
other => {
eprintln!("Unknown model size '{other}'. Choose: tiny | small | base");
std::process::exit(1);
}
};
// Build model config from the chosen preset.
let model_cfg = size.sensorlm_config();
let mut train_cfg = TrainingConfig::default();
train_cfg.model_size = size;
train_cfg.artifact_dir = artifact_dir;
train_cfg.data_dir = data_dir;
train_cfg.batch_size = batch_size;
train_cfg.vram_gb = vram_gb;
train_cfg.num_workers = num_workers;
train_cfg.skip_vram_check = no_vram_check;
train_cfg.show_summary = summary;
if let Some(cfg_path) = config {
eprintln!("Loading config from {cfg_path} (not yet implemented – using defaults)");
}
eprintln!(
"Model: {size:?} ({} params), batch={batch_size}, workers={num_workers}",
size.approx_params(),
);
if cpu {
// CPU training with NdArray backend.
use sensorlm::CpuTrainBackend;
eprintln!("Training on CPU (NdArray backend)…");
match sensorlm::training::learner::train::<CpuTrainBackend>(model_cfg, train_cfg) {
Ok(()) => eprintln!("Training complete."),
Err(e) => eprintln!("Training failed: {e}"),
}
} else {
// GPU training – requires --features wgpu.
// Falling back to CPU if wgpu feature is not enabled.
#[cfg(feature = "wgpu")]
{
use sensorlm::TrainBackend;
eprintln!("Training on GPU (WGPU backend)…");
match sensorlm::training::learner::train::<TrainBackend>(model_cfg, train_cfg) {
Ok(()) => eprintln!("Training complete."),
Err(e) => eprintln!("Training failed: {e}"),
}
}
#[cfg(not(feature = "wgpu"))]
{
eprintln!("WGPU backend not compiled in. Re-run with `--features wgpu` or use --cpu.");
eprintln!("Falling back to CPU (NdArray backend)…");
match sensorlm::training::learner::train::<sensorlm::CpuTrainBackend>(model_cfg, train_cfg) {
Ok(()) => eprintln!("Training complete."),
Err(e) => eprintln!("Training failed: {e}"),
}
}
}
}
Commands::Infer { checkpoint, tokenizer, classes, cpu: _ } => {
let class_names: Vec<String> = classes.split(',').map(str::trim).map(String::from).collect();
eprintln!("Running zero-shot inference with {} classes", class_names.len());
eprintln!("Checkpoint : {checkpoint}");
eprintln!("Tokenizer : {tokenizer}");
eprintln!("Classes : {}", class_names.join(", "));
eprintln!("(Full inference pipeline requires a loaded checkpoint – see examples/inference_demo.rs)");
}
Commands::Quantize { checkpoint, output, calibration_data, num_batches } => {
use sensorlm::quantization::int8::quantize_model_weights;
use sensorlm::config::SensorLMConfig;
eprintln!("Quantising {checkpoint} → {output}");
eprintln!("Calibration data : {calibration_data}");
eprintln!("Calibration batches: {num_batches}");
// Demonstration: quantise a set of randomly generated weights.
let config = SensorLMConfig::default();
let config_json = serde_json::to_string(&config).unwrap();
// In a real run, extract weights from the loaded checkpoint.
let dummy_layers = vec![
(
"sensor_encoder.patch_embed.proj.weight".to_string(),
vec![0.01f32; 768 * 100],
vec![768, 100],
None::<Vec<f32>>,
),
];
let qm = quantize_model_weights(config_json, dummy_layers.into_iter());
eprintln!(
"Compression: {:.1}x ({} MB → {} MB)",
qm.compression_ratio(),
qm.total_fp32_bytes / (1024 * 1024),
qm.total_quantized_bytes / (1024 * 1024),
);
let out_path = std::path::Path::new(&output);
if let Some(parent) = out_path.parent() {
let _ = std::fs::create_dir_all(parent);
}
match qm.save(out_path) {
Ok(()) => eprintln!("Saved quantised model to {output}"),
Err(e) => eprintln!("Save failed: {e}"),
}
}
Commands::Download { dataset, dest } => {
use sensorlm::data::download::{download_file, find_dataset};
use std::path::PathBuf;
match find_dataset(&dataset) {
Some(entry) => {
let dest_path = PathBuf::from(&dest).join(format!("{}.zip", entry.name));
eprintln!("Downloading {} to {}", entry.name, dest_path.display());
match download_file(entry.url, &dest_path, entry.sha256) {
Ok(()) => eprintln!("Download complete."),
Err(e) => eprintln!("Download failed: {e}"),
}
}
None => {
eprintln!(
"Unknown dataset '{}'. Available datasets: {}",
dataset,
sensorlm::data::download::KNOWN_DATASETS
.iter()
.map(|d| d.name)
.collect::<Vec<_>>()
.join(", ")
);
}
}
}
Commands::GenerateCaptions { input, caption_type, seed } => {
use sensorlm::config::CaptionKey;
use sensorlm::data::captioning::{generate_caption, CaptionContext};
use ndarray::Array2;
use rand::{rngs::StdRng, SeedableRng};
use sensorlm::constants::NUM_CHANNELS;
eprintln!("Generating '{caption_type}' caption for {input}");
// Load a dummy zero-valued sensor array for demonstration.
let x = Array2::<f64>::zeros((1440, NUM_CHANNELS));
let ctx = CaptionContext::new();
let key = match caption_type.as_str() {
"low" => CaptionKey::LowLevel,
"middle" => CaptionKey::MiddleLevel,
"high-summary" => CaptionKey::HighLevelSummary,
"high-all" => CaptionKey::HighLevelAll,
"middle-low" => CaptionKey::MiddleLow,
"high-low" => CaptionKey::HighLow,
"high-middle" => CaptionKey::HighMiddle,
"all" => CaptionKey::HighMiddleLow,
other => {
eprintln!("Unknown caption type '{other}', defaulting to high-summary");
CaptionKey::HighLevelSummary
}
};
let mut rng = StdRng::seed_from_u64(seed);
match generate_caption(&x.view(), None, &ctx, key, &mut rng) {
Ok(caption) => println!("{caption}"),
Err(e) => eprintln!("Caption generation failed: {e}"),
}
}
}
}