1#[cfg(feature = "simd")]
9use wide::f64x4;
10
11use crate::tensor::DenseTensor;
12use crate::tensor::traits::TensorBase;
13
14#[derive(Debug)]
19pub struct TransformerMemoryPool {
20 attn_score_buffer: Option<Vec<f64>>,
22 attn_weight_buffer: Option<Vec<f64>>,
24 qkv_buffer: Option<Vec<f64>>,
26 output_buffer: Option<Vec<f64>>,
28 batch_size: usize,
30 seq_len: usize,
32 hidden_dim: usize,
34 num_heads: usize,
36}
37
38impl TransformerMemoryPool {
39 pub fn new(batch_size: usize, seq_len: usize, hidden_dim: usize, num_heads: usize) -> Self {
41 Self {
42 attn_score_buffer: None,
43 attn_weight_buffer: None,
44 qkv_buffer: None,
45 output_buffer: None,
46 batch_size,
47 seq_len,
48 hidden_dim,
49 num_heads,
50 }
51 }
52
53 pub fn resize(
55 &mut self,
56 batch_size: usize,
57 seq_len: usize,
58 hidden_dim: usize,
59 num_heads: usize,
60 ) {
61 let needs_resize = self.batch_size != batch_size
62 || self.seq_len != seq_len
63 || self.hidden_dim != hidden_dim
64 || self.num_heads != num_heads;
65
66 if needs_resize {
67 self.batch_size = batch_size;
68 self.seq_len = seq_len;
69 self.hidden_dim = hidden_dim;
70 self.num_heads = num_heads;
71
72 self.attn_score_buffer = None;
74 self.attn_weight_buffer = None;
75 self.qkv_buffer = None;
76 self.output_buffer = None;
77 }
78 }
79
80 #[must_use]
87 pub fn get_attn_score_buffer(&mut self) -> &mut Vec<f64> {
88 if self.attn_score_buffer.is_none() {
89 let size = self.batch_size * self.num_heads * self.seq_len * self.seq_len;
90 self.attn_score_buffer = Some(vec![0.0f64; size]);
91 }
92 self.attn_score_buffer.as_mut().unwrap()
93 }
94
95 #[must_use]
102 pub fn get_attn_weight_buffer(&mut self) -> &mut Vec<f64> {
103 if self.attn_weight_buffer.is_none() {
104 let size = self.batch_size * self.num_heads * self.seq_len * self.seq_len;
105 self.attn_weight_buffer = Some(vec![0.0f64; size]);
106 }
107 self.attn_weight_buffer.as_mut().unwrap()
108 }
109
110 #[must_use]
117 pub fn get_qkv_buffer(&mut self) -> &mut Vec<f64> {
118 if self.qkv_buffer.is_none() {
119 let size = self.batch_size * self.seq_len * self.hidden_dim;
120 self.qkv_buffer = Some(vec![0.0f64; size]);
121 }
122 self.qkv_buffer.as_mut().unwrap()
123 }
124
125 #[must_use]
132 pub fn get_output_buffer(&mut self) -> &mut Vec<f64> {
133 if self.output_buffer.is_none() {
134 let size = self.batch_size * self.seq_len * self.hidden_dim;
135 self.output_buffer = Some(vec![0.0f64; size]);
136 }
137 self.output_buffer.as_mut().unwrap()
138 }
139
140 pub fn memory_bytes(&self) -> usize {
142 let mut bytes = 0;
143
144 if let Some(ref buf) = self.attn_score_buffer {
145 bytes += buf.len() * 8; }
147 if let Some(ref buf) = self.attn_weight_buffer {
148 bytes += buf.len() * 8;
149 }
150 if let Some(ref buf) = self.qkv_buffer {
151 bytes += buf.len() * 8;
152 }
153 if let Some(ref buf) = self.output_buffer {
154 bytes += buf.len() * 8;
155 }
156
157 bytes
158 }
159}
160
161impl Default for TransformerMemoryPool {
162 fn default() -> Self {
163 Self::new(1, 512, 4096, 32) }
165}
166
167pub fn softmax_inplace_simd(data: &mut [f64], shape: &[usize], dim: usize) {
174 assert!(dim < shape.len(), "Invalid dimension");
175
176 let ndim = shape.len();
177 let dim_size = shape[dim];
178
179 let mut stride = 1;
181 for &size in shape.iter().take(ndim).skip(dim + 1) {
182 stride *= size;
183 }
184
185 let outer: usize = shape[..dim].iter().product();
187 let inner: usize = shape[dim + 1..].iter().product();
188
189 #[cfg(feature = "simd")]
190 {
191 for o in 0..outer {
193 for i in 0..inner {
194 let base = o * dim_size * stride + i;
195
196 let mut max_val = f64::NEG_INFINITY;
198 for d in (0..dim_size).step_by(4) {
199 if d + 4 <= dim_size {
200 let vals = [
201 data[base + d * stride],
202 data[base + (d + 1) * stride],
203 data[base + (d + 2) * stride],
204 data[base + (d + 3) * stride],
205 ];
206 let simd_vals = f64x4::new(vals);
207 let max_simd = simd_vals.max(f64x4::new([max_val; 4]));
208 let max_arr = max_simd.to_array();
209 max_val = max_arr[0].max(max_arr[1]).max(max_arr[2]).max(max_arr[3]);
210 } else {
211 for rem_d in d..dim_size {
212 max_val = max_val.max(data[base + rem_d * stride]);
213 }
214 }
215 }
216
217 let mut sum_exp = 0.0;
219 for d in (0..dim_size).step_by(4) {
220 if d + 4 <= dim_size {
221 let vals = [
222 (data[base + d * stride] - max_val).exp(),
223 (data[base + (d + 1) * stride] - max_val).exp(),
224 (data[base + (d + 2) * stride] - max_val).exp(),
225 (data[base + (d + 3) * stride] - max_val).exp(),
226 ];
227 let simd_vals = f64x4::new(vals);
228 let sum_simd = simd_vals.reduce_add();
229 sum_exp += sum_simd;
230
231 let exp_vals = simd_vals.to_array();
233 data[base + d * stride] = exp_vals[0];
234 data[base + (d + 1) * stride] = exp_vals[1];
235 data[base + (d + 2) * stride] = exp_vals[2];
236 data[base + (d + 3) * stride] = exp_vals[3];
237 } else {
238 for rem_d in d..dim_size {
239 let exp_val = (data[base + rem_d * stride] - max_val).exp();
240 sum_exp += exp_val;
241 data[base + rem_d * stride] = exp_val;
242 }
243 }
244 }
245
246 let inv_sum = 1.0 / sum_exp;
248 let inv_sum_simd = f64x4::new([inv_sum; 4]);
249 for d in (0..dim_size).step_by(4) {
250 if d + 4 <= dim_size {
251 let vals = [
252 data[base + d * stride],
253 data[base + (d + 1) * stride],
254 data[base + (d + 2) * stride],
255 data[base + (d + 3) * stride],
256 ];
257 let simd_vals = f64x4::new(vals) * inv_sum_simd;
258 let norm_vals = simd_vals.to_array();
259 data[base + d * stride] = norm_vals[0];
260 data[base + (d + 1) * stride] = norm_vals[1];
261 data[base + (d + 2) * stride] = norm_vals[2];
262 data[base + (d + 3) * stride] = norm_vals[3];
263 } else {
264 for rem_d in d..dim_size {
265 data[base + rem_d * stride] *= inv_sum;
266 }
267 }
268 }
269 }
270 }
271 }
272
273 #[cfg(not(feature = "simd"))]
274 {
275 for o in 0..outer {
277 for i in 0..inner {
278 let base = o * dim_size * stride + i;
279
280 let max_val = (0..dim_size)
282 .map(|d| data[base + d * stride])
283 .fold(f64::NEG_INFINITY, f64::max);
284
285 let sum_exp: f64 = (0..dim_size)
287 .map(|d| {
288 let exp_val = (data[base + d * stride] - max_val).exp();
289 data[base + d * stride] = exp_val;
290 exp_val
291 })
292 .sum();
293
294 let inv_sum = 1.0 / sum_exp;
296 for d in 0..dim_size {
297 data[base + d * stride] *= inv_sum;
298 }
299 }
300 }
301 }
302}
303
304pub fn matmul_with_buffer(a: &DenseTensor, b: &DenseTensor, buffer: &mut Vec<f64>) -> DenseTensor {
314 let m = a.shape()[0];
315 let k = a.shape()[1];
316 let n = b.shape()[1];
317
318 assert_eq!(a.shape()[1], b.shape()[0], "Inner dimensions must match");
319
320 if buffer.len() < m * n {
322 *buffer = vec![0.0; m * n];
323 }
324
325 #[cfg(feature = "simd")]
326 {
327 for i in 0..m {
329 for j in (0..n).step_by(4) {
330 if j + 4 <= n {
331 let mut sum_simd = f64x4::new([0.0; 4]);
332
333 for p in 0..k {
334 let a_val = a.data()[i * k + p];
335 let a_simd = f64x4::new([a_val; 4]);
336
337 let b_vals = [
338 b.data()[p * n + j],
339 b.data()[p * n + j + 1],
340 b.data()[p * n + j + 2],
341 b.data()[p * n + j + 3],
342 ];
343 let b_simd = f64x4::new(b_vals);
344
345 sum_simd += a_simd * b_simd;
346 }
347
348 let sums = sum_simd.to_array();
349 buffer[i * n + j] = sums[0];
350 buffer[i * n + j + 1] = sums[1];
351 buffer[i * n + j + 2] = sums[2];
352 buffer[i * n + j + 3] = sums[3];
353 } else {
354 for rem_j in j..n {
356 let mut sum = 0.0;
357 for p in 0..k {
358 sum += a.data()[i * k + p] * b.data()[p * n + rem_j];
359 }
360 buffer[i * n + rem_j] = sum;
361 }
362 }
363 }
364 }
365 }
366
367 #[cfg(not(feature = "simd"))]
368 {
369 for i in 0..m {
371 for j in 0..n {
372 let mut sum = 0.0;
373 for p in 0..k {
374 sum += a.data()[i * k + p] * b.data()[p * n + j];
375 }
376 buffer[i * n + j] = sum;
377 }
378 }
379 }
380
381 DenseTensor::new(buffer[..m * n].to_vec(), vec![m, n])
382}
383
384pub mod benchmark {
386 use std::time::Instant;
387
388 pub fn measure_time<F, R>(name: &str, f: F) -> R
397 where
398 F: FnOnce() -> R,
399 {
400 let start = Instant::now();
401 let result = f();
402 let elapsed = start.elapsed();
403
404 println!("{}: {:.3} ms", name, elapsed.as_secs_f64() * 1000.0);
405 result
406 }
407
408 pub fn benchmark_throughput<F>(name: &str, iterations: usize, f: F)
415 where
416 F: Fn(),
417 {
418 let start = Instant::now();
419
420 for _ in 0..iterations {
421 f();
422 }
423
424 let elapsed = start.elapsed();
425 let ops_per_sec = iterations as f64 / elapsed.as_secs_f64();
426
427 println!(
428 "{}: {:.2} ops/sec ({:.3} ms/op)",
429 name,
430 ops_per_sec,
431 elapsed.as_secs_f64() * 1000.0 / iterations as f64
432 );
433 }
434
435 pub fn tokens_per_second(num_tokens: usize, elapsed_ms: f64) -> f64 {
441 num_tokens as f64 / (elapsed_ms / 1000.0)
442 }
443}
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448 use crate::transformer::perf::benchmark;
449
450 #[test]
451 fn test_memory_pool() {
452 let mut pool = TransformerMemoryPool::new(2, 128, 768, 8);
453
454 let attn_score_buf = pool.get_attn_score_buffer();
456 assert_eq!(attn_score_buf.len(), 2 * 8 * 128 * 128);
457
458 let attn_weight_buf = pool.get_attn_weight_buffer();
459 assert_eq!(attn_weight_buf.len(), 2 * 8 * 128 * 128);
460
461 let qkv_buf = pool.get_qkv_buffer();
462 assert_eq!(qkv_buf.len(), 2 * 128 * 768);
463
464 let output_buf = pool.get_output_buffer();
465 assert_eq!(output_buf.len(), 2 * 128 * 768);
466
467 pool.resize(4, 256, 1024, 16);
469 assert_eq!(pool.batch_size, 4);
470 assert_eq!(pool.seq_len, 256);
471 }
472
473 #[test]
474 fn test_softmax_simd() {
475 let mut data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
476 let shape = vec![2, 3];
477
478 softmax_inplace_simd(&mut data, &shape, 1);
479
480 for i in 0..2 {
482 let row_sum: f64 = data[i * 3..(i + 1) * 3].iter().sum();
483 assert!((row_sum - 1.0).abs() < 1e-5, "Row {} sum: {}", i, row_sum);
484 }
485 }
486
487 #[test]
488 fn test_matmul_with_buffer() {
489 let a = DenseTensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
490 let b = DenseTensor::new(vec![0.5, 0.5, 0.5, 0.5], vec![2, 2]);
491
492 let mut buffer = vec![0.0; 4];
493 let result = matmul_with_buffer(&a, &b, &mut buffer);
494
495 assert_eq!(result.shape(), &[2, 2]);
496 assert!((result.data()[0] - 1.5).abs() < 1e-5);
497 assert!((result.data()[1] - 1.5).abs() < 1e-5);
498 assert!((result.data()[2] - 3.5).abs() < 1e-5);
499 assert!((result.data()[3] - 3.5).abs() < 1e-5);
500 }
501
502 #[test]
503 fn test_benchmark_utils() {
504 let elapsed = std::time::Instant::now();
506 benchmark::measure_time("test", || {
507 std::thread::sleep(std::time::Duration::from_millis(10));
508 });
509 let actual_elapsed = elapsed.elapsed().as_secs_f64() * 1000.0;
510
511 assert!(actual_elapsed >= 10.0, "Should have slept for at least 10ms");
512
513 let tps = benchmark::tokens_per_second(100, 1000.0); assert!((tps - 100.0).abs() < 1e-5);
516 }
517}