sensorlm-rs 0.1.0

SensorLM – wearable sensor foundation model in Rust (Burn + WGPU)
Documentation
//! Example: post-training INT8 quantisation of a SensorLM model.
//!
//! Demonstrates:
//! 1. Extracting linear layer weights from the model (simulated here).
//! 2. Calibrating scales via [`Calibrator`].
//! 3. Running [`quantize_model_weights`] to produce a [`QuantizedModel`].
//! 4. Verifying the dequantised weights are close to the originals.
//! 5. Computing the compression ratio.
//! 6. Saving the quantised model to disk.
//!
//! ```
//! cargo run --example quantize_model
//! ```

use sensorlm::{
    config::SensorLMConfig,
    quantization::int8::{quantize_model_weights, Calibrator, QuantizedModel},
};
use std::path::Path;

fn main() {
    println!("=== SensorLM INT8 Post-Training Quantisation ===\n");

    let model_cfg = SensorLMConfig::default();
    let config_json = serde_json::to_string_pretty(&model_cfg).unwrap();

    // -----------------------------------------------------------------------
    // Simulate extracting named weight tensors from a trained model.
    // In a real workflow, you would:
    //   1. Load the FP32 checkpoint.
    //   2. Iterate over model.named_parameters() (Burn API).
    //   3. Collect (name, weight_vec, shape, bias).
    // -----------------------------------------------------------------------
    println!("Simulating weight extraction from a ViT-B SensorLM model…");

    let layer_specs: Vec<(&str, usize, usize)> = vec![
        // Sensor encoder: patch embedding
        ("sensor_encoder.patch_embed.proj.weight", 768, 100),
        // Sensor encoder: transformer blocks (12 × 4 projections)
        ("sensor_encoder.blocks.0.attn.q_proj.weight", 768, 768),
        ("sensor_encoder.blocks.0.attn.k_proj.weight", 768, 768),
        ("sensor_encoder.blocks.0.attn.v_proj.weight", 768, 768),
        ("sensor_encoder.blocks.0.attn.out_proj.weight", 768, 768),
        ("sensor_encoder.blocks.0.mlp.fc1.weight", 3072, 768),
        ("sensor_encoder.blocks.0.mlp.fc2.weight", 768, 3072),
        // Text encoder: embedding projection
        ("text_encoder.proj.weight", 768, 768),
        ("text_encoder.blocks.0.attn.q_proj.weight", 768, 768),
        ("text_encoder.blocks.0.mlp.fc1.weight", 3072, 768),
    ];

    let mut calibrator = Calibrator::new();
    let mut layers = Vec::new();

    for (name, out_features, in_features) in layer_specs {
        // Generate deterministic weights (in practice load from checkpoint).
        let weights: Vec<f32> = (0..out_features * in_features)
            .map(|i| ((i as f32 * 0.001) % 2.0 - 1.0) * 0.02)
            .collect();

        // Calibrate.
        calibrator.record_layer(name.to_string(), &weights);

        // Collect for quantisation.
        layers.push((
            name.to_string(),
            weights,
            vec![out_features, in_features],
            None::<Vec<f32>>,
        ));
    }

    let calibrations = calibrator.finish();
    println!("Calibrated {} layers.", calibrations.len());
    for cal in &calibrations[..3] {
        println!(
            "  {} : w_scale = {:.6}",
            &cal.name[..40.min(cal.name.len())],
            cal.weight_scale()
        );
    }

    // -----------------------------------------------------------------------
    // Quantise.
    // -----------------------------------------------------------------------
    println!("\nQuantising to INT8…");
    let qm = quantize_model_weights(config_json, layers.into_iter());

    println!("\n--- Quantisation results ---");
    println!("Layers quantised       : {}", qm.layers.len());
    println!(
        "Original FP32 size     : {:.1} MB",
        qm.total_fp32_bytes as f64 / 1024.0 / 1024.0
    );
    println!(
        "Quantised INT8 size    : {:.1} MB",
        qm.total_quantized_bytes as f64 / 1024.0 / 1024.0
    );
    println!("Compression ratio      : {:.1}x", qm.compression_ratio());

    // -----------------------------------------------------------------------
    // Verify dequantisation error.
    // -----------------------------------------------------------------------
    println!("\n--- Dequantisation accuracy ---");
    for qw in &qm.layers[..3.min(qm.layers.len())] {
        let dq = qw.dequantize();
        // Compute max absolute error between original and dequantised.
        let orig: Vec<f32> = (0..qw.weights_i8.len())
            .map(|i| ((i as f32 * 0.001) % 2.0 - 1.0) * 0.02)
            .collect();
        let max_err = orig.iter().zip(dq.iter())
            .map(|(a, b)| (a - b).abs())
            .fold(0.0f32, f32::max);
        println!(
            "  {:<50} scale={:.6}  max_err={:.6}",
            &qw.name[..qw.name.len().min(50)],
            qw.scale,
            max_err
        );
    }

    // -----------------------------------------------------------------------
    // Save.
    // -----------------------------------------------------------------------
    let out_path = Path::new("/tmp/sensorlm_int8.json");
    match qm.save(out_path) {
        Ok(()) => {
            println!("\nSaved quantised model to {}", out_path.display());

            // Reload and verify.
            let loaded = QuantizedModel::load(out_path).unwrap();
            println!("Reloaded model: {} layers, {:.1}x compression",
                loaded.layers.len(), loaded.compression_ratio());
        }
        Err(e) => eprintln!("Save failed: {e}"),
    }

    println!("\n=== Done ===");
}