Skip to main content

ternlang_compress/
model.rs

1// TernModel — in-memory representation of a fully quantized LLM.
2//
3// One TernModel holds all compressed layers.  Each TernLayer can store
4// weights in either dense (TritMatrix) or sparse (SparseIndex) form,
5// chosen automatically based on measured sparsity.
6
7use serde::{Deserialize, Serialize};
8use crate::sparse::SparseIndex;
9
10/// Storage format chosen per layer.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub enum LayerStorage {
13    /// Dense 2-bit-packed ternary.  Better for sparsity < ~75%.
14    Dense {
15        rows: usize,
16        cols: usize,
17        /// Trit values packed 4-per-byte (2 bits each: 00=0, 01=+1, 10=-1).
18        packed: Vec<u8>,
19    },
20    /// CSR sparse.  Better for sparsity ≥ ~75%.
21    Sparse(SparseIndex),
22}
23
24/// One compressed weight matrix (one transformer projection, etc.).
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct TernLayer {
27    /// Layer name (e.g. "model.layers.0.self_attn.q_proj.weight").
28    pub name: String,
29    /// Per-layer scale α = mean(|W_original|).
30    pub scale: f32,
31    /// Weight storage (dense or sparse).
32    pub storage: LayerStorage,
33    /// Fraction of zero weights.
34    pub sparsity: f64,
35    /// Original dtype (for metadata / roundtrip info).
36    pub original_dtype: String,
37}
38
39impl TernLayer {
40    /// Total number of parameters in this layer.
41    pub fn num_params(&self) -> usize {
42        match &self.storage {
43            LayerStorage::Dense { rows, cols, .. } => rows * cols,
44            LayerStorage::Sparse(idx) => idx.rows * idx.cols,
45        }
46    }
47
48    /// Memory used by ternary weights in bytes.
49    pub fn memory_bytes(&self) -> usize {
50        match &self.storage {
51            LayerStorage::Dense { rows, cols, .. } => (rows * cols + 3) / 4,
52            LayerStorage::Sparse(idx) => idx.memory_bytes(),
53        }
54    }
55}
56
57/// A complete ternary-quantized LLM.
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct TernModel {
60    /// Source model identifier (e.g. "meta-llama/Llama-3.2-1B").
61    pub source_model: String,
62    /// ternlang-compress format version.
63    pub format_version: u32,
64    /// All compressed weight layers.
65    pub layers: Vec<TernLayer>,
66    /// Model architecture metadata (passed through from source).
67    pub architecture: String,
68    /// Vocabulary size.
69    pub vocab_size: usize,
70    /// Hidden dimension.
71    pub hidden_size: usize,
72    /// Number of transformer layers.
73    pub num_layers: usize,
74}
75
76impl TernModel {
77    /// Total number of parameters across all layers.
78    pub fn total_params(&self) -> usize {
79        self.layers.iter().map(|l| l.num_params()).sum()
80    }
81
82    /// Total ternary weight memory in bytes.
83    pub fn total_memory_bytes(&self) -> usize {
84        self.layers.iter().map(|l| l.memory_bytes()).sum()
85    }
86
87    /// Average sparsity across all layers.
88    pub fn mean_sparsity(&self) -> f64 {
89        if self.layers.is_empty() { return 0.0; }
90        let sum: f64 = self.layers.iter().map(|l| l.sparsity).sum();
91        sum / self.layers.len() as f64
92    }
93
94    /// Compression ratio vs original f16 storage.
95    pub fn compression_ratio_vs_f16(&self) -> f64 {
96        let f16_bytes = self.total_params() * 2;
97        if f16_bytes == 0 { return 1.0; }
98        f16_bytes as f64 / self.total_memory_bytes() as f64
99    }
100
101    /// Human-readable summary.
102    pub fn summary(&self) -> String {
103        format!(
104            "TernModel: {}\n  Params:      {:>12}\n  Ternary mem: {:>12} MB\n  Mean sparsity: {:.1}%\n  vs f16:      {:.1}× smaller",
105            self.source_model,
106            self.total_params(),
107            self.total_memory_bytes() / 1_048_576,
108            self.mean_sparsity() * 100.0,
109            self.compression_ratio_vs_f16(),
110        )
111    }
112}