use serde::{Deserialize, Serialize};
use crate::sparse::SparseIndex;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LayerStorage {
Dense {
rows: usize,
cols: usize,
packed: Vec<u8>,
},
Sparse(SparseIndex),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TernLayer {
pub name: String,
pub scale: f32,
pub storage: LayerStorage,
pub sparsity: f64,
pub original_dtype: String,
}
impl TernLayer {
pub fn num_params(&self) -> usize {
match &self.storage {
LayerStorage::Dense { rows, cols, .. } => rows * cols,
LayerStorage::Sparse(idx) => idx.rows * idx.cols,
}
}
pub fn memory_bytes(&self) -> usize {
match &self.storage {
LayerStorage::Dense { rows, cols, .. } => (rows * cols + 3) / 4,
LayerStorage::Sparse(idx) => idx.memory_bytes(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TernModel {
pub source_model: String,
pub format_version: u32,
pub layers: Vec<TernLayer>,
pub architecture: String,
pub vocab_size: usize,
pub hidden_size: usize,
pub num_layers: usize,
}
impl TernModel {
pub fn total_params(&self) -> usize {
self.layers.iter().map(|l| l.num_params()).sum()
}
pub fn total_memory_bytes(&self) -> usize {
self.layers.iter().map(|l| l.memory_bytes()).sum()
}
pub fn mean_sparsity(&self) -> f64 {
if self.layers.is_empty() { return 0.0; }
let sum: f64 = self.layers.iter().map(|l| l.sparsity).sum();
sum / self.layers.len() as f64
}
pub fn compression_ratio_vs_f16(&self) -> f64 {
let f16_bytes = self.total_params() * 2;
if f16_bytes == 0 { return 1.0; }
f16_bytes as f64 / self.total_memory_bytes() as f64
}
pub fn summary(&self) -> String {
format!(
"TernModel: {}\n Params: {:>12}\n Ternary mem: {:>12} MB\n Mean sparsity: {:.1}%\n vs f16: {:.1}× smaller",
self.source_model,
self.total_params(),
self.total_memory_bytes() / 1_048_576,
self.mean_sparsity() * 100.0,
self.compression_ratio_vs_f16(),
)
}
}