sensorlm-rs 0.1.0

SensorLM – wearable sensor foundation model in Rust (Burn + WGPU)
Documentation
//! `sensorlm` – command-line interface for the SensorLM pipeline.
//!
//! # Commands
//!
//! ```text
//! sensorlm train       [options]   Train a SensorLM-SigLIP model
//! sensorlm infer       [options]   Run inference / retrieval
//! sensorlm quantize    [options]   Post-training quantisation
//! sensorlm download    [options]   Download a public dataset
//! sensorlm generate-captions       Generate captions from a sensor file
//! ```

use clap::{Parser, Subcommand};

#[derive(Parser)]
#[command(
    name = "sensorlm",
    about = "SensorLM – wearable sensor foundation model (Burn + WGPU)",
    version = env!("CARGO_PKG_VERSION"),
    long_about = None,
)]
struct Cli {
    #[command(subcommand)]
    command: Commands,
}

#[derive(Subcommand)]
enum Commands {
    /// Train a SensorLM-SigLIP model.
    Train {
        /// Path to training config JSON (uses defaults if not provided).
        #[arg(short, long)]
        config: Option<String>,
        /// Directory where checkpoints and logs are saved.
        #[arg(short, long, default_value = "./artifacts")]
        artifact_dir: String,
        /// Path to the dataset directory (Parquet files).
        #[arg(short, long, default_value = "./data")]
        data_dir: String,
        /// Model size preset: tiny | small | base
        ///
        /// Selects ViT variant for both the sensor and text encoders:
        ///   tiny  – ViT-Ti  d=192 heads=3  mlp=768   ~11 M params   ≤ 2 GB VRAM
        ///   small – ViT-S   d=384 heads=6  mlp=1536  ~44 M params   ≤ 6 GB VRAM
        ///   base  – ViT-B   d=768 heads=12 mlp=3072 ~205 M params  ≥ 16 GB VRAM
        #[arg(long, default_value = "tiny")]
        model_size: String,
        /// Batch size.
        ///
        /// The Burn autodiff tape holds all chunked-attention intermediates for
        /// one transformer layer simultaneously; peak scales as B × H × chunks × N.
        /// Suggested maximums with chunk=64 on a 16 GB device:
        ///   tiny  → 16  (per-layer bwd ≈ 2.1 GB)
        ///   small →  8  (per-layer bwd ≈ 2.2 GB)
        ///   base  →  4  (per-layer bwd ≈ 2.2 GB)
        #[arg(short, long, default_value_t = 16)]
        batch_size: usize,
        /// Available GPU / unified-memory VRAM in gigabytes.
        ///
        /// When provided the tool derives the attention-tensor budget as
        /// VRAM/3 and **auto-caps --batch-size** to the largest value that
        /// fits — you no longer need to tune batch size manually.
        ///
        ///   --vram-gb 8   → base max batch  4,  small max batch  9
        ///   --vram-gb 16  → base max batch  9,  small max batch 18
        ///   --vram-gb 24  → base max batch 13,  small max batch 27
        ///   --vram-gb 32  → base max batch 18,  small max batch 36
        ///
        /// Apple Silicon example (M2 Max, 32 GB unified memory):
        ///   cargo run … train --model-size base --vram-gb 32
        #[arg(long)]
        vram_gb: Option<f64>,
        /// DataLoader worker threads for CPU-side data preparation (minimum 1).
        ///
        /// WGPU (including Metal on macOS) is thread-safe; worker threads can
        /// create GPU tensors without causing synchronisation stalls.
        /// 2 is a good default. Raise on machines with many CPU cores.
        #[arg(long, default_value_t = 2)]
        num_workers: usize,
        /// Skip the pre-flight VRAM safety check.
        ///
        /// Use only when --vram-gb does not cover your use case and you are
        /// certain the GPU has enough free VRAM.  OOM errors and driver
        /// crashes are your responsibility.
        #[arg(long)]
        no_vram_check: bool,
        /// Print Burn's Learner Summary table at the end of training.
        ///
        /// Shows per-metric train/valid values in a formatted table.
        /// Hidden by default to keep output clean.
        #[arg(long)]
        summary: bool,
        /// Use CPU backend instead of WGPU (for testing on machines without a GPU).
        #[arg(long)]
        cpu: bool,
    },

    /// Run zero-shot inference on a sensor file.
    Infer {
        /// Path to model checkpoint.
        #[arg(short, long)]
        checkpoint: String,
        /// Path to tokeniser model file.
        #[arg(long, default_value = "tokenizer.model")]
        tokenizer: String,
        /// Comma-separated class labels for zero-shot classification.
        #[arg(short, long, default_value = "walking,running,sleeping,sedentary")]
        classes: String,
        /// Use CPU backend.
        #[arg(long)]
        cpu: bool,
    },

