use sensorlm::{
config::SensorLMConfig,
quantization::int8::{quantize_model_weights, Calibrator, QuantizedModel},
};
use std::path::Path;
fn main() {
println!("=== SensorLM INT8 Post-Training Quantisation ===\n");
let model_cfg = SensorLMConfig::default();
let config_json = serde_json::to_string_pretty(&model_cfg).unwrap();
println!("Simulating weight extraction from a ViT-B SensorLM model…");
let layer_specs: Vec<(&str, usize, usize)> = vec![
("sensor_encoder.patch_embed.proj.weight", 768, 100),
("sensor_encoder.blocks.0.attn.q_proj.weight", 768, 768),
("sensor_encoder.blocks.0.attn.k_proj.weight", 768, 768),
("sensor_encoder.blocks.0.attn.v_proj.weight", 768, 768),
("sensor_encoder.blocks.0.attn.out_proj.weight", 768, 768),
("sensor_encoder.blocks.0.mlp.fc1.weight", 3072, 768),
("sensor_encoder.blocks.0.mlp.fc2.weight", 768, 3072),
("text_encoder.proj.weight", 768, 768),
("text_encoder.blocks.0.attn.q_proj.weight", 768, 768),
("text_encoder.blocks.0.mlp.fc1.weight", 3072, 768),
];
let mut calibrator = Calibrator::new();
let mut layers = Vec::new();
for (name, out_features, in_features) in layer_specs {
let weights: Vec<f32> = (0..out_features * in_features)
.map(|i| ((i as f32 * 0.001) % 2.0 - 1.0) * 0.02)
.collect();
calibrator.record_layer(name.to_string(), &weights);
layers.push((
name.to_string(),
weights,
vec![out_features, in_features],
None::<Vec<f32>>,
));
}
let calibrations = calibrator.finish();
println!("Calibrated {} layers.", calibrations.len());
for cal in &calibrations[..3] {
println!(
" {} : w_scale = {:.6}",
&cal.name[..40.min(cal.name.len())],
cal.weight_scale()
);
}
println!("\nQuantising to INT8…");
let qm = quantize_model_weights(config_json, layers.into_iter());
println!("\n--- Quantisation results ---");
println!("Layers quantised : {}", qm.layers.len());
println!(
"Original FP32 size : {:.1} MB",
qm.total_fp32_bytes as f64 / 1024.0 / 1024.0
);
println!(
"Quantised INT8 size : {:.1} MB",
qm.total_quantized_bytes as f64 / 1024.0 / 1024.0
);
println!("Compression ratio : {:.1}x", qm.compression_ratio());
println!("\n--- Dequantisation accuracy ---");
for qw in &qm.layers[..3.min(qm.layers.len())] {
let dq = qw.dequantize();
let orig: Vec<f32> = (0..qw.weights_i8.len())
.map(|i| ((i as f32 * 0.001) % 2.0 - 1.0) * 0.02)
.collect();
let max_err = orig.iter().zip(dq.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
println!(
" {:<50} scale={:.6} max_err={:.6}",
&qw.name[..qw.name.len().min(50)],
qw.scale,
max_err
);
}
let out_path = Path::new("/tmp/sensorlm_int8.json");
match qm.save(out_path) {
Ok(()) => {
println!("\nSaved quantised model to {}", out_path.display());
let loaded = QuantizedModel::load(out_path).unwrap();
println!("Reloaded model: {} layers, {:.1}x compression",
loaded.layers.len(), loaded.compression_ratio());
}
Err(e) => eprintln!("Save failed: {e}"),
}
println!("\n=== Done ===");
}