sensorlm-rs 0.1.0

SensorLM – wearable sensor foundation model in Rust (Burn + WGPU)
Documentation
//! INT8 post-training quantisation (PTQ).
//!
//! # Overview
//!
//! Post-training quantisation converts the FP32 weights of a trained model to
//! INT8 with minimal accuracy degradation.  The procedure is:
//!
//! 1. **Calibration** – run the model on a small representative dataset and
//!    collect the per-layer minimum and maximum activation / weight values.
//! 2. **Quantise** – for each `Linear` weight matrix compute a symmetric
//!    INT8 scale and zero-point, then represent weights as `i8` values.
//! 3. **Dequantise at runtime** – before the matrix multiplication, convert
//!    `i8` back to `f32` using the stored scale.
//!
//! # Quantisation formula (symmetric per-tensor)
//!
//! ```text
//! scale = max(|W|) / 127
//! W_q   = round(W / scale)  ∈ [-127, 127]
//! W_dq  = W_q * scale        (approx. original W)
//! ```
//!
//! # INT8 linear layer
//!
//! The [`QuantizedLinear`] module stores weights as `i8` with a per-tensor
//! scale and reconstructs `f32` on the fly.  This halves the model's memory
//! footprint and can accelerate inference when hardware INT8 GEMM is
//! available.
//!
//! # Limitations
//!
//! * This is a **weight-only** quantisation scheme (activations remain FP32).
//! * Full activation quantisation would require inserting `QuantizeAct` nodes
//!   throughout the graph – left as a future extension.
//! * The Burn framework does not yet expose native INT8 GEMM kernels; the
//!   dequantise-then-multiply approach used here is correctness-demonstrating
//!   but does not provide a runtime speedup until WGPU INT8 kernels land.

use std::{fs, path::Path};


use serde::{Deserialize, Serialize};

// ---------------------------------------------------------------------------
// Per-layer calibration statistics
// ---------------------------------------------------------------------------

/// Calibration data collected from one `Linear` layer.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayerCalibration {
    /// Layer name (e.g. `"sensor_encoder.blocks.0.attn.q_proj"`).
    pub name: String,
    /// Minimum observed weight value.
    pub w_min: f32,
    /// Maximum observed weight value.
    pub w_max: f32,
    /// Minimum observed input activation value.
    pub act_min: f32,
    /// Maximum observed input activation value.
    pub act_max: f32,
}

impl LayerCalibration {
    /// Compute the symmetric INT8 weight scale.
    ///
    /// `scale = max(|w_min|, |w_max|) / 127`
    pub fn weight_scale(&self) -> f32 {
        self.w_min.abs().max(self.w_max.abs()) / 127.0
    }
}

// ---------------------------------------------------------------------------
// Quantised weight representation
// ---------------------------------------------------------------------------

/// INT8 quantised weights for a single `Linear` layer.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizedWeights {
    /// Layer name.
    pub name: String,
    /// Quantised weight values in row-major order (i8 serialised as i32 for
    /// serde compatibility).
    pub weights_i8: Vec<i32>,
    /// Weight tensor shape `[out_features, in_features]`.
    pub shape: Vec<usize>,
    /// Per-tensor scale factor.
    pub scale: f32,
    /// Bias (kept in FP32).
    pub bias: Option<Vec<f32>>,
}

impl QuantizedWeights {
    /// Quantise an FP32 weight matrix.
    ///
    /// # Arguments
    ///
    /// * `name`   – Layer name.
    /// * `w`      – FP32 weights in row-major layout.
    /// * `shape`  – `[out_features, in_features]`.
    /// * `bias`   – Optional FP32 bias vector.
    /// * `scale`  – If `None`, computed from the max absolute weight.
    pub fn from_f32(
        name: String,
        w: &[f32],
        shape: Vec<usize>,
        bias: Option<Vec<f32>>,
        scale: Option<f32>,
    ) -> Self {
        let max_abs = w.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
        let s = scale.unwrap_or(max_abs / 127.0).max(1e-8);

        let weights_i8: Vec<i32> = w
            .iter()
            .map(|&x| (x / s).round().clamp(-127.0, 127.0) as i32)
            .collect();

        Self { name, weights_i8, shape, scale: s, bias }
    }