    /// Post-training INT8 quantisation.
    Quantize {
        /// FP32 checkpoint to quantise.
        #[arg(short, long)]
        checkpoint: String,
        /// Output path for the quantised model JSON.
        #[arg(short, long, default_value = "./artifacts/model_int8.json")]
        output: String,
        /// Path to calibration dataset (Parquet).
        #[arg(long, default_value = "./data/calibration.parquet")]
        calibration_data: String,
        /// Number of calibration batches.
        #[arg(long, default_value_t = 100)]
        num_batches: usize,
    },

    /// Download a public dataset.
    Download {
        /// Dataset name: `pamap2` or `wesad`.
        #[arg(short, long)]
        dataset: String,
        /// Destination directory.
        #[arg(short, long, default_value = "./data")]
        dest: String,
    },

    /// Generate text captions from a single normalised sensor file.
    GenerateCaptions {
        /// Path to a CSV/Parquet file with columns matching FEATURE_NAMES.
        #[arg(short, long)]
        input: String,
        /// Caption type: low, middle, high-summary, high-all, or combinations.
        #[arg(short, long, default_value = "high-summary")]
        caption_type: String,
        /// Random seed for template selection.
        #[arg(long, default_value_t = 42)]
        seed: u64,
    },
}

fn main() {
    // Initialise tracing.
    //
    // Default filter: suppress noisy WGPU/Metal internals (especially the
    // `Device::maintain: waiting for submission index N` spam that floods
    // the terminal during GPU training), while keeping sensorlm and burn logs.
    //
    // Override with RUST_LOG env var, e.g.:
    //   RUST_LOG=debug cargo run ...    (show everything)
    //   RUST_LOG=error cargo run ...   (errors only)
    tracing_subscriber::fmt()
        .with_env_filter(
            tracing_subscriber::EnvFilter::try_from_default_env()
                .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new(
                    // Global default: warn (suppresses GPU backend spam).
                    // burn_train=error: silences the harmless "Failed to install
                    //   the experiment logger" warn that fires because Burn's
                    //   LearnerBuilder tries to set a second global
                    //   tracing-subscriber after we've already set ours.
                    // sensorlm=info: keep our own pre-flight messages visible.
                    "warn,burn_train=error,sensorlm=info"
                )),
        )
        .init();

    let cli = Cli::parse();

    match cli.command {
        Commands::Train {
            config,
            artifact_dir,
            data_dir,
            model_size,
            batch_size,
            vram_gb,
            num_workers,
            no_vram_check,
            summary,
            cpu,
        } => {
            use sensorlm::config::{ModelSize, TrainingConfig};

            // Parse model-size preset.
            let size = match model_size.to_lowercase().as_str() {
                "tiny"  | "ti" => ModelSize::Tiny,
                "small" | "s"  => ModelSize::Small,
                "base"  | "b"  => ModelSize::Base,
                other => {
                    eprintln!("Unknown model size '{other}'. Choose: tiny | small | base");
                    std::process::exit(1);
                }
            };

            // Build model config from the chosen preset.
            let model_cfg = size.sensorlm_config();

            let mut train_cfg = TrainingConfig::default();
            train_cfg.model_size      = size;
            train_cfg.artifact_dir    = artifact_dir;
            train_cfg.data_dir        = data_dir;
            train_cfg.batch_size      = batch_size;
            train_cfg.vram_gb         = vram_gb;
            train_cfg.num_workers     = num_workers;
            train_cfg.skip_vram_check = no_vram_check;
            train_cfg.show_summary    = summary;

            if let Some(cfg_path) = config {
                eprintln!("Loading config from {cfg_path} (not yet implemented – using defaults)");
            }

            eprintln!(
                "Model: {size:?} ({} params), batch={batch_size}, workers={num_workers}",
                size.approx_params(),
            );

            if cpu {
                // CPU training with NdArray backend.
                use sensorlm::CpuTrainBackend;
                eprintln!("Training on CPU (NdArray backend)…");
                match sensorlm::training::learner::train::<CpuTrainBackend>(model_cfg, train_cfg) {
                    Ok(()) => eprintln!("Training complete."),
                    Err(e) => eprintln!("Training failed: {e}"),
                }
            } else {
                // GPU training – requires --features wgpu.
                // Falling back to CPU if wgpu feature is not enabled.
                #[cfg(feature = "wgpu")]
                {
                    use sensorlm::TrainBackend;
                    eprintln!("Training on GPU (WGPU backend)…");
                    match sensorlm::training::learner::train::<TrainBackend>(model_cfg, train_cfg) {
                        Ok(()) => eprintln!("Training complete."),
                        Err(e) => eprintln!("Training failed: {e}"),
                    }
                }
                #[cfg(not(feature = "wgpu"))]
                {
                    eprintln!("WGPU backend not compiled in. Re-run with `--features wgpu` or use --cpu.");
                    eprintln!("Falling back to CPU (NdArray backend)…");
                    match sensorlm::training::learner::train::<sensorlm::CpuTrainBackend>(model_cfg, train_cfg) {
                        Ok(()) => eprintln!("Training complete."),
                        Err(e) => eprintln!("Training failed: {e}"),
                    }
                }
            }
        }

        Commands::Infer { checkpoint, tokenizer, classes, cpu: _ } => {
            let class_names: Vec<String> = classes.split(',').map(str::trim).map(String::from).collect();
            eprintln!("Running zero-shot inference with {} classes", class_names.len());
            eprintln!("Checkpoint : {checkpoint}");
            eprintln!("Tokenizer  : {tokenizer}");
            eprintln!("Classes    : {}", class_names.join(", "));
            eprintln!("(Full inference pipeline requires a loaded checkpoint – see examples/inference_demo.rs)");
        }

        Commands::Quantize { checkpoint, output, calibration_data, num_batches } => {
            use sensorlm::quantization::int8::quantize_model_weights;
            use sensorlm::config::SensorLMConfig;

            eprintln!("Quantising {checkpoint}{output}");
            eprintln!("Calibration data : {calibration_data}");
            eprintln!("Calibration batches: {num_batches}");

            // Demonstration: quantise a set of randomly generated weights.
            let config = SensorLMConfig::default();
            let config_json = serde_json::to_string(&config).unwrap();

            // In a real run, extract weights from the loaded checkpoint.
            let dummy_layers = vec![
                (
                    "sensor_encoder.patch_embed.proj.weight".to_string(),
                    vec![0.01f32; 768 * 100],
                    vec![768, 100],
                    None::<Vec<f32>>,
                ),
            ];

            let qm = quantize_model_weights(config_json, dummy_layers.into_iter());
            eprintln!(
                "Compression: {:.1}x ({} MB → {} MB)",
                qm.compression_ratio(),
                qm.total_fp32_bytes / (1024 * 1024),
                qm.total_quantized_bytes / (1024 * 1024),
            );

            let out_path = std::path::Path::new(&output);
            if let Some(parent) = out_path.parent() {
                let _ = std::fs::create_dir_all(parent);
            }
            match qm.save(out_path) {
                Ok(()) => eprintln!("Saved quantised model to {output}"),
                Err(e) => eprintln!("Save failed: {e}"),
            }
        }

        Commands::Download { dataset, dest } => {
            use sensorlm::data::download::{download_file, find_dataset};
            use std::path::PathBuf;

            match find_dataset(&dataset) {
                Some(entry) => {
                    let dest_path = PathBuf::from(&dest).join(format!("{}.zip", entry.name));
                    eprintln!("Downloading {} to {}", entry.name, dest_path.display());
                    match download_file(entry.url, &dest_path, entry.sha256) {
                        Ok(()) => eprintln!("Download complete."),
                        Err(e) => eprintln!("Download failed: {e}"),
                    }
                }
                None => {
                    eprintln!(
                        "Unknown dataset '{}'. Available datasets: {}",
                        dataset,
                        sensorlm::data::download::KNOWN_DATASETS
                            .iter()
                            .map(|d| d.name)
                            .collect::<Vec<_>>()
                            .join(", ")
                    );
                }
            }
        }

        Commands::GenerateCaptions { input, caption_type, seed } => {
            use sensorlm::config::CaptionKey;
            use sensorlm::data::captioning::{generate_caption, CaptionContext};
            use ndarray::Array2;
            use rand::{rngs::StdRng, SeedableRng};
            use sensorlm::constants::NUM_CHANNELS;

            eprintln!("Generating '{caption_type}' caption for {input}");

            // Load a dummy zero-valued sensor array for demonstration.
            let x = Array2::<f64>::zeros((1440, NUM_CHANNELS));
            let ctx = CaptionContext::new();
            let key = match caption_type.as_str() {
                "low"          => CaptionKey::LowLevel,
                "middle"       => CaptionKey::MiddleLevel,
                "high-summary" => CaptionKey::HighLevelSummary,
                "high-all"     => CaptionKey::HighLevelAll,
                "middle-low"   => CaptionKey::MiddleLow,
                "high-low"     => CaptionKey::HighLow,
                "high-middle"  => CaptionKey::HighMiddle,
                "all"          => CaptionKey::HighMiddleLow,
                other => {
                    eprintln!("Unknown caption type '{other}', defaulting to high-summary");
                    CaptionKey::HighLevelSummary
                }
            };

            let mut rng = StdRng::seed_from_u64(seed);
            match generate_caption(&x.view(), None, &ctx, key, &mut rng) {
                Ok(caption) => println!("{caption}"),
                Err(e) => eprintln!("Caption generation failed: {e}"),
            }
        }
    }
}