1use crate::half_precision::{BF16, F16};
7
8#[cfg(feature = "no-std")]
9extern crate alloc;
10
11pub struct BatchNorm {
13 epsilon: f32,
14}
15
16impl BatchNorm {
17 pub fn new(epsilon: f32) -> Self {
19 Self { epsilon }
20 }
21
22 #[allow(clippy::too_many_arguments)] pub fn forward(
27 &self,
28 input: &[f32],
29 mean: &[f32],
30 variance: &[f32],
31 gamma: &[f32],
32 beta: &[f32],
33 output: &mut [f32],
34 batch_size: usize,
35 features: usize,
36 ) {
37 assert_eq!(input.len(), batch_size * features);
38 assert_eq!(output.len(), batch_size * features);
39 assert_eq!(mean.len(), features);
40 assert_eq!(variance.len(), features);
41 assert_eq!(gamma.len(), features);
42 assert_eq!(beta.len(), features);
43
44 for batch in 0..batch_size {
45 for feat in 0..features {
46 let idx = batch * features + feat;
47 let x = input[idx];
48 let normalized = (x - mean[feat]) / (variance[feat] + self.epsilon).sqrt();
49 output[idx] = gamma[feat] * normalized + beta[feat];
50 }
51 }
52 }
53
54 #[allow(clippy::too_many_arguments)] pub fn forward_f16(
57 &self,
58 input: &[F16],
59 mean: &[F16],
60 variance: &[F16],
61 gamma: &[F16],
62 beta: &[F16],
63 output: &mut [F16],
64 batch_size: usize,
65 features: usize,
66 ) {
67 assert_eq!(input.len(), batch_size * features);
68 assert_eq!(output.len(), batch_size * features);
69 assert_eq!(mean.len(), features);
70 assert_eq!(variance.len(), features);
71 assert_eq!(gamma.len(), features);
72 assert_eq!(beta.len(), features);
73
74 for batch in 0..batch_size {
75 for feat in 0..features {
76 let idx = batch * features + feat;
77 let x = input[idx].to_f32();
78 let m = mean[feat].to_f32();
79 let v = variance[feat].to_f32();
80 let g = gamma[feat].to_f32();
81 let b = beta[feat].to_f32();
82
83 let normalized = (x - m) / (v + self.epsilon).sqrt();
84 let result = g * normalized + b;
85 output[idx] = F16::from_f32(result);
86 }
87 }
88 }
89
90 pub fn compute_stats(
92 input: &[f32],
93 mean: &mut [f32],
94 variance: &mut [f32],
95 batch_size: usize,
96 features: usize,
97 ) {
98 assert_eq!(input.len(), batch_size * features);
99 assert_eq!(mean.len(), features);
100 assert_eq!(variance.len(), features);
101
102 for (feat, m) in mean.iter_mut().enumerate() {
104 let mut sum = 0.0;
105 for batch in 0..batch_size {
106 sum += input[batch * features + feat];
107 }
108 *m = sum / batch_size as f32;
109 }
110
111 for (feat, v) in variance.iter_mut().enumerate() {
113 let mut sum_sq_diff = 0.0;
114 for batch in 0..batch_size {
115 let diff = input[batch * features + feat] - mean[feat];
116 sum_sq_diff += diff * diff;
117 }
118 *v = sum_sq_diff / batch_size as f32;
119 }
120 }
121}
122
123pub struct LayerNorm {
125 epsilon: f32,
126}
127
128impl LayerNorm {
129 pub fn new(epsilon: f32) -> Self {
131 Self { epsilon }
132 }
133
134 pub fn forward(
137 &self,
138 input: &[f32],
139 gamma: &[f32],
140 beta: &[f32],
141 output: &mut [f32],
142 batch_size: usize,
143 features: usize,
144 ) {
145 assert_eq!(input.len(), batch_size * features);
146 assert_eq!(output.len(), batch_size * features);
147 assert_eq!(gamma.len(), features);
148 assert_eq!(beta.len(), features);
149
150 for batch in 0..batch_size {
151 let start_idx = batch * features;
152 let end_idx = start_idx + features;
153
154 let sample_slice = &input[start_idx..end_idx];
156 let mean = sample_slice.iter().sum::<f32>() / features as f32;
157
158 let sum_sq_diff: f32 = sample_slice.iter().map(|&x| (x - mean).powi(2)).sum();
160 let variance = sum_sq_diff / features as f32;
161 let std_dev = (variance + self.epsilon).sqrt();
162
163 for (i, feat) in (start_idx..end_idx).enumerate() {
165 let normalized = (input[feat] - mean) / std_dev;
166 output[feat] = gamma[i] * normalized + beta[i];
167 }
168 }
169 }
170}
171
172pub mod batch_matmul {
174 use super::*;
175
176 pub fn batch_matmul_f32(
179 a: &[f32],
180 b: &[f32],
181 c: &mut [f32],
182 batch_size: usize,
183 m: usize,
184 n: usize,
185 k: usize,
186 ) {
187 assert_eq!(a.len(), batch_size * m * k);
188 assert_eq!(b.len(), batch_size * k * n);
189 assert_eq!(c.len(), batch_size * m * n);
190
191 for batch in 0..batch_size {
192 let a_offset = batch * m * k;
193 let b_offset = batch * k * n;
194 let c_offset = batch * m * n;
195
196 for i in 0..m {
197 for j in 0..n {
198 let mut sum = 0.0;
199 for l in 0..k {
200 let a_idx = a_offset + i * k + l;
201 let b_idx = b_offset + l * n + j;
202 sum += a[a_idx] * b[b_idx];
203 }
204 let c_idx = c_offset + i * n + j;
205 c[c_idx] = sum;
206 }
207 }
208 }
209 }
210
211 pub fn batch_matmul_broadcast_f32(
214 a: &[f32],
215 b: &[f32],
216 c: &mut [f32],
217 batch_size: usize,
218 m: usize,
219 n: usize,
220 k: usize,
221 ) {
222 assert_eq!(a.len(), batch_size * m * k);
223 assert_eq!(b.len(), k * n);
224 assert_eq!(c.len(), batch_size * m * n);
225
226 for batch in 0..batch_size {
227 let a_offset = batch * m * k;
228 let c_offset = batch * m * n;
229
230 for i in 0..m {
231 for j in 0..n {
232 let mut sum = 0.0;
233 for l in 0..k {
234 let a_idx = a_offset + i * k + l;
235 let b_idx = l * n + j;
236 sum += a[a_idx] * b[b_idx];
237 }
238 let c_idx = c_offset + i * n + j;
239 c[c_idx] = sum;
240 }
241 }
242 }
243 }
244
245 pub fn batch_matmul_f16(
247 a: &[F16],
248 b: &[F16],
249 c: &mut [F16],
250 batch_size: usize,
251 m: usize,
252 n: usize,
253 k: usize,
254 ) {
255 assert_eq!(a.len(), batch_size * m * k);
256 assert_eq!(b.len(), batch_size * k * n);
257 assert_eq!(c.len(), batch_size * m * n);
258
259 for batch in 0..batch_size {
260 let a_offset = batch * m * k;
261 let b_offset = batch * k * n;
262 let c_offset = batch * m * n;
263
264 for i in 0..m {
265 for j in 0..n {
266 let mut sum = 0.0f32;
267 for l in 0..k {
268 let a_idx = a_offset + i * k + l;
269 let b_idx = b_offset + l * n + j;
270 sum += a[a_idx].to_f32() * b[b_idx].to_f32();
271 }
272 let c_idx = c_offset + i * n + j;
273 c[c_idx] = F16::from_f32(sum);
274 }
275 }
276 }
277 }
278
279 pub fn batch_matmul_bf16(
281 a: &[BF16],
282 b: &[BF16],
283 c: &mut [BF16],
284 batch_size: usize,
285 m: usize,
286 n: usize,
287 k: usize,
288 ) {
289 assert_eq!(a.len(), batch_size * m * k);
290 assert_eq!(b.len(), batch_size * k * n);
291 assert_eq!(c.len(), batch_size * m * n);
292
293 for batch in 0..batch_size {
294 let a_offset = batch * m * k;
295 let b_offset = batch * k * n;
296 let c_offset = batch * m * n;
297
298 for i in 0..m {
299 for j in 0..n {
300 let mut sum = 0.0f32;
301 for l in 0..k {
302 let a_idx = a_offset + i * k + l;
303 let b_idx = b_offset + l * n + j;
304 sum += a[a_idx].to_f32() * b[b_idx].to_f32();
305 }
306 let c_idx = c_offset + i * n + j;
307 c[c_idx] = BF16::from_f32(sum);
308 }
309 }
310 }
311 }
312}
313
314pub mod attention {
316
317 #[allow(clippy::too_many_arguments)] pub fn scaled_dot_product_attention(
322 query: &[f32],
323 key: &[f32],
324 value: &[f32],
325 output: &mut [f32],
326 batch_size: usize,
327 seq_len: usize,
328 d_model: usize,
329 mask: Option<&[bool]>,
330 ) {
331 let scale = 1.0 / (d_model as f32).sqrt();
332
333 assert_eq!(query.len(), batch_size * seq_len * d_model);
334 assert_eq!(key.len(), batch_size * seq_len * d_model);
335 assert_eq!(value.len(), batch_size * seq_len * d_model);
336 assert_eq!(output.len(), batch_size * seq_len * d_model);
337
338 #[cfg(not(feature = "no-std"))]
340 let mut scores = vec![0.0f32; batch_size * seq_len * seq_len];
341 #[cfg(feature = "no-std")]
342 let mut scores = alloc::vec![0.0f32; batch_size * seq_len * seq_len];
343
344 for batch in 0..batch_size {
345 for i in 0..seq_len {
347 for j in 0..seq_len {
348 let mut dot_product = 0.0;
349 for k in 0..d_model {
350 let q_idx = batch * seq_len * d_model + i * d_model + k;
351 let k_idx = batch * seq_len * d_model + j * d_model + k;
352 dot_product += query[q_idx] * key[k_idx];
353 }
354 let score_idx = batch * seq_len * seq_len + i * seq_len + j;
355 scores[score_idx] = dot_product * scale;
356
357 if let Some(mask) = mask {
359 if !mask[i * seq_len + j] {
360 scores[score_idx] = f32::NEG_INFINITY;
361 }
362 }
363 }
364 }
365
366 for i in 0..seq_len {
368 let row_start = batch * seq_len * seq_len + i * seq_len;
369 let row_end = row_start + seq_len;
370
371 let max_val = scores[row_start..row_end]
373 .iter()
374 .copied()
375 .fold(f32::NEG_INFINITY, f32::max);
376
377 let row = &mut scores[row_start..row_end];
379 for s in row.iter_mut() {
380 *s = (*s - max_val).exp();
381 }
382 let sum_exp: f32 = scores[row_start..row_end].iter().sum();
383
384 for s in scores[row_start..row_end].iter_mut() {
386 *s /= sum_exp;
387 }
388 }
389
390 for i in 0..seq_len {
392 for k in 0..d_model {
393 let mut weighted_sum = 0.0;
394 for j in 0..seq_len {
395 let attention_weight = scores[batch * seq_len * seq_len + i * seq_len + j];
396 let v_idx = batch * seq_len * d_model + j * d_model + k;
397 weighted_sum += attention_weight * value[v_idx];
398 }
399 let out_idx = batch * seq_len * d_model + i * d_model + k;
400 output[out_idx] = weighted_sum;
401 }
402 }
403 }
404 }
405
406 #[allow(clippy::too_many_arguments)] pub fn multi_head_attention(
409 query: &[f32],
410 key: &[f32],
411 value: &[f32],
412 output: &mut [f32],
413 batch_size: usize,
414 seq_len: usize,
415 d_model: usize,
416 num_heads: usize,
417 mask: Option<&[bool]>,
418 ) {
419 assert_eq!(d_model % num_heads, 0);
420 let d_k = d_model / num_heads;
421
422 assert_eq!(query.len(), batch_size * seq_len * d_model);
423 assert_eq!(key.len(), batch_size * seq_len * d_model);
424 assert_eq!(value.len(), batch_size * seq_len * d_model);
425 assert_eq!(output.len(), batch_size * seq_len * d_model);
426
427 #[cfg(not(feature = "no-std"))]
428 let mut head_outputs = vec![0.0f32; batch_size * num_heads * seq_len * d_k];
429 #[cfg(feature = "no-std")]
430 let mut head_outputs = alloc::vec![0.0f32; batch_size * num_heads * seq_len * d_k];
431
432 for head in 0..num_heads {
434 let head_start = head * d_k;
435 let _head_end = head_start + d_k;
436
437 #[cfg(not(feature = "no-std"))]
439 let mut head_q = vec![0.0f32; batch_size * seq_len * d_k];
440 #[cfg(feature = "no-std")]
441 let mut head_q = alloc::vec![0.0f32; batch_size * seq_len * d_k];
442 #[cfg(not(feature = "no-std"))]
443 let mut head_k = vec![0.0f32; batch_size * seq_len * d_k];
444 #[cfg(feature = "no-std")]
445 let mut head_k = alloc::vec![0.0f32; batch_size * seq_len * d_k];
446 #[cfg(not(feature = "no-std"))]
447 let mut head_v = vec![0.0f32; batch_size * seq_len * d_k];
448 #[cfg(feature = "no-std")]
449 let mut head_v = alloc::vec![0.0f32; batch_size * seq_len * d_k];
450
451 for batch in 0..batch_size {
452 for seq in 0..seq_len {
453 for d in 0..d_k {
454 let src_idx = batch * seq_len * d_model + seq * d_model + head_start + d;
455 let dst_idx = batch * seq_len * d_k + seq * d_k + d;
456 head_q[dst_idx] = query[src_idx];
457 head_k[dst_idx] = key[src_idx];
458 head_v[dst_idx] = value[src_idx];
459 }
460 }
461 }
462
463 #[cfg(not(feature = "no-std"))]
465 let mut head_output = vec![0.0f32; batch_size * seq_len * d_k];
466 #[cfg(feature = "no-std")]
467 let mut head_output = alloc::vec![0.0f32; batch_size * seq_len * d_k];
468 scaled_dot_product_attention(
469 &head_q,
470 &head_k,
471 &head_v,
472 &mut head_output,
473 batch_size,
474 seq_len,
475 d_k,
476 mask,
477 );
478
479 let head_offset = head * batch_size * seq_len * d_k;
481 head_outputs[head_offset..head_offset + head_output.len()]
482 .copy_from_slice(&head_output);
483 }
484
485 for batch in 0..batch_size {
487 for seq in 0..seq_len {
488 for head in 0..num_heads {
489 for d in 0..d_k {
490 let src_idx = head * batch_size * seq_len * d_k
491 + batch * seq_len * d_k
492 + seq * d_k
493 + d;
494 let dst_idx = batch * seq_len * d_model + seq * d_model + head * d_k + d;
495 output[dst_idx] = head_outputs[src_idx];
496 }
497 }
498 }
499 }
500 }
501}
502
503pub mod convolution {
505
506 #[allow(clippy::too_many_arguments)] pub fn conv2d_batch(
512 input: &[f32],
513 weight: &[f32],
514 bias: &[f32],
515 output: &mut [f32],
516 batch_size: usize,
517 in_channels: usize,
518 out_channels: usize,
519 input_height: usize,
520 input_width: usize,
521 kernel_height: usize,
522 kernel_width: usize,
523 stride_h: usize,
524 stride_w: usize,
525 padding_h: usize,
526 padding_w: usize,
527 ) {
528 let output_height = (input_height + 2 * padding_h - kernel_height) / stride_h + 1;
529 let output_width = (input_width + 2 * padding_w - kernel_width) / stride_w + 1;
530
531 assert_eq!(
532 input.len(),
533 batch_size * in_channels * input_height * input_width
534 );
535 assert_eq!(
536 weight.len(),
537 out_channels * in_channels * kernel_height * kernel_width
538 );
539 assert_eq!(bias.len(), out_channels);
540 assert_eq!(
541 output.len(),
542 batch_size * out_channels * output_height * output_width
543 );
544
545 for batch in 0..batch_size {
546 for (out_ch, &bias_val) in bias.iter().enumerate() {
547 for out_y in 0..output_height {
548 for out_x in 0..output_width {
549 let mut sum = bias_val;
550
551 for in_ch in 0..in_channels {
552 for ky in 0..kernel_height {
553 for kx in 0..kernel_width {
554 let in_y = out_y * stride_h + ky;
555 let in_x = out_x * stride_w + kx;
556
557 if in_y >= padding_h
558 && in_x >= padding_w
559 && in_y < input_height + padding_h
560 && in_x < input_width + padding_w
561 {
562 let input_y = in_y - padding_h;
563 let input_x = in_x - padding_w;
564
565 let input_idx =
566 batch * in_channels * input_height * input_width
567 + in_ch * input_height * input_width
568 + input_y * input_width
569 + input_x;
570 let weight_idx =
571 out_ch * in_channels * kernel_height * kernel_width
572 + in_ch * kernel_height * kernel_width
573 + ky * kernel_width
574 + kx;
575
576 sum += input[input_idx] * weight[weight_idx];
577 }
578 }
579 }
580 }
581
582 let output_idx = batch * out_channels * output_height * output_width
583 + out_ch * output_height * output_width
584 + out_y * output_width
585 + out_x;
586 output[output_idx] = sum;
587 }
588 }
589 }
590 }
591 }
592}
593
594#[allow(non_snake_case)]
595#[cfg(all(test, not(feature = "no-std")))]
596mod tests {
597 use super::*;
598
599 #[cfg(feature = "no-std")]
600 use alloc::{vec, vec::Vec};
601
602 #[test]
603 fn test_batch_norm() {
604 let batch_norm = BatchNorm::new(1e-5);
605 let batch_size = 2;
606 let features = 3;
607
608 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
609 let mean = vec![2.5, 3.5, 4.5];
610 let variance = vec![2.25, 2.25, 2.25];
611 let gamma = vec![1.0, 1.0, 1.0];
612 let beta = vec![0.0, 0.0, 0.0];
613 let mut output = vec![0.0; 6];
614
615 batch_norm.forward(
616 &input,
617 &mean,
618 &variance,
619 &gamma,
620 &beta,
621 &mut output,
622 batch_size,
623 features,
624 );
625
626 for &val in &output {
628 assert!(val.abs() < 2.0); }
630 }
631
632 #[test]
633 fn test_layer_norm() {
634 let layer_norm = LayerNorm::new(1e-5);
635 let batch_size = 2;
636 let features = 3;
637
638 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
639 let gamma = vec![1.0, 1.0, 1.0];
640 let beta = vec![0.0, 0.0, 0.0];
641 let mut output = vec![0.0; 6];
642
643 layer_norm.forward(&input, &gamma, &beta, &mut output, batch_size, features);
644
645 for batch in 0..batch_size {
647 let start = batch * features;
648 let end = start + features;
649 let sample_mean: f32 = output[start..end].iter().sum::<f32>() / features as f32;
650 assert!((sample_mean).abs() < 1e-5);
651 }
652 }
653
654 #[test]
655 fn test_batch_matmul() {
656 let batch_size = 2;
657 let m = 2;
658 let n = 2;
659 let k = 2;
660
661 let a = vec![
662 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ];
665 let b = vec![
666 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, ];
669 let mut c = vec![0.0; batch_size * m * n];
670
671 batch_matmul::batch_matmul_f32(&a, &b, &mut c, batch_size, m, n, k);
672
673 let expected = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
675 for i in 0..expected.len() {
676 assert!((c[i] - expected[i]).abs() < 1e-5);
677 }
678 }
679
680 #[test]
681 fn test_batch_matmul_broadcast() {
682 let batch_size = 2;
683 let m = 2;
684 let n = 2;
685 let k = 2;
686
687 let a = vec![
688 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ];
691 let b = vec![1.0, 0.0, 0.0, 1.0]; let mut c = vec![0.0; batch_size * m * n];
693
694 batch_matmul::batch_matmul_broadcast_f32(&a, &b, &mut c, batch_size, m, n, k);
695
696 let expected = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
698 for i in 0..expected.len() {
699 assert!((c[i] - expected[i]).abs() < 1e-5);
700 }
701 }
702
703 #[test]
704 fn test_attention_basic() {
705 let batch_size = 1;
706 let seq_len = 3;
707 let d_model = 4;
708
709 let query = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0];
711 let key = query.clone();
712 let value = vec![
713 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
714 ];
715 let mut output = vec![0.0; batch_size * seq_len * d_model];
716
717 attention::scaled_dot_product_attention(
718 &query,
719 &key,
720 &value,
721 &mut output,
722 batch_size,
723 seq_len,
724 d_model,
725 None,
726 );
727
728 assert_eq!(output.len(), 12);
730 for &val in &output {
732 assert!(val.is_finite());
733 }
734 }
735
736 #[test]
737 fn test_conv2d_batch_simple() {
738 let batch_size = 1;
739 let in_channels = 1;
740 let out_channels = 1;
741 let input_height = 3;
742 let input_width = 3;
743 let kernel_height = 2;
744 let kernel_width = 2;
745
746 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
747 let weight = vec![1.0, 0.0, 0.0, 1.0]; let bias = vec![0.0];
749
750 let output_height = input_height - kernel_height + 1;
751 let output_width = input_width - kernel_width + 1;
752 let mut output = vec![0.0; batch_size * out_channels * output_height * output_width];
753
754 convolution::conv2d_batch(
755 &input,
756 &weight,
757 &bias,
758 &mut output,
759 batch_size,
760 in_channels,
761 out_channels,
762 input_height,
763 input_width,
764 kernel_height,
765 kernel_width,
766 1,
767 1,
768 0,
769 0,
770 );
771
772 for &val in &output {
774 assert!(val.is_finite());
775 }
776 assert_eq!(output.len(), 4); }
778
779 #[test]
780 fn test_batch_norm_f16() {
781 let batch_norm = BatchNorm::new(1e-3); let batch_size = 2;
783 let features = 3;
784
785 let input = vec![
786 F16::from_f32(1.0),
787 F16::from_f32(2.0),
788 F16::from_f32(3.0),
789 F16::from_f32(4.0),
790 F16::from_f32(5.0),
791 F16::from_f32(6.0),
792 ];
793 let mean = vec![F16::from_f32(2.5), F16::from_f32(3.5), F16::from_f32(4.5)];
794 let variance = vec![
795 F16::from_f32(2.25),
796 F16::from_f32(2.25),
797 F16::from_f32(2.25),
798 ];
799 let gamma = vec![F16::from_f32(1.0), F16::from_f32(1.0), F16::from_f32(1.0)];
800 let beta = vec![F16::from_f32(0.0), F16::from_f32(0.0), F16::from_f32(0.0)];
801 let mut output = vec![F16::from_bits(0); 6];
802
803 batch_norm.forward_f16(
804 &input,
805 &mean,
806 &variance,
807 &gamma,
808 &beta,
809 &mut output,
810 batch_size,
811 features,
812 );
813
814 for &val in &output {
816 assert!(val.to_f32().abs() < 2.0);
817 }
818 }
819
820 #[test]
821 fn test_batch_stats_computation() {
822 let batch_size = 4;
823 let features = 2;
824
825 let input = vec![
826 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ];
831 let mut mean = vec![0.0; features];
832 let mut variance = vec![0.0; features];
833
834 BatchNorm::compute_stats(&input, &mut mean, &mut variance, batch_size, features);
835
836 assert!((mean[0] - 4.0).abs() < 1e-6);
838 assert!((mean[1] - 5.0).abs() < 1e-6);
839
840 assert!((variance[0] - 5.0).abs() < 1e-6);
842 assert!((variance[1] - 5.0).abs() < 1e-6);
843 }
844}