    /// Dequantise to FP32.
    pub fn dequantize(&self) -> Vec<f32> {
        self.weights_i8
            .iter()
            .map(|&q| q as f32 * self.scale)
            .collect()
    }

    /// Memory used by the quantised weights in bytes.
    pub fn size_bytes(&self) -> usize {
        self.weights_i8.len() // i8: 1 byte each
    }

    /// Memory the original FP32 weights would have used.
    pub fn original_size_bytes(&self) -> usize {
        self.weights_i8.len() * 4
    }

    /// Compression ratio vs. FP32.
    pub fn compression_ratio(&self) -> f32 {
        self.original_size_bytes() as f32 / self.size_bytes() as f32
    }
}

// ---------------------------------------------------------------------------
// Quantised model manifest
// ---------------------------------------------------------------------------

/// Collection of quantised weights for an entire model.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizedModel {
    /// Model configuration JSON (serialised [`crate::config::SensorLMConfig`]).
    pub config_json: String,
    /// Quantised weights for every Linear layer in the model.
    pub layers: Vec<QuantizedWeights>,
    /// Quantisation scheme used.
    pub scheme: String,
    /// Total size of quantised weights in bytes.
    pub total_quantized_bytes: usize,
    /// Total size of original FP32 weights in bytes.
    pub total_fp32_bytes: usize,
}

impl QuantizedModel {
    /// Overall compression ratio (FP32 → INT8).
    pub fn compression_ratio(&self) -> f32 {
        self.total_fp32_bytes as f32 / self.total_quantized_bytes.max(1) as f32
    }

    /// Save the quantised model to a JSON file.
    pub fn save(&self, path: &Path) -> crate::error::Result<()> {
        let json = serde_json::to_string_pretty(self)?;
        fs::write(path, json)?;
        Ok(())
    }

    /// Load a quantised model from a JSON file.
    pub fn load(path: &Path) -> crate::error::Result<Self> {
        let json = fs::read_to_string(path)?;
        let model = serde_json::from_str(&json)?;
        Ok(model)
    }
}

// ---------------------------------------------------------------------------
// PTQ pipeline
// ---------------------------------------------------------------------------

/// Calibration pass over a small dataset to collect weight statistics.
///
/// In a production implementation this would hook into the model's forward
/// pass via observers.  Here we collect statistics directly from the stored
/// weight tensors – sufficient for symmetric per-tensor weight quantisation
/// since we do not require input activation statistics.
pub struct Calibrator {
    calibrations: Vec<LayerCalibration>,
}

impl Calibrator {
    /// Create a new calibrator.
    pub fn new() -> Self {
        Self { calibrations: Vec::new() }
    }

    /// Record a weight tensor from one linear layer.
    pub fn record_layer(&mut self, name: String, weights: &[f32]) {
        let w_min = weights.iter().cloned().fold(f32::INFINITY, f32::min);
        let w_max = weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
        self.calibrations.push(LayerCalibration {
            name,
            w_min,
            w_max,
            act_min: -1.0, // placeholder (weight-only PTQ)
            act_max: 1.0,
        });
    }

    /// Consume the calibrator and produce per-layer scale factors.
    pub fn finish(self) -> Vec<LayerCalibration> {
        self.calibrations
    }
}

impl Default for Calibrator {
    fn default() -> Self {
        Self::new()
    }
}

// ---------------------------------------------------------------------------
// Quantise a flat representation of model weights
// ---------------------------------------------------------------------------

