use realizar::quantize::{
fused_q4k_parallel_matvec_into, fused_q4k_q8k_parallel_matvec_into,
quantize_activations_q8k_into, QK_K,
};
use std::time::Instant;
fn main() {
let hidden_dim: usize = 1536;
let intermediate_dim: usize = 8960;
const Q4K_SUPER_BLOCK_BYTES: usize = 144;
let super_blocks = hidden_dim.div_ceil(QK_K);
let _q4k_bytes = super_blocks * Q4K_SUPER_BLOCK_BYTES;
let _ffn_down_bytes =
intermediate_dim.div_ceil(QK_K) * QK_K / QK_K * intermediate_dim.div_ceil(QK_K);
let ffn_down_rows = hidden_dim;
let ffn_down_cols = intermediate_dim;
let ffn_down_super_blocks = ffn_down_cols.div_ceil(QK_K);
let ffn_down_weight: Vec<u8> =
vec![0x55; ffn_down_rows * ffn_down_super_blocks * Q4K_SUPER_BLOCK_BYTES];
let activations_f32: Vec<f32> = (0..intermediate_dim)
.map(|i| (i as f32 * 0.001).sin())
.collect();
let padded_len = intermediate_dim.next_multiple_of(256);
let num_sb = padded_len / 256;
let mut q8k_scales = vec![0.0f32; num_sb];
let mut q8k_quants = vec![0i8; padded_len];
let mut output = vec![0.0f32; hidden_dim];
let iters = 100;
println!(
"=== Micro-profiling FFN Down ({}x{}) ===\n",
hidden_dim, intermediate_dim
);
let start = Instant::now();
for _ in 0..iters {
quantize_activations_q8k_into(&activations_f32, &mut q8k_scales, &mut q8k_quants).unwrap();
}
let q8k_time = start.elapsed();
println!(
"Q8K quantization: {:>7.1} µs/iter ({:.2}%)",
q8k_time.as_micros() as f64 / iters as f64,
0.0
);
let start = Instant::now();
for _ in 0..iters {
fused_q4k_q8k_parallel_matvec_into(
&ffn_down_weight,
&q8k_scales,
&q8k_quants,
intermediate_dim,
hidden_dim,
&mut output,
)
.unwrap();
}
let q4k_q8k_time = start.elapsed();
let q4k_q8k_us = q4k_q8k_time.as_micros() as f64 / iters as f64;
println!("Q4K×Q8K matmul: {:>7.1} µs/iter", q4k_q8k_us);
let start = Instant::now();
for _ in 0..iters {
fused_q4k_parallel_matvec_into(
&ffn_down_weight,
&activations_f32,
intermediate_dim,
hidden_dim,
&mut output,
)
.unwrap();
}
let q4k_f32_time = start.elapsed();
let q4k_f32_us = q4k_f32_time.as_micros() as f64 / iters as f64;
println!("Q4K×f32 matmul: {:>7.1} µs/iter", q4k_f32_us);
println!("\nQ8K speedup: {:.2}x", q4k_f32_us / q4k_q8k_us);
let flops = 2.0 * hidden_dim as f64 * intermediate_dim as f64;
let gflops_q8k = flops / (q4k_q8k_us * 1000.0);
let gflops_f32 = flops / (q4k_f32_us * 1000.0);
println!("\n=== Throughput ===");
println!("Q4K×Q8K: {:.1} GFLOP/s", gflops_q8k);
println!("Q4K×f32: {:.1} GFLOP/s", gflops_f32);
let weight_bytes = ffn_down_weight.len() as f64;
let activation_bytes = intermediate_dim as f64 * 4.0; let q8k_activation_bytes = intermediate_dim as f64 * 1.0;
let total_bytes_f32 = weight_bytes + activation_bytes;
let total_bytes_q8k = weight_bytes + q8k_activation_bytes;
let bw_f32 = total_bytes_f32 / (q4k_f32_us * 1000.0);
let bw_q8k = total_bytes_q8k / (q4k_q8k_us * 1000.0);
println!("\n=== Memory Bandwidth ===");
println!("Q4K×f32 effective: {:.1} GB/s", bw_f32);
println!("Q4K×Q8K effective: {:.1} GB/s", bw_q8k);
let ai_f32 = flops / total_bytes_f32;
let ai_q8k = flops / total_bytes_q8k;
println!("\n=== Arithmetic Intensity ===");
println!("Q4K×f32: {:.2} FLOP/byte", ai_f32);
println!("Q4K×Q8K: {:.2} FLOP/byte", ai_q8k);
let qkv_flops = 3.0 * (hidden_dim as f64).powi(2) * 2.0;
let attn_out_flops = (hidden_dim as f64).powi(2) * 2.0;
let ffn_gate_up_flops = 2.0 * hidden_dim as f64 * intermediate_dim as f64 * 2.0;
let ffn_down_flops = intermediate_dim as f64 * hidden_dim as f64 * 2.0;
let layer_flops = qkv_flops + attn_out_flops + ffn_gate_up_flops + ffn_down_flops;
let num_layers = 28;
let total_flops = layer_flops * num_layers as f64;
let estimated_time_us = (total_flops / gflops_q8k) / 1000.0;
let estimated_tok_s = 1_000_000.0 / estimated_time_us;
println!("\n=== Estimated Full Forward Pass ===");
println!("Total FLOPs per token: {:.2}B", total_flops / 1e9);
println!(
"If all matmuls at Q4K×Q8K speed: {:.1} µs",
estimated_time_us
);
println!("Estimated throughput: {:.0} tok/s", estimated_tok_s);
println!("\nActual measured: ~15-19 tok/s");
println!(
"Gap: {:.1}x slower than matmul-limited",
estimated_tok_s / 17.0
);
}