Skip to main content

god_graph/transformer/
perf.rs

1//! Performance optimization utilities for Transformer inference
2//!
3//! This module provides:
4//! - Memory pool integration for reduced allocation overhead
5//! - Optimized kernels for attention and FFN
6//! - Benchmark utilities for performance measurement
7
8#[cfg(feature = "simd")]
9use wide::f64x4;
10
11use crate::tensor::DenseTensor;
12use crate::tensor::traits::TensorBase;
13
14/// Memory pool for Transformer inference
15///
16/// Reuses intermediate buffers to reduce allocation overhead during inference.
17/// Typical use case: allocate once, reuse across multiple forward passes.
18#[derive(Debug)]
19pub struct TransformerMemoryPool {
20    /// Buffer for attention scores [batch, num_heads, seq_len, seq_len]
21    attn_score_buffer: Option<Vec<f64>>,
22    /// Buffer for attention weights [batch, num_heads, seq_len, seq_len]
23    attn_weight_buffer: Option<Vec<f64>>,
24    /// Buffer for QKV projections [batch, seq_len, hidden_dim]
25    qkv_buffer: Option<Vec<f64>>,
26    /// Buffer for output [batch, seq_len, hidden_dim]
27    output_buffer: Option<Vec<f64>>,
28    /// Current batch size
29    batch_size: usize,
30    /// Current sequence length
31    seq_len: usize,
32    /// Current hidden dimension
33    hidden_dim: usize,
34    /// Number of attention heads
35    num_heads: usize,
36}
37
38impl TransformerMemoryPool {
39    /// Create a new memory pool with specified dimensions
40    pub fn new(batch_size: usize, seq_len: usize, hidden_dim: usize, num_heads: usize) -> Self {
41        Self {
42            attn_score_buffer: None,
43            attn_weight_buffer: None,
44            qkv_buffer: None,
45            output_buffer: None,
46            batch_size,
47            seq_len,
48            hidden_dim,
49            num_heads,
50        }
51    }
52
53    /// Update pool dimensions if needed
54    pub fn resize(
55        &mut self,
56        batch_size: usize,
57        seq_len: usize,
58        hidden_dim: usize,
59        num_heads: usize,
60    ) {
61        let needs_resize = self.batch_size != batch_size
62            || self.seq_len != seq_len
63            || self.hidden_dim != hidden_dim
64            || self.num_heads != num_heads;
65
66        if needs_resize {
67            self.batch_size = batch_size;
68            self.seq_len = seq_len;
69            self.hidden_dim = hidden_dim;
70            self.num_heads = num_heads;
71
72            // Clear buffers to force reallocation
73            self.attn_score_buffer = None;
74            self.attn_weight_buffer = None;
75            self.qkv_buffer = None;
76            self.output_buffer = None;
77        }
78    }
79
80    /// Get or allocate attention score buffer
81    ///
82    /// # Panics
83    ///
84    /// This method should never panic as it allocates the buffer if needed.
85    /// Panic would only occur if memory allocation fails.
86    #[must_use]
87    pub fn get_attn_score_buffer(&mut self) -> &mut Vec<f64> {
88        if self.attn_score_buffer.is_none() {
89            let size = self.batch_size * self.num_heads * self.seq_len * self.seq_len;
90            self.attn_score_buffer = Some(vec![0.0f64; size]);
91        }
92        self.attn_score_buffer.as_mut().unwrap()
93    }
94
95    /// Get or allocate attention weight buffer
96    ///
97    /// # Panics
98    ///
99    /// This method should never panic as it allocates the buffer if needed.
100    /// Panic would only occur if memory allocation fails.
101    #[must_use]
102    pub fn get_attn_weight_buffer(&mut self) -> &mut Vec<f64> {
103        if self.attn_weight_buffer.is_none() {
104            let size = self.batch_size * self.num_heads * self.seq_len * self.seq_len;
105            self.attn_weight_buffer = Some(vec![0.0f64; size]);
106        }
107        self.attn_weight_buffer.as_mut().unwrap()
108    }
109
110    /// Get or allocate QKV projection buffer
111    ///
112    /// # Panics
113    ///
114    /// This method should never panic as it allocates the buffer if needed.
115    /// Panic would only occur if memory allocation fails.
116    #[must_use]
117    pub fn get_qkv_buffer(&mut self) -> &mut Vec<f64> {
118        if self.qkv_buffer.is_none() {
119            let size = self.batch_size * self.seq_len * self.hidden_dim;
120            self.qkv_buffer = Some(vec![0.0f64; size]);
121        }
122        self.qkv_buffer.as_mut().unwrap()
123    }
124
125    /// Get or allocate output buffer
126    ///
127    /// # Panics
128    ///
129    /// This method should never panic as it allocates the buffer if needed.
130    /// Panic would only occur if memory allocation fails.
131    #[must_use]
132    pub fn get_output_buffer(&mut self) -> &mut Vec<f64> {
133        if self.output_buffer.is_none() {
134            let size = self.batch_size * self.seq_len * self.hidden_dim;
135            self.output_buffer = Some(vec![0.0f64; size]);
136        }
137        self.output_buffer.as_mut().unwrap()
138    }
139
140    /// Get estimated memory usage in bytes
141    pub fn memory_bytes(&self) -> usize {
142        let mut bytes = 0;
143
144        if let Some(ref buf) = self.attn_score_buffer {
145            bytes += buf.len() * 8; // f64 = 8 bytes
146        }
147        if let Some(ref buf) = self.attn_weight_buffer {
148            bytes += buf.len() * 8;
149        }
150        if let Some(ref buf) = self.qkv_buffer {
151            bytes += buf.len() * 8;
152        }
153        if let Some(ref buf) = self.output_buffer {
154            bytes += buf.len() * 8;
155        }
156
157        bytes
158    }
159}
160
161impl Default for TransformerMemoryPool {
162    fn default() -> Self {
163        Self::new(1, 512, 4096, 32) // Default: LLaMA-7B inference
164    }
165}
166
167/// Optimized softmax implementation using SIMD
168///
169/// # Arguments
170/// * `data` - Input data (will be overwritten with softmax result)
171/// * `dim` - Dimension along which to compute softmax
172/// * `shape` - Tensor shape
173pub fn softmax_inplace_simd(data: &mut [f64], shape: &[usize], dim: usize) {
174    assert!(dim < shape.len(), "Invalid dimension");
175
176    let ndim = shape.len();
177    let dim_size = shape[dim];
178
179    // Calculate stride for the softmax dimension
180    let mut stride = 1;
181    for &size in shape.iter().take(ndim).skip(dim + 1) {
182        stride *= size;
183    }
184
185    // Calculate outer iterations
186    let outer: usize = shape[..dim].iter().product();
187    let inner: usize = shape[dim + 1..].iter().product();
188
189    #[cfg(feature = "simd")]
190    {
191        // SIMD-optimized softmax
192        for o in 0..outer {
193            for i in 0..inner {
194                let base = o * dim_size * stride + i;
195
196                // Find max for numerical stability (SIMD)
197                let mut max_val = f64::NEG_INFINITY;
198                for d in (0..dim_size).step_by(4) {
199                    if d + 4 <= dim_size {
200                        let vals = [
201                            data[base + d * stride],
202                            data[base + (d + 1) * stride],
203                            data[base + (d + 2) * stride],
204                            data[base + (d + 3) * stride],
205                        ];
206                        let simd_vals = f64x4::new(vals);
207                        let max_simd = simd_vals.max(f64x4::new([max_val; 4]));
208                        let max_arr = max_simd.to_array();
209                        max_val = max_arr[0].max(max_arr[1]).max(max_arr[2]).max(max_arr[3]);
210                    } else {
211                        for rem_d in d..dim_size {
212                            max_val = max_val.max(data[base + rem_d * stride]);
213                        }
214                    }
215                }
216
217                // Compute exp(x - max) and sum (SIMD)
218                let mut sum_exp = 0.0;
219                for d in (0..dim_size).step_by(4) {
220                    if d + 4 <= dim_size {
221                        let vals = [
222                            (data[base + d * stride] - max_val).exp(),
223                            (data[base + (d + 1) * stride] - max_val).exp(),
224                            (data[base + (d + 2) * stride] - max_val).exp(),
225                            (data[base + (d + 3) * stride] - max_val).exp(),
226                        ];
227                        let simd_vals = f64x4::new(vals);
228                        let sum_simd = simd_vals.reduce_add();
229                        sum_exp += sum_simd;
230
231                        // Store back
232                        let exp_vals = simd_vals.to_array();
233                        data[base + d * stride] = exp_vals[0];
234                        data[base + (d + 1) * stride] = exp_vals[1];
235                        data[base + (d + 2) * stride] = exp_vals[2];
236                        data[base + (d + 3) * stride] = exp_vals[3];
237                    } else {
238                        for rem_d in d..dim_size {
239                            let exp_val = (data[base + rem_d * stride] - max_val).exp();
240                            sum_exp += exp_val;
241                            data[base + rem_d * stride] = exp_val;
242                        }
243                    }
244                }
245
246                // Normalize (SIMD)
247                let inv_sum = 1.0 / sum_exp;
248                let inv_sum_simd = f64x4::new([inv_sum; 4]);
249                for d in (0..dim_size).step_by(4) {
250                    if d + 4 <= dim_size {
251                        let vals = [
252                            data[base + d * stride],
253                            data[base + (d + 1) * stride],
254                            data[base + (d + 2) * stride],
255                            data[base + (d + 3) * stride],
256                        ];
257                        let simd_vals = f64x4::new(vals) * inv_sum_simd;
258                        let norm_vals = simd_vals.to_array();
259                        data[base + d * stride] = norm_vals[0];
260                        data[base + (d + 1) * stride] = norm_vals[1];
261                        data[base + (d + 2) * stride] = norm_vals[2];
262                        data[base + (d + 3) * stride] = norm_vals[3];
263                    } else {
264                        for rem_d in d..dim_size {
265                            data[base + rem_d * stride] *= inv_sum;
266                        }
267                    }
268                }
269            }
270        }
271    }
272
273    #[cfg(not(feature = "simd"))]
274    {
275        // Fallback: naive implementation
276        for o in 0..outer {
277            for i in 0..inner {
278                let base = o * dim_size * stride + i;
279
280                // Find max for numerical stability
281                let max_val = (0..dim_size)
282                    .map(|d| data[base + d * stride])
283                    .fold(f64::NEG_INFINITY, f64::max);
284
285                // Compute exp(x - max) and sum
286                let sum_exp: f64 = (0..dim_size)
287                    .map(|d| {
288                        let exp_val = (data[base + d * stride] - max_val).exp();
289                        data[base + d * stride] = exp_val;
290                        exp_val
291                    })
292                    .sum();
293
294                // Normalize
295                let inv_sum = 1.0 / sum_exp;
296                for d in 0..dim_size {
297                    data[base + d * stride] *= inv_sum;
298                }
299            }
300        }
301    }
302}
303
304/// Optimized matrix multiplication with pre-allocated buffer
305///
306/// # Arguments
307/// * `a` - Matrix A [M, K]
308/// * `b` - Matrix B [K, N]
309/// * `buffer` - Pre-allocated output buffer [M * N]
310///
311/// # Returns
312/// Result matrix [M, N]
313pub fn matmul_with_buffer(a: &DenseTensor, b: &DenseTensor, buffer: &mut Vec<f64>) -> DenseTensor {
314    let m = a.shape()[0];
315    let k = a.shape()[1];
316    let n = b.shape()[1];
317
318    assert_eq!(a.shape()[1], b.shape()[0], "Inner dimensions must match");
319
320    // Ensure buffer is large enough
321    if buffer.len() < m * n {
322        *buffer = vec![0.0; m * n];
323    }
324
325    #[cfg(feature = "simd")]
326    {
327        // SIMD-optimized matmul
328        for i in 0..m {
329            for j in (0..n).step_by(4) {
330                if j + 4 <= n {
331                    let mut sum_simd = f64x4::new([0.0; 4]);
332
333                    for p in 0..k {
334                        let a_val = a.data()[i * k + p];
335                        let a_simd = f64x4::new([a_val; 4]);
336
337                        let b_vals = [
338                            b.data()[p * n + j],
339                            b.data()[p * n + j + 1],
340                            b.data()[p * n + j + 2],
341                            b.data()[p * n + j + 3],
342                        ];
343                        let b_simd = f64x4::new(b_vals);
344
345                        sum_simd += a_simd * b_simd;
346                    }
347
348                    let sums = sum_simd.to_array();
349                    buffer[i * n + j] = sums[0];
350                    buffer[i * n + j + 1] = sums[1];
351                    buffer[i * n + j + 2] = sums[2];
352                    buffer[i * n + j + 3] = sums[3];
353                } else {
354                    // Handle remainder
355                    for rem_j in j..n {
356                        let mut sum = 0.0;
357                        for p in 0..k {
358                            sum += a.data()[i * k + p] * b.data()[p * n + rem_j];
359                        }
360                        buffer[i * n + rem_j] = sum;
361                    }
362                }
363            }
364        }
365    }
366
367    #[cfg(not(feature = "simd"))]
368    {
369        // Fallback: naive implementation
370        for i in 0..m {
371            for j in 0..n {
372                let mut sum = 0.0;
373                for p in 0..k {
374                    sum += a.data()[i * k + p] * b.data()[p * n + j];
375                }
376                buffer[i * n + j] = sum;
377            }
378        }
379    }
380
381    DenseTensor::new(buffer[..m * n].to_vec(), vec![m, n])
382}
383
384/// Benchmark utilities for measuring inference performance
385pub mod benchmark {
386    use std::time::Instant;
387
388    /// Measure execution time of a function
389    ///
390    /// # Arguments
391    /// * `name` - Benchmark name
392    /// * `f` - Function to benchmark
393    ///
394    /// # Returns
395    /// Elapsed time in milliseconds
396    pub fn measure_time<F, R>(name: &str, f: F) -> R
397    where
398        F: FnOnce() -> R,
399    {
400        let start = Instant::now();
401        let result = f();
402        let elapsed = start.elapsed();
403
404        println!("{}: {:.3} ms", name, elapsed.as_secs_f64() * 1000.0);
405        result
406    }
407
408    /// Benchmark throughput (operations per second)
409    ///
410    /// # Arguments
411    /// * `name` - Benchmark name
412    /// * `iterations` - Number of iterations
413    /// * `f` - Function to benchmark
414    pub fn benchmark_throughput<F>(name: &str, iterations: usize, f: F)
415    where
416        F: Fn(),
417    {
418        let start = Instant::now();
419
420        for _ in 0..iterations {
421            f();
422        }
423
424        let elapsed = start.elapsed();
425        let ops_per_sec = iterations as f64 / elapsed.as_secs_f64();
426
427        println!(
428            "{}: {:.2} ops/sec ({:.3} ms/op)",
429            name,
430            ops_per_sec,
431            elapsed.as_secs_f64() * 1000.0 / iterations as f64
432        );
433    }
434
435    /// Measure tokens per second for inference
436    ///
437    /// # Arguments
438    /// * `num_tokens` - Number of tokens generated
439    /// * `elapsed_ms` - Elapsed time in milliseconds
440    pub fn tokens_per_second(num_tokens: usize, elapsed_ms: f64) -> f64 {
441        num_tokens as f64 / (elapsed_ms / 1000.0)
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448    use crate::transformer::perf::benchmark;
449
450    #[test]
451    fn test_memory_pool() {
452        let mut pool = TransformerMemoryPool::new(2, 128, 768, 8);
453
454        // Test buffer allocation
455        let attn_score_buf = pool.get_attn_score_buffer();
456        assert_eq!(attn_score_buf.len(), 2 * 8 * 128 * 128);
457
458        let attn_weight_buf = pool.get_attn_weight_buffer();
459        assert_eq!(attn_weight_buf.len(), 2 * 8 * 128 * 128);
460
461        let qkv_buf = pool.get_qkv_buffer();
462        assert_eq!(qkv_buf.len(), 2 * 128 * 768);
463
464        let output_buf = pool.get_output_buffer();
465        assert_eq!(output_buf.len(), 2 * 128 * 768);
466
467        // Test resize
468        pool.resize(4, 256, 1024, 16);
469        assert_eq!(pool.batch_size, 4);
470        assert_eq!(pool.seq_len, 256);
471    }
472
473    #[test]
474    fn test_softmax_simd() {
475        let mut data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
476        let shape = vec![2, 3];
477
478        softmax_inplace_simd(&mut data, &shape, 1);
479
480        // Check that each row sums to 1.0
481        for i in 0..2 {
482            let row_sum: f64 = data[i * 3..(i + 1) * 3].iter().sum();
483            assert!((row_sum - 1.0).abs() < 1e-5, "Row {} sum: {}", i, row_sum);
484        }
485    }
486
487    #[test]
488    fn test_matmul_with_buffer() {
489        let a = DenseTensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
490        let b = DenseTensor::new(vec![0.5, 0.5, 0.5, 0.5], vec![2, 2]);
491
492        let mut buffer = vec![0.0; 4];
493        let result = matmul_with_buffer(&a, &b, &mut buffer);
494
495        assert_eq!(result.shape(), &[2, 2]);
496        assert!((result.data()[0] - 1.5).abs() < 1e-5);
497        assert!((result.data()[1] - 1.5).abs() < 1e-5);
498        assert!((result.data()[2] - 3.5).abs() < 1e-5);
499        assert!((result.data()[3] - 3.5).abs() < 1e-5);
500    }
501
502    #[test]
503    fn test_benchmark_utils() {
504        // Test measure_time
505        let elapsed = std::time::Instant::now();
506        benchmark::measure_time("test", || {
507            std::thread::sleep(std::time::Duration::from_millis(10));
508        });
509        let actual_elapsed = elapsed.elapsed().as_secs_f64() * 1000.0;
510
511        assert!(actual_elapsed >= 10.0, "Should have slept for at least 10ms");
512
513        // Test tokens_per_second
514        let tps = benchmark::tokens_per_second(100, 1000.0); // 100 tokens in 1000ms
515        assert!((tps - 100.0).abs() < 1e-5);
516    }
517}