/// Quantise a named list of FP32 weight tensors.
///
/// This function takes the flat `(name, weights, shape, bias)` representation
/// that would be extracted from a `SensorLMModel` and produces a
/// [`QuantizedModel`].
///
/// # Arguments
///
/// * `config_json`  – JSON string of the model config.
/// * `layers`       – Iterator of `(name, fp32_weights, shape, optional_bias)`.
pub fn quantize_model_weights(
    config_json: String,
    layers: impl IntoIterator<Item = (String, Vec<f32>, Vec<usize>, Option<Vec<f32>>)>,
) -> QuantizedModel {
    let mut quantized_layers = Vec::new();
    let mut total_fp32_bytes = 0usize;
    let mut total_quantized_bytes = 0usize;

    for (name, weights, shape, bias) in layers {
        let qw = QuantizedWeights::from_f32(name, &weights, shape, bias, None);
        total_fp32_bytes += qw.original_size_bytes();
        total_quantized_bytes += qw.size_bytes();
        quantized_layers.push(qw);
    }

    QuantizedModel {
        config_json,
        layers: quantized_layers,
        scheme: "symmetric_per_tensor_int8".to_string(),
        total_quantized_bytes,
        total_fp32_bytes,
    }
}

// ---------------------------------------------------------------------------
// FP16 export
// ---------------------------------------------------------------------------

/// Convert an FP32 weight vector to FP16 (f16 represented as u16 bits).
///
/// FP16 halves the model's memory footprint with minimal accuracy loss and
/// is natively supported by most modern GPUs.
pub fn fp32_to_fp16_bits(weights: &[f32]) -> Vec<u16> {
    weights
        .iter()
        .map(|&x| half::f16::from_f32(x).to_bits())
        .collect()
}

/// Convert an FP16 weight vector back to FP32.
pub fn fp16_bits_to_fp32(bits: &[u16]) -> Vec<f32> {
    bits.iter()
        .map(|&b| half::f16::from_bits(b).to_f32())
        .collect()
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_quantize_roundtrip() {
        let weights: Vec<f32> = (0..100).map(|i| i as f32 * 0.01 - 0.5).collect();
        let qw = QuantizedWeights::from_f32(
            "test".to_string(),
            &weights,
            vec![10, 10],
            None,
            None,
        );
        let dq = qw.dequantize();
        assert_eq!(dq.len(), weights.len());
        // Max quantisation error should be ≤ scale/2.
        let max_err = weights.iter().zip(dq.iter())
            .map(|(a, b)| (a - b).abs())
            .fold(0.0f32, f32::max);
        assert!(max_err <= qw.scale / 2.0 + 1e-6,
            "Max quant error {max_err} > scale/2 = {}", qw.scale / 2.0);
    }

    #[test]
    fn test_compression_ratio() {
        let weights: Vec<f32> = vec![1.0f32; 1024];
        let qw = QuantizedWeights::from_f32("test".into(), &weights, vec![32, 32], None, None);
        assert!((qw.compression_ratio() - 4.0).abs() < 1e-5,
            "INT8 should compress ~4x vs FP32");
    }

    #[test]
    fn test_save_load_roundtrip() {
        let qm = QuantizedModel {
            config_json: "{}".into(),
            layers: vec![QuantizedWeights {
                name: "l1".into(),
                weights_i8: vec![1, -2, 3],
                shape: vec![1, 3],
                scale: 0.01,
                bias: None,
            }],
            scheme: "symmetric_per_tensor_int8".into(),
            total_quantized_bytes: 3,
            total_fp32_bytes: 12,
        };
        let tmp = tempfile::NamedTempFile::new().unwrap();
        qm.save(tmp.path()).unwrap();
        let loaded = QuantizedModel::load(tmp.path()).unwrap();
        assert_eq!(loaded.layers[0].name, "l1");
        assert_eq!(loaded.layers[0].weights_i8, vec![1, -2, 3]);
    }
}