1use anyhow::Result;
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
9pub enum QuantKernel {
10 Q4kGemv,
11 Q5kGemv,
12 Q6kGemv,
13 Q8Gemv,
14 Nf4Gemv,
15}
16
17impl std::str::FromStr for QuantKernel {
18 type Err = anyhow::Error;
19 fn from_str(s: &str) -> Result<Self> {
20 match s {
21 "q4k_gemv" | "q4k" => Ok(Self::Q4kGemv),
22 "q5k_gemv" | "q5k" => Ok(Self::Q5kGemv),
23 "q6k_gemv" | "q6k" => Ok(Self::Q6kGemv),
24 "q8_gemv" | "q8" => Ok(Self::Q8Gemv),
25 "nf4_gemv" | "nf4" => Ok(Self::Nf4Gemv),
26 _ => anyhow::bail!("Unknown quant kernel: {s}. Supported: q4k_gemv, q5k_gemv, q6k_gemv, q8_gemv, nf4_gemv"),
27 }
28 }
29}
30
31impl QuantKernel {
32 pub fn superblock_elements(&self) -> u32 {
34 match self {
35 QuantKernel::Q4kGemv => 256,
36 QuantKernel::Q5kGemv => 256,
37 QuantKernel::Q6kGemv => 256,
38 QuantKernel::Q8Gemv => 256,
39 QuantKernel::Nf4Gemv => 64,
40 }
41 }
42
43 pub fn superblock_bytes(&self) -> u32 {
45 match self {
46 QuantKernel::Q4kGemv => 144, QuantKernel::Q5kGemv => 176,
48 QuantKernel::Q6kGemv => 210,
49 QuantKernel::Q8Gemv => 256, QuantKernel::Nf4Gemv => 32, }
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct QuantProfile {
58 pub kernel: QuantKernel,
59 pub dimensions: [u32; 3],
60 pub superblocks_per_sec: f64,
62 pub effective_bandwidth_gbps: f64,
64 pub compression_speedup: f64,
66 pub wall_time_us: f64,
68}
69
70fn parse_dimensions(size: &str) -> Result<[u32; 3]> {
72 let parts: Vec<&str> = size.split('x').collect();
73 if parts.len() != 3 {
74 anyhow::bail!("Dimensions must be MxNxK format, got: {size}");
75 }
76 let m: u32 = parts[0]
77 .parse()
78 .map_err(|_| anyhow::anyhow!("Invalid M: {}", parts[0]))?;
79 let n: u32 = parts[1]
80 .parse()
81 .map_err(|_| anyhow::anyhow!("Invalid N: {}", parts[1]))?;
82 let k: u32 = parts[2]
83 .parse()
84 .map_err(|_| anyhow::anyhow!("Invalid K: {}", parts[2]))?;
85 Ok([m, n, k])
86}
87
88pub fn profile_quant(kernel_name: &str, size: &str) -> Result<()> {
90 let kernel: QuantKernel = kernel_name.parse()?;
91 let dims = parse_dimensions(size)?;
92
93 println!("\n=== CGP Quant Profile: {kernel_name} ({size}) ===\n");
94 println!(" Kernel: {kernel:?}");
95 println!(" Dimensions: M={}, N={}, K={}", dims[0], dims[1], dims[2]);
96 println!(
97 " Super-block: {} elements, {} bytes",
98 kernel.superblock_elements(),
99 kernel.superblock_bytes()
100 );
101
102 let total_elements = dims[0] as u64 * dims[2] as u64;
104 let num_superblocks = total_elements / kernel.superblock_elements() as u64;
105 let compressed_bytes = num_superblocks * kernel.superblock_bytes() as u64;
106 let fp32_bytes = total_elements * 4;
107
108 println!(" Total weights: {total_elements}");
109 println!(" Super-blocks: {num_superblocks}");
110 println!(" Compressed size: {:.2} MB", compressed_bytes as f64 / 1e6);
111 println!(" FP32 equivalent: {:.2} MB", fp32_bytes as f64 / 1e6);
112 println!(
113 " Compression ratio: {:.1}x",
114 fp32_bytes as f64 / compressed_bytes as f64
115 );
116
117 if let Some(timing) = parse_q4k_timing(dims[0], dims[2]) {
119 println!("\n Measured (from benchmark_matrix_suite):");
120 println!(" Time: {:.1} us", timing.time_us);
121 println!(" GFLOPS: {:.1}", timing.gflops);
122 println!(" Effective BW: {:.1} GB/s (compressed)", timing.bw_gbps);
123 let sbs_per_sec = num_superblocks as f64 / (timing.time_us / 1e6);
124 println!(" Super-blocks/sec: {:.0}", sbs_per_sec);
125
126 let flops = 2.0 * dims[0] as f64 * dims[2] as f64; let ai = flops / compressed_bytes as f64; println!("\n Roofline Analysis (compressed):");
130 println!(" Arithmetic Intensity: {:.1} FLOP/byte", ai);
131
132 let peak_bw_gbps = timing.bw_gbps; let peak_flops = timing.gflops;
135
136 let theoretical_peak_gflops = 150.0;
138 let compute_pct = peak_flops / theoretical_peak_gflops * 100.0;
139
140 let theoretical_bw_gbps = 40.0;
142 let bw_pct = peak_bw_gbps / theoretical_bw_gbps * 100.0;
143
144 println!(
145 " Compute util: {:.0}% of AVX-512 peak (~150 GFLOP/s)",
146 compute_pct
147 );
148 println!(
149 " Bandwidth util: {:.0}% of practical DRAM (~40 GB/s)",
150 bw_pct
151 );
152
153 if bw_pct > compute_pct {
154 println!(" Bottleneck: COMPUTE-BOUND (fused dequant+dot overhead)");
155 } else {
156 println!(" Bottleneck: MEMORY-BOUND (limited by DRAM read throughput)");
157 }
158
159 let token_time_ms = timing.time_us * 192.0 / 1000.0;
163 let tokens_per_sec = 1000.0 / token_time_ms;
164 println!("\n LLM Token Estimation (Llama-7B-like, {kernel_name}):");
165 println!(" Per-layer GEMV: {:.1} us", timing.time_us);
166 println!(" Est. 192 GEMVs/token: {:.1} ms", token_time_ms);
167 println!(" Est. tokens/sec: {:.1}", tokens_per_sec);
168 } else {
169 println!("\n No timing data (build benchmark: cargo build --release --example benchmark_matrix_suite --features parallel)");
170 }
171
172 println!();
173 Ok(())
174}
175
176const STANDARD_LAYERS: &[(&str, u32, u32)] = &[
178 ("ffn_up/gate (1.5B-7B)", 1536, 8960),
179 ("ffn_down (1.5B-7B)", 8960, 1536),
180 ("attn_qkv (1.5B-7B)", 1536, 1536),
181 ("generic_4K", 4096, 4096),
182 ("ffn_up (13B)", 5120, 13824),
183 ("ffn_down (13B)", 13824, 5120),
184 ("attn_qkv (13B)", 5120, 5120),
185];
186
187pub fn profile_quant_all() -> Result<()> {
189 println!("\n=== CGP Quant Sweep: Q4K GEMV — All Standard LLM Layers ===\n");
190
191 let binary = find_bench_binary();
192 let bench_output = binary.and_then(|b| {
193 std::process::Command::new(&b)
194 .output()
195 .ok()
196 .filter(|o| o.status.success())
197 .map(|o| String::from_utf8_lossy(&o.stdout).to_string())
198 });
199
200 println!(
201 " {:25} {:>6}x{:<6} {:>10} {:>10} {:>10} {:>10}",
202 "Layer", "M", "K", "Time (us)", "GFLOPS", "BW GB/s", "tok/s est"
203 );
204 println!(" {}", "-".repeat(85));
205
206 let mut total_time_us = 0.0;
207 let mut measured_count = 0;
208
209 for (label, out_dim, in_dim) in STANDARD_LAYERS {
210 let timing = bench_output.as_ref().and_then(|stdout| {
211 let pattern = format!("{}x{}", out_dim, in_dim);
212 for line in stdout.lines() {
213 if line.contains("Q4K GEMV") && line.contains(&pattern) {
214 let time_us = extract_between(line, "...", " us")
215 .and_then(|s| s.trim().parse::<f64>().ok())?;
216 let gflops = extract_between(line, "(", " GFLOPS")
217 .and_then(|s| s.trim().parse::<f64>().ok())?;
218 let bw_gbps = extract_between(line, "GFLOPS, ", " GB/s")
219 .and_then(|s| s.trim().parse::<f64>().ok())?;
220 return Some(Q4kTiming {
221 time_us,
222 gflops,
223 bw_gbps,
224 });
225 }
226 }
227 None
228 });
229
230 if let Some(t) = timing {
231 let tok_per_sec = 1000.0 / (t.time_us * 192.0 / 1000.0);
233 println!(
234 " {:25} {:>6}x{:<6} {:>10.1} {:>10.1} {:>10.1} {:>10.1}",
235 label, out_dim, in_dim, t.time_us, t.gflops, t.bw_gbps, tok_per_sec
236 );
237 total_time_us += t.time_us;
238 measured_count += 1;
239 } else {
240 println!(
241 " {:25} {:>6}x{:<6} {:>10} {:>10} {:>10} {:>10}",
242 label, out_dim, in_dim, "-", "-", "-", "-"
243 );
244 }
245 }
246
247 if measured_count > 0 {
248 println!(" {}", "-".repeat(85));
249 let avg_gflops = STANDARD_LAYERS
250 .iter()
251 .take(4) .count();
253 println!(
254 "\n Summary ({measured_count}/{} layers measured):",
255 STANDARD_LAYERS.len()
256 );
257 let _ = avg_gflops;
258 let avg_time = total_time_us / measured_count as f64;
259 let composite_tok_s = 1000.0 / (avg_time * 192.0 / 1000.0);
260 println!(" Avg GEMV time: {:.1} us", avg_time);
261 println!(" Composite tok/s estimate: {:.1}", composite_tok_s);
262 println!(
263 " Total GEMV time (measured layers): {:.1} us",
264 total_time_us
265 );
266 } else {
267 println!("\n No benchmark data available.");
268 println!(
269 " Build: cargo build --release --example benchmark_matrix_suite --features parallel"
270 );
271 }
272
273 println!();
274 Ok(())
275}
276
277struct Q4kTiming {
279 time_us: f64,
280 gflops: f64,
281 bw_gbps: f64,
282}
283
284fn parse_q4k_timing(out_dim: u32, in_dim: u32) -> Option<Q4kTiming> {
287 let binary = find_bench_binary()?;
288 let output = std::process::Command::new(&binary).output().ok()?;
289 if !output.status.success() {
290 return None;
291 }
292 let stdout = String::from_utf8_lossy(&output.stdout);
293 let pattern = format!("{}x{}", out_dim, in_dim);
294
295 for line in stdout.lines() {
296 if line.contains("Q4K GEMV") && line.contains(&pattern) {
297 let time_us =
299 extract_between(line, "...", " us").and_then(|s| s.trim().parse::<f64>().ok())?;
300 let gflops =
301 extract_between(line, "(", " GFLOPS").and_then(|s| s.trim().parse::<f64>().ok())?;
302 let bw_gbps = extract_between(line, "GFLOPS, ", " GB/s")
303 .and_then(|s| s.trim().parse::<f64>().ok())?;
304 return Some(Q4kTiming {
305 time_us,
306 gflops,
307 bw_gbps,
308 });
309 }
310 }
311 None
312}
313
314fn extract_between<'a>(s: &'a str, start: &str, end: &str) -> Option<&'a str> {
316 let end_idx = s.find(end)?;
317 let prefix = &s[..end_idx];
318 let start_idx = prefix.rfind(start)? + start.len();
319 Some(&s[start_idx..end_idx])
320}
321
322fn find_bench_binary() -> Option<String> {
324 let candidates = [
325 "/mnt/nvme-raid0/targets/trueno/release/examples/benchmark_matrix_suite",
326 "./target/release/examples/benchmark_matrix_suite",
327 ];
328 for path in &candidates {
329 if std::path::Path::new(path).exists() {
330 return Some(path.to_string());
331 }
332 }
333 None
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339
340 #[test]
341 fn test_parse_dimensions() {
342 let dims = parse_dimensions("4096x1x4096").unwrap();
343 assert_eq!(dims, [4096, 1, 4096]);
344 }
345
346 #[test]
347 fn test_parse_dimensions_invalid() {
348 assert!(parse_dimensions("4096x4096").is_err());
349 assert!(parse_dimensions("abc").is_err());
350 }
351
352 #[test]
354 fn test_q4k_superblock() {
355 let k = QuantKernel::Q4kGemv;
356 assert_eq!(k.superblock_elements(), 256);
357 assert_eq!(k.superblock_bytes(), 144);
358 }
359
360 #[test]
362 fn test_effective_bandwidth_compressed() {
363 let total_elements: u64 = 4096 * 4096;
365 let num_superblocks = total_elements / 256;
366 let compressed_bytes = num_superblocks * 144;
367 let expected_mb = 9.437184; assert!(
369 (compressed_bytes as f64 / 1e6 - expected_mb).abs() < 0.01,
370 "Compressed size {:.2} MB != expected {:.2} MB",
371 compressed_bytes as f64 / 1e6,
372 expected_mb
373 );
374 }
375
376 #[test]
377 fn test_kernel_from_str() {
378 assert_eq!(
379 "q4k_gemv".parse::<QuantKernel>().unwrap(),
380 QuantKernel::Q4kGemv
381 );
382 assert_eq!("q6k".parse::<QuantKernel>().unwrap(), QuantKernel::Q6kGemv);
383 assert!("invalid".parse::<QuantKernel>().is_err());
384 }
385}