cbtop/quantize/
weights.rs1use std::collections::HashMap;
4use std::fmt;
5
6use super::format::QuantFormat;
7
8#[derive(Debug, Clone)]
10pub struct QuantizedWeights {
11 pub format: QuantFormat,
13 pub data: Vec<u8>,
15 pub shape: (usize, usize),
17 pub layer_name: String,
19}
20
21impl QuantizedWeights {
22 pub fn new(format: QuantFormat, data: Vec<u8>, shape: (usize, usize), name: &str) -> Self {
24 Self {
25 format,
26 data,
27 shape,
28 layer_name: name.to_string(),
29 }
30 }
31
32 pub fn num_weights(&self) -> usize {
34 self.shape.0 * self.shape.1
35 }
36
37 pub fn memory_bytes(&self) -> usize {
39 self.data.len()
40 }
41
42 pub fn f16_memory_bytes(&self) -> usize {
44 self.num_weights() * 2
45 }
46
47 pub fn compression_ratio(&self) -> f64 {
49 self.f16_memory_bytes() as f64 / self.memory_bytes() as f64
50 }
51
52 pub fn actual_bits_per_weight(&self) -> f64 {
54 (self.data.len() * 8) as f64 / self.num_weights() as f64
55 }
56}
57
58#[derive(Debug, Clone)]
60pub struct LayerQuantStats {
61 pub name: String,
63 pub format: QuantFormat,
65 pub weights: usize,
67 pub memory_bytes: usize,
69 pub compression_ratio: f64,
71}
72
73#[derive(Debug, Clone, Default)]
75pub struct QuantStats {
76 pub total_weights: usize,
78 pub total_memory_bytes: usize,
80 pub f16_memory_bytes: usize,
82 pub weights_by_format: HashMap<QuantFormat, usize>,
84 pub memory_by_format: HashMap<QuantFormat, usize>,
86 pub layer_stats: Vec<LayerQuantStats>,
88}
89
90impl QuantStats {
91 pub fn new() -> Self {
93 Self::default()
94 }
95
96 pub fn add_layer(&mut self, weights: &QuantizedWeights) {
98 self.total_weights += weights.num_weights();
99 self.total_memory_bytes += weights.memory_bytes();
100 self.f16_memory_bytes += weights.f16_memory_bytes();
101
102 *self.weights_by_format.entry(weights.format).or_default() += weights.num_weights();
103 *self.memory_by_format.entry(weights.format).or_default() += weights.memory_bytes();
104
105 self.layer_stats.push(LayerQuantStats {
106 name: weights.layer_name.clone(),
107 format: weights.format,
108 weights: weights.num_weights(),
109 memory_bytes: weights.memory_bytes(),
110 compression_ratio: weights.compression_ratio(),
111 });
112 }
113
114 pub fn compression_ratio(&self) -> f64 {
116 if self.total_memory_bytes == 0 {
117 1.0
118 } else {
119 self.f16_memory_bytes as f64 / self.total_memory_bytes as f64
120 }
121 }
122
123 pub fn avg_bits_per_weight(&self) -> f64 {
125 if self.total_weights == 0 {
126 0.0
127 } else {
128 (self.total_memory_bytes * 8) as f64 / self.total_weights as f64
129 }
130 }
131
132 pub fn dominant_format(&self) -> Option<QuantFormat> {
134 self.weights_by_format
135 .iter()
136 .max_by_key(|(_, count)| *count)
137 .map(|(format, _)| *format)
138 }
139}
140
141impl fmt::Display for QuantStats {
142 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143 writeln!(f, "Quantization Statistics")?;
144 writeln!(f, "======================")?;
145 writeln!(f, "Total Weights: {}", self.total_weights)?;
146 writeln!(
147 f,
148 "Total Memory: {:.2} MB (quantized)",
149 self.total_memory_bytes as f64 / 1_000_000.0
150 )?;
151 writeln!(
152 f,
153 "F16 Memory: {:.2} MB (baseline)",
154 self.f16_memory_bytes as f64 / 1_000_000.0
155 )?;
156 writeln!(f, "Compression: {:.2}x", self.compression_ratio())?;
157 writeln!(f, "Avg Bits/Weight: {:.2}", self.avg_bits_per_weight())?;
158
159 if !self.weights_by_format.is_empty() {
160 writeln!(f)?;
161 writeln!(f, "By Format:")?;
162 for (format, weights) in &self.weights_by_format {
163 let memory = self.memory_by_format.get(format).unwrap_or(&0);
164 writeln!(
165 f,
166 " {}: {} weights, {:.2} MB",
167 format,
168 weights,
169 *memory as f64 / 1_000_000.0
170 )?;
171 }
172 }
173
174 Ok(())
175 }
176}