use burn_core as burn;
use burn_core::module::Module;
use burn_flex::Flex;
use burn_nn::{LayerNorm, LayerNormConfig, Linear, LinearConfig};
use burn_store::{BurnpackStore, HalfPrecisionAdapter, ModuleSnapshot};
use burn_tensor::backend::Backend;
#[derive(Module, Debug)]
struct DemoModel<B: Backend> {
linear1: Linear<B>,
norm: LayerNorm<B>,
linear2: Linear<B>,
}
impl<B: Backend> DemoModel<B> {
fn new(device: &B::Device) -> Self {
Self {
linear1: LinearConfig::new(128, 64).init(device),
norm: LayerNormConfig::new(64).init(device),
linear2: LinearConfig::new(64, 10).init(device),
}
}
}
fn main() {
type B = Flex;
let device = Default::default();
let model = DemoModel::<B>::new(&device);
let dir = tempfile::tempdir().expect("Failed to create temp dir");
let path_f32 = dir.path().join("model_f32");
let path_f16 = dir.path().join("model_f16");
let path_mixed = dir.path().join("model_mixed");
let mut store = BurnpackStore::from_file(path_f32.to_str().unwrap()).overwrite(true);
model.save_into(&mut store).expect("Failed to save F32");
let size_f32 = std::fs::metadata(format!("{}.bpk", path_f32.display()))
.map(|m| m.len())
.unwrap_or(0);
let adapter = HalfPrecisionAdapter::new();
let mut store = BurnpackStore::from_file(path_f16.to_str().unwrap())
.overwrite(true)
.with_to_adapter(adapter.clone());
model.save_into(&mut store).expect("Failed to save F16");
let size_f16 = std::fs::metadata(format!("{}.bpk", path_f16.display()))
.map(|m| m.len())
.unwrap_or(0);
let adapter_no_norm = HalfPrecisionAdapter::new().without_module("LayerNorm");
let mut store = BurnpackStore::from_file(path_mixed.to_str().unwrap())
.overwrite(true)
.with_to_adapter(adapter_no_norm);
model.save_into(&mut store).expect("Failed to save mixed");
let size_mixed = std::fs::metadata(format!("{}.bpk", path_mixed.display()))
.map(|m| m.len())
.unwrap_or(0);
println!("F32 (full precision): {} bytes", size_f32);
println!("F16 (default modules): {} bytes", size_f16);
println!("Mixed (norm stays F32): {} bytes", size_mixed);
println!(
"F16 savings: {:.1}%",
(1.0 - size_f16 as f64 / size_f32 as f64) * 100.0
);
let mut load_store =
BurnpackStore::from_file(path_f16.to_str().unwrap()).with_from_adapter(adapter);
let mut model2 = DemoModel::<B>::new(&device);
let result = model2.load_from(&mut load_store).expect("Failed to load");
println!(
"\nRound-trip loaded {} tensors successfully",
result.applied.len()
);
}