trueno/brick/fused_ops.rs
1//! Fused Operations for Transformer Inference
2//!
3//! This module contains fused compute operations that combine multiple
4//! operations into single passes for improved performance.
5//!
6//! # Operations
7//!
8//! - `FusedQKVOp`: Fused Query/Key/Value projection (3 matmuls → 1)
9//! - `FusedGateUpOp`: Fused Gate+Up FFN projection with SiLU (SwiGLU)
10//!
11//! # Performance Impact
12//!
13//! Fusing operations provides:
14//! - Reduced kernel launches (GPU)
15//! - Better cache utilization (data loaded once)
16//! - Eliminated intermediate memory traffic
17
18use super::{Backend, ComputeOp};
19use crate::error::TruenoError;
20
21// ============================================================================
22// Fused Q/K/V Projection (PMAT-PERF-009)
23// ============================================================================
24
25/// Weights for fused QKV projection
26#[derive(Debug, Clone)]
27pub struct FusedQKVWeights {
28 /// Q projection weights [hidden_size, hidden_size]
29 pub q_weight: Vec<f32>,
30 /// K projection weights [hidden_size, kv_dim]
31 pub k_weight: Vec<f32>,
32 /// V projection weights [hidden_size, kv_dim]
33 pub v_weight: Vec<f32>,
34}
35
36/// Fused Q/K/V projection operation for transformer attention.
37///
38/// Computes Q, K, V projections in a single pass over the input:
39/// - Q = x * W_q (hidden_size → hidden_size)
40/// - K = x * W_k (hidden_size → kv_dim)
41/// - V = x * W_v (hidden_size → kv_dim)
42///
43/// # Performance Impact
44///
45/// Fusing 3 separate matmuls into 1 operation provides:
46/// - 3x reduction in kernel launches (GPU)
47/// - Better cache utilization (input x loaded once)
48/// - Expected speedup: 2-3x for decode phase
49///
50/// # Five-Whys Root Cause (PMAT-PERF-009)
51///
52/// ```text
53/// Why 1: Why is decode throughput 131 tok/s vs 400 tok/s target?
54/// → 280+ kernel launches per token (10+ per layer × 28 layers)
55///
56/// Why 2: Why so many kernel launches?
57/// → Q, K, V computed as 3 separate GEMV operations
58///
59/// Why 3: Why separate operations?
60/// → Original implementation didn't consider launch overhead
61///
62/// Why 4: Why does launch overhead matter?
63/// → GPU kernel launch: ~5-10µs, 280 launches = 1.4-2.8ms overhead/token
64///
65/// Why 5: ROOT CAUSE
66/// → Kernel launch overhead (2.8ms) exceeds compute time for small batch decode
67/// → FIX: Fuse Q/K/V into single kernel, reducing launches by 2/3
68/// ```
69#[derive(Debug, Clone)]
70pub struct FusedQKVOp {
71 /// Hidden dimension size
72 pub hidden_size: usize,
73 /// KV dimension (num_kv_heads * head_dim, may differ from hidden_size for GQA)
74 pub kv_dim: usize,
75 /// Number of attention heads
76 pub num_heads: usize,
77 /// Head dimension
78 pub head_dim: usize,
79}
80
81impl FusedQKVOp {
82 /// Create a new fused QKV operation.
83 ///
84 /// # Arguments
85 /// * `hidden_size` - Hidden dimension (e.g., 3584 for Qwen 3B)
86 /// * `num_heads` - Number of attention heads
87 /// * `num_kv_heads` - Number of KV heads (may differ for GQA)
88 pub fn new(hidden_size: usize, num_heads: usize, num_kv_heads: usize) -> Self {
89 let head_dim = hidden_size / num_heads;
90 let kv_dim = num_kv_heads * head_dim;
91 Self { hidden_size, kv_dim, num_heads, head_dim }
92 }
93}
94
95#[allow(clippy::needless_range_loop)] // Matrix indexing is clearer with explicit loops
96impl ComputeOp for FusedQKVOp {
97 type Input = (Vec<f32>, FusedQKVWeights);
98 type Output = (Vec<f32>, Vec<f32>, Vec<f32>); // (Q, K, V)
99
100 fn name(&self) -> &'static str {
101 "fused_qkv"
102 }
103
104 fn execute(&self, input: Self::Input, _backend: Backend) -> Result<Self::Output, TruenoError> {
105 let (x, weights) = input;
106
107 // Validate input dimensions
108 if x.len() != self.hidden_size {
109 return Err(TruenoError::SizeMismatch { expected: self.hidden_size, actual: x.len() });
110 }
111
112 let h = self.hidden_size;
113
114 // Q projection: x @ W_q^T -> [hidden_size]
115 // CGP-DBUF: SIMD dot product per row (was scalar nested loop).
116 let mut q: Vec<f32> = Vec::with_capacity(h);
117 // SAFETY: Each q[i] is SET to simd_dot result before any read.
118 unsafe {
119 q.set_len(h);
120 }
121 for i in 0..h {
122 q[i] =
123 super::attention::AttentionOp::simd_dot(&x, &weights.q_weight[i * h..(i + 1) * h]);
124 }
125
126 // K projection: x @ W_k^T -> [kv_dim]
127 let mut k: Vec<f32> = Vec::with_capacity(self.kv_dim);
128 // SAFETY: Each k[i] is SET to simd_dot result before any read.
129 unsafe {
130 k.set_len(self.kv_dim);
131 }
132 for i in 0..self.kv_dim {
133 k[i] =
134 super::attention::AttentionOp::simd_dot(&x, &weights.k_weight[i * h..(i + 1) * h]);
135 }
136
137 // V projection: x @ W_v^T -> [kv_dim]
138 let mut v: Vec<f32> = Vec::with_capacity(self.kv_dim);
139 // SAFETY: Each v[i] is SET to simd_dot result before any read.
140 unsafe {
141 v.set_len(self.kv_dim);
142 }
143 for i in 0..self.kv_dim {
144 v[i] =
145 super::attention::AttentionOp::simd_dot(&x, &weights.v_weight[i * h..(i + 1) * h]);
146 }
147
148 Ok((q, k, v))
149 }
150
151 fn tokens(&self, _input: &Self::Input) -> usize {
152 // Output tokens = Q + K + V dimensions
153 self.hidden_size + 2 * self.kv_dim
154 }
155}
156
157// ============================================================================
158// Fused Gate+Up FFN Projection (PMAT-PERF-009)
159// ============================================================================
160
161/// Weights for fused gate+up FFN projection
162#[derive(Debug, Clone)]
163pub struct FusedGateUpWeights {
164 /// Gate projection weights [hidden_size, intermediate_size]
165 pub gate_weight: Vec<f32>,
166 /// Up projection weights [hidden_size, intermediate_size]
167 pub up_weight: Vec<f32>,
168}
169
170/// Fused Gate+Up FFN projection with SiLU activation.
171///
172/// Computes gate and up projections in a single pass:
173/// - gate = x * W_gate
174/// - up = x * W_up
175/// - output = SiLU(gate) * up (SwiGLU activation)
176///
177/// # Performance Impact
178///
179/// Fusing 2 separate matmuls + activation provides:
180/// - 2x reduction in kernel launches (GPU)
181/// - Fused SiLU avoids intermediate memory traffic
182/// - Expected speedup: 1.5-2x for decode phase
183///
184/// # Five-Whys Root Cause (PMAT-PERF-009)
185///
186/// ```text
187/// Why 1: Why is FFN phase slow?
188/// → 3 kernel launches: gate_proj, up_proj, SiLU activation
189///
190/// Why 2: Why separate kernels?
191/// → Traditional implementation pattern from training frameworks
192///
193/// Why 3: Why does this matter for inference?
194/// → Inference is memory-bound; kernel launch overhead dominates
195///
196/// Why 4: Why not fuse earlier?
197/// → Requires custom kernel development
198///
199/// Why 5: ROOT CAUSE
200/// → SwiGLU requires gate*up pattern that naturally fuses
201/// → FIX: Fuse gate+up+SiLU into single operation
202/// ```
203#[derive(Debug, Clone)]
204pub struct FusedGateUpOp {
205 /// Hidden dimension size
206 pub hidden_size: usize,
207 /// Intermediate FFN dimension
208 pub intermediate_size: usize,
209}
210
211impl FusedGateUpOp {
212 /// Create a new fused gate+up operation.
213 ///
214 /// # Arguments
215 /// * `hidden_size` - Hidden dimension (e.g., 3584 for Qwen 3B)
216 /// * `intermediate_size` - FFN intermediate dimension (e.g., 18944)
217 pub fn new(hidden_size: usize, intermediate_size: usize) -> Self {
218 Self { hidden_size, intermediate_size }
219 }
220
221 /// SiLU activation: x * sigmoid(x)
222 ///
223 /// ONE PATH: Delegates to `crate::activations::silu_scalar` (UCBD §4).
224 #[inline]
225 pub fn silu(x: f32) -> f32 {
226 contract_pre_silu!();
227 let result = crate::activations::silu_scalar(x);
228 contract_post_silu!(&[result]);
229 result
230 }
231}
232
233impl ComputeOp for FusedGateUpOp {
234 type Input = (Vec<f32>, FusedGateUpWeights);
235 type Output = Vec<f32>; // SwiGLU output [intermediate_size]
236
237 fn name(&self) -> &'static str {
238 "fused_gate_up"
239 }
240
241 fn execute(&self, input: Self::Input, _backend: Backend) -> Result<Self::Output, TruenoError> {
242 let (x, weights) = input;
243
244 // Validate input dimensions
245 if x.len() != self.hidden_size {
246 return Err(TruenoError::SizeMismatch { expected: self.hidden_size, actual: x.len() });
247 }
248
249 // SIMD-optimized fused gate + up + SwiGLU
250 // CGP-DBUF: Use slice-based SIMD dot directly — eliminates 2×intermediate_size
251 // Vector allocations per call (was ~38K allocs for Qwen 3B).
252 // Uninit: output[i] = silu(gate) * up (SET) for every i.
253 let mut output: Vec<f32> = Vec::with_capacity(self.intermediate_size);
254 // SAFETY: Loop writes output[i] = silu(gate_sum) * up_sum for all i.
255 unsafe {
256 output.set_len(self.intermediate_size);
257 }
258
259 let h = self.hidden_size;
260 for i in 0..self.intermediate_size {
261 let row_start = i * h;
262 let row_end = row_start + h;
263
264 // Direct SIMD dot on slices — no Vector allocation
265 let gate_sum = super::attention::AttentionOp::simd_dot(
266 &x,
267 &weights.gate_weight[row_start..row_end],
268 );
269 let up_sum =
270 super::attention::AttentionOp::simd_dot(&x, &weights.up_weight[row_start..row_end]);
271
272 // SwiGLU: SiLU(gate) * up
273 output[i] = Self::silu(gate_sum) * up_sum;
274 }
275
276 Ok(output)
277 }
278
279 fn tokens(&self, _input: &Self::Input) -> usize {
280 self.intermediate_size
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 #[test]
289 fn test_fused_qkv_basic() {
290 // hidden=4, num_heads=2, kv_heads=1 → head_dim=2, kv_dim=2
291 let op = FusedQKVOp::new(4, 2, 1);
292
293 let x = vec![1.0, 2.0, 3.0, 4.0];
294 let weights = FusedQKVWeights {
295 q_weight: vec![1.0; 16], // hidden_size x hidden_size = 4x4 = 16
296 k_weight: vec![1.0; 8], // kv_dim x hidden_size = 2x4 = 8
297 v_weight: vec![1.0; 8], // kv_dim x hidden_size = 2x4 = 8
298 };
299
300 let (q, k, v) = op.execute((x, weights), Backend::Scalar).unwrap();
301
302 assert_eq!(q.len(), 4);
303 assert_eq!(k.len(), 2);
304 assert_eq!(v.len(), 2);
305 }
306
307 #[test]
308 fn test_fused_qkv_dimension_mismatch() {
309 let op = FusedQKVOp::new(4, 2, 2);
310 let x = vec![1.0, 2.0]; // Wrong size
311 let weights = FusedQKVWeights {
312 q_weight: vec![1.0; 16],
313 k_weight: vec![1.0; 8],
314 v_weight: vec![1.0; 8],
315 };
316
317 let result = op.execute((x, weights), Backend::Scalar);
318 assert!(result.is_err());
319 }
320
321 #[test]
322 fn test_fused_gate_up_basic() {
323 let op = FusedGateUpOp::new(4, 2);
324
325 let x = vec![1.0, 2.0, 3.0, 4.0];
326 let weights = FusedGateUpWeights {
327 gate_weight: vec![1.0; 8], // 2x4
328 up_weight: vec![1.0; 8], // 2x4
329 };
330
331 let output = op.execute((x, weights), Backend::Scalar).unwrap();
332 assert_eq!(output.len(), 2);
333
334 // Output should be SiLU(gate_sum) * up_sum
335 // gate_sum = up_sum = 1+2+3+4 = 10
336 // SiLU(10) ≈ 10 * sigmoid(10) ≈ 10 * 0.99995 ≈ 10
337 // output ≈ 10 * 10 = 100
338 assert!(output[0] > 90.0 && output[0] < 110.0);
339 }
340
341 #[test]
342 fn test_fused_gate_up_dimension_mismatch() {
343 let op = FusedGateUpOp::new(4, 2);
344 let x = vec![1.0, 2.0]; // Wrong size
345 let weights = FusedGateUpWeights { gate_weight: vec![1.0; 8], up_weight: vec![1.0; 8] };
346
347 let result = op.execute((x, weights), Backend::Scalar);
348 assert!(result.is_err());
349 }
350
351 #[test]
352 fn test_silu_values() {
353 // SiLU(0) = 0
354 assert!((FusedGateUpOp::silu(0.0) - 0.0).abs() < 1e-6);
355
356 // SiLU(x) → x for large positive x
357 assert!((FusedGateUpOp::silu(10.0) - 10.0).abs() < 0.01);
358
359 // SiLU(x) → 0 for large negative x
360 assert!(FusedGateUpOp::silu(-10.0).abs() < 0.01);
361 }
362
363 #[test]
364 fn test_fused_qkv_tokens() {
365 // hidden=128, heads=8, kv_heads=4 → head_dim=16, kv_dim=64
366 let op = FusedQKVOp::new(128, 8, 4);
367 let weights = FusedQKVWeights { q_weight: vec![], k_weight: vec![], v_weight: vec![] };
368 // tokens = hidden + 2 * kv_dim = 128 + 2 * 64 = 256
369 assert_eq!(op.tokens(&(vec![], weights)), 256);
370 }
371
372 #[test]
373 fn test_fused_gate_up_tokens() {
374 let op = FusedGateUpOp::new(128, 256);
375 let weights = FusedGateUpWeights { gate_weight: vec![], up_weight: vec![] };
376 assert_eq!(op.tokens(&(vec![], weights)), 256);
377 }
378}