Skip to main content

cbtop/quantize/
weights.rs

1//! Quantized weight storage and statistics.
2
3use std::collections::HashMap;
4use std::fmt;
5
6use super::format::QuantFormat;
7
8/// Quantized weight storage for a single layer.
9#[derive(Debug, Clone)]
10pub struct QuantizedWeights {
11    /// Quantization format
12    pub format: QuantFormat,
13    /// Raw quantized data
14    pub data: Vec<u8>,
15    /// Shape: [rows, cols] for 2D weights
16    pub shape: (usize, usize),
17    /// Layer name (for debugging)
18    pub layer_name: String,
19}
20
21impl QuantizedWeights {
22    /// Create new quantized weights.
23    pub fn new(format: QuantFormat, data: Vec<u8>, shape: (usize, usize), name: &str) -> Self {
24        Self {
25            format,
26            data,
27            shape,
28            layer_name: name.to_string(),
29        }
30    }
31
32    /// Total number of weights.
33    pub fn num_weights(&self) -> usize {
34        self.shape.0 * self.shape.1
35    }
36
37    /// Memory footprint in bytes.
38    pub fn memory_bytes(&self) -> usize {
39        self.data.len()
40    }
41
42    /// Memory footprint if stored as F16.
43    pub fn f16_memory_bytes(&self) -> usize {
44        self.num_weights() * 2
45    }
46
47    /// Compression ratio (F16 / quantized).
48    pub fn compression_ratio(&self) -> f64 {
49        self.f16_memory_bytes() as f64 / self.memory_bytes() as f64
50    }
51
52    /// Effective bits per weight (actual).
53    pub fn actual_bits_per_weight(&self) -> f64 {
54        (self.data.len() * 8) as f64 / self.num_weights() as f64
55    }
56}
57
58/// Per-layer quantization statistics.
59#[derive(Debug, Clone)]
60pub struct LayerQuantStats {
61    /// Layer name
62    pub name: String,
63    /// Quantization format
64    pub format: QuantFormat,
65    /// Weight count
66    pub weights: usize,
67    /// Memory bytes
68    pub memory_bytes: usize,
69    /// Compression ratio
70    pub compression_ratio: f64,
71}
72
73/// Quantization statistics for a model or layer.
74#[derive(Debug, Clone, Default)]
75pub struct QuantStats {
76    /// Total weights across all layers
77    pub total_weights: usize,
78    /// Total memory (bytes) for quantized weights
79    pub total_memory_bytes: usize,
80    /// Memory if stored as F16
81    pub f16_memory_bytes: usize,
82    /// Weights per format
83    pub weights_by_format: HashMap<QuantFormat, usize>,
84    /// Memory per format
85    pub memory_by_format: HashMap<QuantFormat, usize>,
86    /// Per-layer stats
87    pub layer_stats: Vec<LayerQuantStats>,
88}
89
90impl QuantStats {
91    /// Create new empty stats.
92    pub fn new() -> Self {
93        Self::default()
94    }
95
96    /// Add layer statistics.
97    pub fn add_layer(&mut self, weights: &QuantizedWeights) {
98        self.total_weights += weights.num_weights();
99        self.total_memory_bytes += weights.memory_bytes();
100        self.f16_memory_bytes += weights.f16_memory_bytes();
101
102        *self.weights_by_format.entry(weights.format).or_default() += weights.num_weights();
103        *self.memory_by_format.entry(weights.format).or_default() += weights.memory_bytes();
104
105        self.layer_stats.push(LayerQuantStats {
106            name: weights.layer_name.clone(),
107            format: weights.format,
108            weights: weights.num_weights(),
109            memory_bytes: weights.memory_bytes(),
110            compression_ratio: weights.compression_ratio(),
111        });
112    }
113
114    /// Overall compression ratio.
115    pub fn compression_ratio(&self) -> f64 {
116        if self.total_memory_bytes == 0 {
117            1.0
118        } else {
119            self.f16_memory_bytes as f64 / self.total_memory_bytes as f64
120        }
121    }
122
123    /// Effective bits per weight (average).
124    pub fn avg_bits_per_weight(&self) -> f64 {
125        if self.total_weights == 0 {
126            0.0
127        } else {
128            (self.total_memory_bytes * 8) as f64 / self.total_weights as f64
129        }
130    }
131
132    /// Dominant format (most weights).
133    pub fn dominant_format(&self) -> Option<QuantFormat> {
134        self.weights_by_format
135            .iter()
136            .max_by_key(|(_, count)| *count)
137            .map(|(format, _)| *format)
138    }
139}
140
141impl fmt::Display for QuantStats {
142    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143        writeln!(f, "Quantization Statistics")?;
144        writeln!(f, "======================")?;
145        writeln!(f, "Total Weights: {}", self.total_weights)?;
146        writeln!(
147            f,
148            "Total Memory: {:.2} MB (quantized)",
149            self.total_memory_bytes as f64 / 1_000_000.0
150        )?;
151        writeln!(
152            f,
153            "F16 Memory: {:.2} MB (baseline)",
154            self.f16_memory_bytes as f64 / 1_000_000.0
155        )?;
156        writeln!(f, "Compression: {:.2}x", self.compression_ratio())?;
157        writeln!(f, "Avg Bits/Weight: {:.2}", self.avg_bits_per_weight())?;
158
159        if !self.weights_by_format.is_empty() {
160            writeln!(f)?;
161            writeln!(f, "By Format:")?;
162            for (format, weights) in &self.weights_by_format {
163                let memory = self.memory_by_format.get(format).unwrap_or(&0);
164                writeln!(
165                    f,
166                    "  {}: {} weights, {:.2} MB",
167                    format,
168                    weights,
169                    *memory as f64 / 1_000_000.0
170                )?;
171            }
172        }
173
174        Ok(())
175    }
176}