Skip to main content

trueno/brick/
fused_ops.rs

1//! Fused Operations for Transformer Inference
2//!
3//! This module contains fused compute operations that combine multiple
4//! operations into single passes for improved performance.
5//!
6//! # Operations
7//!
8//! - `FusedQKVOp`: Fused Query/Key/Value projection (3 matmuls → 1)
9//! - `FusedGateUpOp`: Fused Gate+Up FFN projection with SiLU (SwiGLU)
10//!
11//! # Performance Impact
12//!
13//! Fusing operations provides:
14//! - Reduced kernel launches (GPU)
15//! - Better cache utilization (data loaded once)
16//! - Eliminated intermediate memory traffic
17
18use super::{Backend, ComputeOp};
19use crate::error::TruenoError;
20
21// ============================================================================
22// Fused Q/K/V Projection (PMAT-PERF-009)
23// ============================================================================
24
25/// Weights for fused QKV projection
26#[derive(Debug, Clone)]
27pub struct FusedQKVWeights {
28    /// Q projection weights [hidden_size, hidden_size]
29    pub q_weight: Vec<f32>,
30    /// K projection weights [hidden_size, kv_dim]
31    pub k_weight: Vec<f32>,
32    /// V projection weights [hidden_size, kv_dim]
33    pub v_weight: Vec<f32>,
34}
35
36/// Fused Q/K/V projection operation for transformer attention.
37///
38/// Computes Q, K, V projections in a single pass over the input:
39/// - Q = x * W_q (hidden_size → hidden_size)
40/// - K = x * W_k (hidden_size → kv_dim)
41/// - V = x * W_v (hidden_size → kv_dim)
42///
43/// # Performance Impact
44///
45/// Fusing 3 separate matmuls into 1 operation provides:
46/// - 3x reduction in kernel launches (GPU)
47/// - Better cache utilization (input x loaded once)
48/// - Expected speedup: 2-3x for decode phase
49///
50/// # Five-Whys Root Cause (PMAT-PERF-009)
51///
52/// ```text
53/// Why 1: Why is decode throughput 131 tok/s vs 400 tok/s target?
54/// → 280+ kernel launches per token (10+ per layer × 28 layers)
55///
56/// Why 2: Why so many kernel launches?
57/// → Q, K, V computed as 3 separate GEMV operations
58///
59/// Why 3: Why separate operations?
60/// → Original implementation didn't consider launch overhead
61///
62/// Why 4: Why does launch overhead matter?
63/// → GPU kernel launch: ~5-10µs, 280 launches = 1.4-2.8ms overhead/token
64///
65/// Why 5: ROOT CAUSE
66/// → Kernel launch overhead (2.8ms) exceeds compute time for small batch decode
67/// → FIX: Fuse Q/K/V into single kernel, reducing launches by 2/3
68/// ```
69#[derive(Debug, Clone)]
70pub struct FusedQKVOp {
71    /// Hidden dimension size
72    pub hidden_size: usize,
73    /// KV dimension (num_kv_heads * head_dim, may differ from hidden_size for GQA)
74    pub kv_dim: usize,
75    /// Number of attention heads
76    pub num_heads: usize,
77    /// Head dimension
78    pub head_dim: usize,
79}
80
81impl FusedQKVOp {
82    /// Create a new fused QKV operation.
83    ///
84    /// # Arguments
85    /// * `hidden_size` - Hidden dimension (e.g., 3584 for Qwen 3B)
86    /// * `num_heads` - Number of attention heads
87    /// * `num_kv_heads` - Number of KV heads (may differ for GQA)
88    pub fn new(hidden_size: usize, num_heads: usize, num_kv_heads: usize) -> Self {
89        let head_dim = hidden_size / num_heads;
90        let kv_dim = num_kv_heads * head_dim;
91        Self { hidden_size, kv_dim, num_heads, head_dim }
92    }
93}
94
95#[allow(clippy::needless_range_loop)] // Matrix indexing is clearer with explicit loops
96impl ComputeOp for FusedQKVOp {
97    type Input = (Vec<f32>, FusedQKVWeights);
98    type Output = (Vec<f32>, Vec<f32>, Vec<f32>); // (Q, K, V)
99
100    fn name(&self) -> &'static str {
101        "fused_qkv"
102    }
103
104    fn execute(&self, input: Self::Input, _backend: Backend) -> Result<Self::Output, TruenoError> {
105        let (x, weights) = input;
106
107        // Validate input dimensions
108        if x.len() != self.hidden_size {
109            return Err(TruenoError::SizeMismatch { expected: self.hidden_size, actual: x.len() });
110        }
111
112        let h = self.hidden_size;
113
114        // Q projection: x @ W_q^T -> [hidden_size]
115        // CGP-DBUF: SIMD dot product per row (was scalar nested loop).
116        let mut q: Vec<f32> = Vec::with_capacity(h);
117        // SAFETY: Each q[i] is SET to simd_dot result before any read.
118        unsafe {
119            q.set_len(h);
120        }
121        for i in 0..h {
122            q[i] =
123                super::attention::AttentionOp::simd_dot(&x, &weights.q_weight[i * h..(i + 1) * h]);
124        }
125
126        // K projection: x @ W_k^T -> [kv_dim]
127        let mut k: Vec<f32> = Vec::with_capacity(self.kv_dim);
128        // SAFETY: Each k[i] is SET to simd_dot result before any read.
129        unsafe {
130            k.set_len(self.kv_dim);
131        }
132        for i in 0..self.kv_dim {
133            k[i] =
134                super::attention::AttentionOp::simd_dot(&x, &weights.k_weight[i * h..(i + 1) * h]);
135        }
136
137        // V projection: x @ W_v^T -> [kv_dim]
138        let mut v: Vec<f32> = Vec::with_capacity(self.kv_dim);
139        // SAFETY: Each v[i] is SET to simd_dot result before any read.
140        unsafe {
141            v.set_len(self.kv_dim);
142        }
143        for i in 0..self.kv_dim {
144            v[i] =
145                super::attention::AttentionOp::simd_dot(&x, &weights.v_weight[i * h..(i + 1) * h]);
146        }
147
148        Ok((q, k, v))
149    }
150
151    fn tokens(&self, _input: &Self::Input) -> usize {
152        // Output tokens = Q + K + V dimensions
153        self.hidden_size + 2 * self.kv_dim
154    }
155}
156
157// ============================================================================
158// Fused Gate+Up FFN Projection (PMAT-PERF-009)
159// ============================================================================
160
161/// Weights for fused gate+up FFN projection
162#[derive(Debug, Clone)]
163pub struct FusedGateUpWeights {
164    /// Gate projection weights [hidden_size, intermediate_size]
165    pub gate_weight: Vec<f32>,
166    /// Up projection weights [hidden_size, intermediate_size]
167    pub up_weight: Vec<f32>,
168}
169
170/// Fused Gate+Up FFN projection with SiLU activation.
171///
172/// Computes gate and up projections in a single pass:
173/// - gate = x * W_gate
174/// - up = x * W_up
175/// - output = SiLU(gate) * up (SwiGLU activation)
176///
177/// # Performance Impact
178///
179/// Fusing 2 separate matmuls + activation provides:
180/// - 2x reduction in kernel launches (GPU)
181/// - Fused SiLU avoids intermediate memory traffic
182/// - Expected speedup: 1.5-2x for decode phase
183///
184/// # Five-Whys Root Cause (PMAT-PERF-009)
185///
186/// ```text
187/// Why 1: Why is FFN phase slow?
188/// → 3 kernel launches: gate_proj, up_proj, SiLU activation
189///
190/// Why 2: Why separate kernels?
191/// → Traditional implementation pattern from training frameworks
192///
193/// Why 3: Why does this matter for inference?
194/// → Inference is memory-bound; kernel launch overhead dominates
195///
196/// Why 4: Why not fuse earlier?
197/// → Requires custom kernel development
198///
199/// Why 5: ROOT CAUSE
200/// → SwiGLU requires gate*up pattern that naturally fuses
201/// → FIX: Fuse gate+up+SiLU into single operation
202/// ```
203#[derive(Debug, Clone)]
204pub struct FusedGateUpOp {
205    /// Hidden dimension size
206    pub hidden_size: usize,
207    /// Intermediate FFN dimension
208    pub intermediate_size: usize,
209}
210
211impl FusedGateUpOp {
212    /// Create a new fused gate+up operation.
213    ///
214    /// # Arguments
215    /// * `hidden_size` - Hidden dimension (e.g., 3584 for Qwen 3B)
216    /// * `intermediate_size` - FFN intermediate dimension (e.g., 18944)
217    pub fn new(hidden_size: usize, intermediate_size: usize) -> Self {
218        Self { hidden_size, intermediate_size }
219    }
220
221    /// SiLU activation: x * sigmoid(x)
222    ///
223    /// ONE PATH: Delegates to `crate::activations::silu_scalar` (UCBD §4).
224    #[inline]
225    pub fn silu(x: f32) -> f32 {
226        contract_pre_silu!();
227        let result = crate::activations::silu_scalar(x);
228        contract_post_silu!(&[result]);
229        result
230    }
231}
232
233impl ComputeOp for FusedGateUpOp {
234    type Input = (Vec<f32>, FusedGateUpWeights);
235    type Output = Vec<f32>; // SwiGLU output [intermediate_size]
236
237    fn name(&self) -> &'static str {
238        "fused_gate_up"
239    }
240
241    fn execute(&self, input: Self::Input, _backend: Backend) -> Result<Self::Output, TruenoError> {
242        let (x, weights) = input;
243
244        // Validate input dimensions
245        if x.len() != self.hidden_size {
246            return Err(TruenoError::SizeMismatch { expected: self.hidden_size, actual: x.len() });
247        }
248
249        // SIMD-optimized fused gate + up + SwiGLU
250        // CGP-DBUF: Use slice-based SIMD dot directly — eliminates 2×intermediate_size
251        // Vector allocations per call (was ~38K allocs for Qwen 3B).
252        // Uninit: output[i] = silu(gate) * up (SET) for every i.
253        let mut output: Vec<f32> = Vec::with_capacity(self.intermediate_size);
254        // SAFETY: Loop writes output[i] = silu(gate_sum) * up_sum for all i.
255        unsafe {
256            output.set_len(self.intermediate_size);
257        }
258
259        let h = self.hidden_size;
260        for i in 0..self.intermediate_size {
261            let row_start = i * h;
262            let row_end = row_start + h;
263
264            // Direct SIMD dot on slices — no Vector allocation
265            let gate_sum = super::attention::AttentionOp::simd_dot(
266                &x,
267                &weights.gate_weight[row_start..row_end],
268            );
269            let up_sum =
270                super::attention::AttentionOp::simd_dot(&x, &weights.up_weight[row_start..row_end]);
271
272            // SwiGLU: SiLU(gate) * up
273            output[i] = Self::silu(gate_sum) * up_sum;
274        }
275
276        Ok(output)
277    }
278
279    fn tokens(&self, _input: &Self::Input) -> usize {
280        self.intermediate_size
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    #[test]
289    fn test_fused_qkv_basic() {
290        // hidden=4, num_heads=2, kv_heads=1 → head_dim=2, kv_dim=2
291        let op = FusedQKVOp::new(4, 2, 1);
292
293        let x = vec![1.0, 2.0, 3.0, 4.0];
294        let weights = FusedQKVWeights {
295            q_weight: vec![1.0; 16], // hidden_size x hidden_size = 4x4 = 16
296            k_weight: vec![1.0; 8],  // kv_dim x hidden_size = 2x4 = 8
297            v_weight: vec![1.0; 8],  // kv_dim x hidden_size = 2x4 = 8
298        };
299
300        let (q, k, v) = op.execute((x, weights), Backend::Scalar).unwrap();
301
302        assert_eq!(q.len(), 4);
303        assert_eq!(k.len(), 2);
304        assert_eq!(v.len(), 2);
305    }
306
307    #[test]
308    fn test_fused_qkv_dimension_mismatch() {
309        let op = FusedQKVOp::new(4, 2, 2);
310        let x = vec![1.0, 2.0]; // Wrong size
311        let weights = FusedQKVWeights {
312            q_weight: vec![1.0; 16],
313            k_weight: vec![1.0; 8],
314            v_weight: vec![1.0; 8],
315        };
316
317        let result = op.execute((x, weights), Backend::Scalar);
318        assert!(result.is_err());
319    }
320
321    #[test]
322    fn test_fused_gate_up_basic() {
323        let op = FusedGateUpOp::new(4, 2);
324
325        let x = vec![1.0, 2.0, 3.0, 4.0];
326        let weights = FusedGateUpWeights {
327            gate_weight: vec![1.0; 8], // 2x4
328            up_weight: vec![1.0; 8],   // 2x4
329        };
330
331        let output = op.execute((x, weights), Backend::Scalar).unwrap();
332        assert_eq!(output.len(), 2);
333
334        // Output should be SiLU(gate_sum) * up_sum
335        // gate_sum = up_sum = 1+2+3+4 = 10
336        // SiLU(10) ≈ 10 * sigmoid(10) ≈ 10 * 0.99995 ≈ 10
337        // output ≈ 10 * 10 = 100
338        assert!(output[0] > 90.0 && output[0] < 110.0);
339    }
340
341    #[test]
342    fn test_fused_gate_up_dimension_mismatch() {
343        let op = FusedGateUpOp::new(4, 2);
344        let x = vec![1.0, 2.0]; // Wrong size
345        let weights = FusedGateUpWeights { gate_weight: vec![1.0; 8], up_weight: vec![1.0; 8] };
346
347        let result = op.execute((x, weights), Backend::Scalar);
348        assert!(result.is_err());
349    }
350
351    #[test]
352    fn test_silu_values() {
353        // SiLU(0) = 0
354        assert!((FusedGateUpOp::silu(0.0) - 0.0).abs() < 1e-6);
355
356        // SiLU(x) → x for large positive x
357        assert!((FusedGateUpOp::silu(10.0) - 10.0).abs() < 0.01);
358
359        // SiLU(x) → 0 for large negative x
360        assert!(FusedGateUpOp::silu(-10.0).abs() < 0.01);
361    }
362
363    #[test]
364    fn test_fused_qkv_tokens() {
365        // hidden=128, heads=8, kv_heads=4 → head_dim=16, kv_dim=64
366        let op = FusedQKVOp::new(128, 8, 4);
367        let weights = FusedQKVWeights { q_weight: vec![], k_weight: vec![], v_weight: vec![] };
368        // tokens = hidden + 2 * kv_dim = 128 + 2 * 64 = 256
369        assert_eq!(op.tokens(&(vec![], weights)), 256);
370    }
371
372    #[test]
373    fn test_fused_gate_up_tokens() {
374        let op = FusedGateUpOp::new(128, 256);
375        let weights = FusedGateUpWeights { gate_weight: vec![], up_weight: vec![] };
376        assert_eq!(op.tokens(&(vec![], weights)), 256);
377    }
378}