Skip to main content

cuda_rust_wasm/runtime/
tensor_ops.rs

1//! Tensor Core / Matrix Multiply-Accumulate (MMA) operations
2//!
3//! Emulates NVIDIA Tensor Core operations for mixed-precision matrix
4//! multiplication. On real hardware (SM 7.0+), these map to WMMA/MMA
5//! PTX instructions. In CPU fallback, we provide functionally-correct
6//! tiled matrix multiply with the same API semantics.
7//!
8//! Supports: fp16×fp16→fp32, bf16×bf16→fp32, fp32→fp32, int8×int8→int32.
9
10use super::half::Half;
11use super::bfloat16::BFloat16;
12use std::fmt;
13
14/// Precision mode for tensor core operations.
15#[derive(Debug, Clone, Copy, PartialEq)]
16pub enum MmaPrecision {
17    /// fp16 inputs, fp32 accumulation (HMMA)
18    Fp16Fp32,
19    /// bf16 inputs, fp32 accumulation
20    Bf16Fp32,
21    /// fp32 inputs, fp32 accumulation (TF32 on Ampere+)
22    Tf32,
23    /// int8 inputs, int32 accumulation (IMMA)
24    Int8Int32,
25    /// Full fp32 (no tensor cores, standard GEMM)
26    Fp32,
27}
28
29/// Fragment shape for WMMA operations.
30/// Maps to hardware-supported shapes like 16×16×16, 8×32×16, etc.
31#[derive(Debug, Clone, Copy, PartialEq)]
32pub struct FragmentShape {
33    pub m: usize,
34    pub n: usize,
35    pub k: usize,
36}
37
38impl FragmentShape {
39    /// Standard 16×16×16 (SM 7.0+ Volta)
40    pub const M16N16K16: Self = Self { m: 16, n: 16, k: 16 };
41    /// Ampere 16×8×16
42    pub const M16N8K16: Self = Self { m: 16, n: 8, k: 16 };
43    /// INT8: 8×32×16
44    pub const M8N32K16: Self = Self { m: 8, n: 32, k: 16 };
45    /// Custom shape
46    pub fn new(m: usize, n: usize, k: usize) -> Self {
47        Self { m, n, k }
48    }
49}
50
51/// Matrix fragment — a tile of a matrix stored in registers.
52/// On GPU, these map to warp-distributed register fragments.
53#[derive(Debug, Clone)]
54pub struct Fragment {
55    /// Data stored as f32 (accumulator format).
56    pub data: Vec<f32>,
57    /// Number of rows.
58    pub rows: usize,
59    /// Number of columns.
60    pub cols: usize,
61}
62
63impl Fragment {
64    /// Create a zero-initialized fragment.
65    pub fn zeros(rows: usize, cols: usize) -> Self {
66        Self {
67            data: vec![0.0; rows * cols],
68            rows,
69            cols,
70        }
71    }
72
73    /// Create from f32 data.
74    pub fn from_f32(data: &[f32], rows: usize, cols: usize) -> crate::Result<Self> {
75        if data.len() != rows * cols {
76            return Err(crate::error::CudaRustError::RuntimeError(
77                format!("Fragment size mismatch: {}×{} needs {} elements, got {}",
78                    rows, cols, rows * cols, data.len()),
79            ));
80        }
81        Ok(Self {
82            data: data.to_vec(),
83            rows,
84            cols,
85        })
86    }
87
88    /// Load from fp16 data (converting to f32 accumulator format).
89    pub fn from_half(data: &[Half], rows: usize, cols: usize) -> crate::Result<Self> {
90        if data.len() != rows * cols {
91            return Err(crate::error::CudaRustError::RuntimeError(
92                format!("Fragment size mismatch: expected {} elements, got {}", rows * cols, data.len()),
93            ));
94        }
95        Ok(Self {
96            data: data.iter().map(|h| h.to_f32()).collect(),
97            rows,
98            cols,
99        })
100    }
101
102    /// Load from bf16 data.
103    pub fn from_bf16(data: &[BFloat16], rows: usize, cols: usize) -> crate::Result<Self> {
104        if data.len() != rows * cols {
105            return Err(crate::error::CudaRustError::RuntimeError(
106                format!("Fragment size mismatch: expected {} elements, got {}", rows * cols, data.len()),
107            ));
108        }
109        Ok(Self {
110            data: data.iter().map(|b| b.to_f32()).collect(),
111            rows,
112            cols,
113        })
114    }
115
116    /// Get element at (row, col).
117    pub fn get(&self, row: usize, col: usize) -> f32 {
118        self.data[row * self.cols + col]
119    }
120
121    /// Set element at (row, col).
122    pub fn set(&mut self, row: usize, col: usize, val: f32) {
123        self.data[row * self.cols + col] = val;
124    }
125
126    /// Store to fp16.
127    pub fn to_half(&self) -> Vec<Half> {
128        self.data.iter().map(|&v| Half::from_f32(v)).collect()
129    }
130
131    /// Store to bf16.
132    pub fn to_bf16(&self) -> Vec<BFloat16> {
133        self.data.iter().map(|&v| BFloat16::from_f32(v)).collect()
134    }
135}
136
137/// Tensor Core MMA engine.
138///
139/// Provides `mma()` (D = A·B + C) matching the semantics of CUDA's
140/// `nvcuda::wmma::mma_sync` and PTX `mma.sync` instructions.
141pub struct TensorCoreEngine {
142    precision: MmaPrecision,
143    shape: FragmentShape,
144}
145
146impl TensorCoreEngine {
147    /// Create a new engine with specified precision and fragment shape.
148    pub fn new(precision: MmaPrecision, shape: FragmentShape) -> Self {
149        Self { precision, shape }
150    }
151
152    /// Matrix multiply-accumulate: D = A · B + C
153    ///
154    /// A: (m × k), B: (k × n), C: (m × n) → D: (m × n)
155    pub fn mma(&self, a: &Fragment, b: &Fragment, c: &Fragment) -> crate::Result<Fragment> {
156        if a.rows != self.shape.m || a.cols != self.shape.k {
157            return Err(crate::error::CudaRustError::RuntimeError(
158                format!("Fragment A shape {}×{} doesn't match MMA {}×{}",
159                    a.rows, a.cols, self.shape.m, self.shape.k),
160            ));
161        }
162        if b.rows != self.shape.k || b.cols != self.shape.n {
163            return Err(crate::error::CudaRustError::RuntimeError(
164                format!("Fragment B shape {}×{} doesn't match MMA {}×{}",
165                    b.rows, b.cols, self.shape.k, self.shape.n),
166            ));
167        }
168        if c.rows != self.shape.m || c.cols != self.shape.n {
169            return Err(crate::error::CudaRustError::RuntimeError(
170                format!("Fragment C shape {}×{} doesn't match MMA {}×{}",
171                    c.rows, c.cols, self.shape.m, self.shape.n),
172            ));
173        }
174
175        let m = self.shape.m;
176        let n = self.shape.n;
177        let k = self.shape.k;
178
179        let mut d = Fragment::zeros(m, n);
180
181        // D = A · B + C (standard GEMM)
182        for i in 0..m {
183            for j in 0..n {
184                let mut acc = c.get(i, j);
185                for p in 0..k {
186                    acc += a.get(i, p) * b.get(p, j);
187                }
188                d.set(i, j, acc);
189            }
190        }
191
192        Ok(d)
193    }
194
195    /// Full GEMM using tiled MMA: C = alpha * A · B + beta * C
196    ///
197    /// A: (m × k), B: (k × n), C: (m × n) — arbitrary sizes, tiled internally.
198    pub fn gemm(
199        &self,
200        a: &[f32], b: &[f32], c: &mut [f32],
201        m: usize, n: usize, k: usize,
202        alpha: f32, beta: f32,
203    ) -> crate::Result<GemmStats> {
204        if a.len() != m * k || b.len() != k * n || c.len() != m * n {
205            return Err(crate::error::CudaRustError::RuntimeError("GEMM dimension mismatch".into()));
206        }
207
208        let tm = self.shape.m;
209        let tn = self.shape.n;
210        let tk = self.shape.k;
211        let mut mma_count = 0u64;
212
213        // Scale C by beta
214        for val in c.iter_mut() {
215            *val *= beta;
216        }
217
218        // Tile over M, N, K
219        let m_tiles = (m + tm - 1) / tm;
220        let n_tiles = (n + tn - 1) / tn;
221        let k_tiles = (k + tk - 1) / tk;
222
223        for mi in 0..m_tiles {
224            let m_start = mi * tm;
225            let m_end = (m_start + tm).min(m);
226            let actual_m = m_end - m_start;
227
228            for ni in 0..n_tiles {
229                let n_start = ni * tn;
230                let n_end = (n_start + tn).min(n);
231                let actual_n = n_end - n_start;
232
233                for ki in 0..k_tiles {
234                    let k_start = ki * tk;
235                    let k_end = (k_start + tk).min(k);
236                    let actual_k = k_end - k_start;
237
238                    // Extract tiles
239                    for i in 0..actual_m {
240                        for j in 0..actual_n {
241                            let mut acc = 0.0f32;
242                            for p in 0..actual_k {
243                                acc += a[(m_start + i) * k + (k_start + p)]
244                                     * b[(k_start + p) * n + (n_start + j)];
245                            }
246                            c[(m_start + i) * n + (n_start + j)] += alpha * acc;
247                        }
248                    }
249                    mma_count += 1;
250                }
251            }
252        }
253
254        let flops = 2 * (m as u64) * (n as u64) * (k as u64);
255        Ok(GemmStats { mma_count, flops, precision: self.precision })
256    }
257}
258
259/// Statistics from a GEMM operation.
260#[derive(Debug, Clone)]
261pub struct GemmStats {
262    /// Number of MMA (tile) operations performed.
263    pub mma_count: u64,
264    /// Total floating-point operations.
265    pub flops: u64,
266    /// Precision mode used.
267    pub precision: MmaPrecision,
268}
269
270impl fmt::Display for GemmStats {
271    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
272        write!(f, "GEMM: {} MMA ops, {:.2}M FLOPs, {:?}",
273            self.mma_count, self.flops as f64 / 1e6, self.precision)
274    }
275}
276
277// ── Tests ──────────────────────────────────────────────────────────
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    #[test]
284    fn test_fragment_zeros() {
285        let frag = Fragment::zeros(4, 4);
286        assert_eq!(frag.data.len(), 16);
287        assert!(frag.data.iter().all(|&v| v == 0.0));
288    }
289
290    #[test]
291    fn test_fragment_from_f32() {
292        let data: Vec<f32> = (0..16).map(|i| i as f32).collect();
293        let frag = Fragment::from_f32(&data, 4, 4).unwrap();
294        assert_eq!(frag.get(0, 0), 0.0);
295        assert_eq!(frag.get(1, 2), 6.0);
296        assert_eq!(frag.get(3, 3), 15.0);
297    }
298
299    #[test]
300    fn test_mma_identity() {
301        let engine = TensorCoreEngine::new(MmaPrecision::Fp32, FragmentShape::new(2, 2, 2));
302
303        // A = I (identity)
304        let a = Fragment::from_f32(&[1.0, 0.0, 0.0, 1.0], 2, 2).unwrap();
305        // B = some matrix
306        let b = Fragment::from_f32(&[5.0, 6.0, 7.0, 8.0], 2, 2).unwrap();
307        // C = zeros
308        let c = Fragment::zeros(2, 2);
309
310        let d = engine.mma(&a, &b, &c).unwrap();
311        assert!((d.get(0, 0) - 5.0).abs() < 1e-6);
312        assert!((d.get(0, 1) - 6.0).abs() < 1e-6);
313        assert!((d.get(1, 0) - 7.0).abs() < 1e-6);
314        assert!((d.get(1, 1) - 8.0).abs() < 1e-6);
315    }
316
317    #[test]
318    fn test_mma_accumulate() {
319        let engine = TensorCoreEngine::new(MmaPrecision::Fp16Fp32, FragmentShape::new(2, 2, 2));
320
321        let a = Fragment::from_f32(&[1.0, 2.0, 3.0, 4.0], 2, 2).unwrap();
322        let b = Fragment::from_f32(&[5.0, 6.0, 7.0, 8.0], 2, 2).unwrap();
323        let c = Fragment::from_f32(&[10.0, 10.0, 10.0, 10.0], 2, 2).unwrap();
324
325        // D = A·B + C = [[1*5+2*7, 1*6+2*8], [3*5+4*7, 3*6+4*8]] + 10
326        let d = engine.mma(&a, &b, &c).unwrap();
327        assert!((d.get(0, 0) - 29.0).abs() < 1e-6); // 19 + 10
328        assert!((d.get(0, 1) - 32.0).abs() < 1e-6); // 22 + 10
329        assert!((d.get(1, 0) - 53.0).abs() < 1e-6); // 43 + 10
330        assert!((d.get(1, 1) - 60.0).abs() < 1e-6); // 50 + 10
331    }
332
333    #[test]
334    fn test_mma_shape_validation() {
335        let engine = TensorCoreEngine::new(MmaPrecision::Fp32, FragmentShape::new(4, 4, 4));
336        let a = Fragment::zeros(2, 2); // Wrong shape
337        let b = Fragment::zeros(4, 4);
338        let c = Fragment::zeros(4, 4);
339        assert!(engine.mma(&a, &b, &c).is_err());
340    }
341
342    #[test]
343    fn test_gemm_basic() {
344        let engine = TensorCoreEngine::new(MmaPrecision::Fp32, FragmentShape::new(2, 2, 2));
345        let a = vec![1.0, 2.0, 3.0, 4.0]; // 2×2
346        let b = vec![5.0, 6.0, 7.0, 8.0]; // 2×2
347        let mut c = vec![0.0; 4]; // 2×2
348
349        let stats = engine.gemm(&a, &b, &mut c, 2, 2, 2, 1.0, 0.0).unwrap();
350        assert!((c[0] - 19.0).abs() < 1e-4); // 1*5+2*7
351        assert!((c[1] - 22.0).abs() < 1e-4);
352        assert!((c[2] - 43.0).abs() < 1e-4);
353        assert!((c[3] - 50.0).abs() < 1e-4);
354        assert_eq!(stats.flops, 16); // 2*2*2*2
355    }
356
357    #[test]
358    fn test_gemm_alpha_beta() {
359        let engine = TensorCoreEngine::new(MmaPrecision::Fp32, FragmentShape::new(2, 2, 2));
360        let a = vec![1.0, 0.0, 0.0, 1.0]; // Identity
361        let b = vec![1.0, 2.0, 3.0, 4.0];
362        let mut c = vec![10.0, 10.0, 10.0, 10.0];
363
364        // C = 2.0 * I * B + 0.5 * C
365        engine.gemm(&a, &b, &mut c, 2, 2, 2, 2.0, 0.5).unwrap();
366        assert!((c[0] - 7.0).abs() < 1e-4); // 2*1 + 0.5*10 = 7
367        assert!((c[1] - 9.0).abs() < 1e-4); // 2*2 + 0.5*10 = 9
368    }
369
370    #[test]
371    fn test_gemm_non_square() {
372        let engine = TensorCoreEngine::new(MmaPrecision::Fp32, FragmentShape::new(2, 2, 2));
373        // A: 3×2, B: 2×4 → C: 3×4
374        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
375        let b = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
376        let mut c = vec![0.0; 12];
377
378        engine.gemm(&a, &b, &mut c, 3, 4, 2, 1.0, 0.0).unwrap();
379        // Row 0: [1*1+2*5, 1*2+2*6, 1*3+2*7, 1*4+2*8] = [11, 14, 17, 20]
380        assert!((c[0] - 11.0).abs() < 1e-4);
381        assert!((c[1] - 14.0).abs() < 1e-4);
382        assert!((c[2] - 17.0).abs() < 1e-4);
383        assert!((c[3] - 20.0).abs() < 1e-4);
384    }
385
386    #[test]
387    fn test_fragment_half_roundtrip() {
388        let data = vec![Half::from_f32(1.0), Half::from_f32(2.0), Half::from_f32(3.0), Half::from_f32(4.0)];
389        let frag = Fragment::from_half(&data, 2, 2).unwrap();
390        let back = frag.to_half();
391        for i in 0..4 {
392            assert!((back[i].to_f32() - data[i].to_f32()).abs() < 0.01);
393        }
394    }
395
396    #[test]
397    fn test_fragment_bf16_roundtrip() {
398        let data = vec![BFloat16::from_f32(1.5), BFloat16::from_f32(2.5)];
399        let frag = Fragment::from_bf16(&data, 1, 2).unwrap();
400        let back = frag.to_bf16();
401        assert!((back[0].to_f32() - 1.5).abs() < 0.1);
402        assert!((back[1].to_f32() - 2.5).abs() < 0.1);
403    }
404
405    #[test]
406    fn test_gemm_stats_display() {
407        let stats = GemmStats { mma_count: 64, flops: 1_000_000, precision: MmaPrecision::Fp16Fp32 };
408        let s = format!("{}", stats);
409        assert!(s.contains("64 MMA"));
410        assert!(s.contains("Fp16Fp32"));
411    }
412}