1use super::half::Half;
11use super::bfloat16::BFloat16;
12use std::fmt;
13
14#[derive(Debug, Clone, Copy, PartialEq)]
16pub enum MmaPrecision {
17 Fp16Fp32,
19 Bf16Fp32,
21 Tf32,
23 Int8Int32,
25 Fp32,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq)]
32pub struct FragmentShape {
33 pub m: usize,
34 pub n: usize,
35 pub k: usize,
36}
37
38impl FragmentShape {
39 pub const M16N16K16: Self = Self { m: 16, n: 16, k: 16 };
41 pub const M16N8K16: Self = Self { m: 16, n: 8, k: 16 };
43 pub const M8N32K16: Self = Self { m: 8, n: 32, k: 16 };
45 pub fn new(m: usize, n: usize, k: usize) -> Self {
47 Self { m, n, k }
48 }
49}
50
51#[derive(Debug, Clone)]
54pub struct Fragment {
55 pub data: Vec<f32>,
57 pub rows: usize,
59 pub cols: usize,
61}
62
63impl Fragment {
64 pub fn zeros(rows: usize, cols: usize) -> Self {
66 Self {
67 data: vec![0.0; rows * cols],
68 rows,
69 cols,
70 }
71 }
72
73 pub fn from_f32(data: &[f32], rows: usize, cols: usize) -> crate::Result<Self> {
75 if data.len() != rows * cols {
76 return Err(crate::error::CudaRustError::RuntimeError(
77 format!("Fragment size mismatch: {}×{} needs {} elements, got {}",
78 rows, cols, rows * cols, data.len()),
79 ));
80 }
81 Ok(Self {
82 data: data.to_vec(),
83 rows,
84 cols,
85 })
86 }
87
88 pub fn from_half(data: &[Half], rows: usize, cols: usize) -> crate::Result<Self> {
90 if data.len() != rows * cols {
91 return Err(crate::error::CudaRustError::RuntimeError(
92 format!("Fragment size mismatch: expected {} elements, got {}", rows * cols, data.len()),
93 ));
94 }
95 Ok(Self {
96 data: data.iter().map(|h| h.to_f32()).collect(),
97 rows,
98 cols,
99 })
100 }
101
102 pub fn from_bf16(data: &[BFloat16], rows: usize, cols: usize) -> crate::Result<Self> {
104 if data.len() != rows * cols {
105 return Err(crate::error::CudaRustError::RuntimeError(
106 format!("Fragment size mismatch: expected {} elements, got {}", rows * cols, data.len()),
107 ));
108 }
109 Ok(Self {
110 data: data.iter().map(|b| b.to_f32()).collect(),
111 rows,
112 cols,
113 })
114 }
115
116 pub fn get(&self, row: usize, col: usize) -> f32 {
118 self.data[row * self.cols + col]
119 }
120
121 pub fn set(&mut self, row: usize, col: usize, val: f32) {
123 self.data[row * self.cols + col] = val;
124 }
125
126 pub fn to_half(&self) -> Vec<Half> {
128 self.data.iter().map(|&v| Half::from_f32(v)).collect()
129 }
130
131 pub fn to_bf16(&self) -> Vec<BFloat16> {
133 self.data.iter().map(|&v| BFloat16::from_f32(v)).collect()
134 }
135}
136
137pub struct TensorCoreEngine {
142 precision: MmaPrecision,
143 shape: FragmentShape,
144}
145
146impl TensorCoreEngine {
147 pub fn new(precision: MmaPrecision, shape: FragmentShape) -> Self {
149 Self { precision, shape }
150 }
151
152 pub fn mma(&self, a: &Fragment, b: &Fragment, c: &Fragment) -> crate::Result<Fragment> {
156 if a.rows != self.shape.m || a.cols != self.shape.k {
157 return Err(crate::error::CudaRustError::RuntimeError(
158 format!("Fragment A shape {}×{} doesn't match MMA {}×{}",
159 a.rows, a.cols, self.shape.m, self.shape.k),
160 ));
161 }
162 if b.rows != self.shape.k || b.cols != self.shape.n {
163 return Err(crate::error::CudaRustError::RuntimeError(
164 format!("Fragment B shape {}×{} doesn't match MMA {}×{}",
165 b.rows, b.cols, self.shape.k, self.shape.n),
166 ));
167 }
168 if c.rows != self.shape.m || c.cols != self.shape.n {
169 return Err(crate::error::CudaRustError::RuntimeError(
170 format!("Fragment C shape {}×{} doesn't match MMA {}×{}",
171 c.rows, c.cols, self.shape.m, self.shape.n),
172 ));
173 }
174
175 let m = self.shape.m;
176 let n = self.shape.n;
177 let k = self.shape.k;
178
179 let mut d = Fragment::zeros(m, n);
180
181 for i in 0..m {
183 for j in 0..n {
184 let mut acc = c.get(i, j);
185 for p in 0..k {
186 acc += a.get(i, p) * b.get(p, j);
187 }
188 d.set(i, j, acc);
189 }
190 }
191
192 Ok(d)
193 }
194
195 pub fn gemm(
199 &self,
200 a: &[f32], b: &[f32], c: &mut [f32],
201 m: usize, n: usize, k: usize,
202 alpha: f32, beta: f32,
203 ) -> crate::Result<GemmStats> {
204 if a.len() != m * k || b.len() != k * n || c.len() != m * n {
205 return Err(crate::error::CudaRustError::RuntimeError("GEMM dimension mismatch".into()));
206 }
207
208 let tm = self.shape.m;
209 let tn = self.shape.n;
210 let tk = self.shape.k;
211 let mut mma_count = 0u64;
212
213 for val in c.iter_mut() {
215 *val *= beta;
216 }
217
218 let m_tiles = (m + tm - 1) / tm;
220 let n_tiles = (n + tn - 1) / tn;
221 let k_tiles = (k + tk - 1) / tk;
222
223 for mi in 0..m_tiles {
224 let m_start = mi * tm;
225 let m_end = (m_start + tm).min(m);
226 let actual_m = m_end - m_start;
227
228 for ni in 0..n_tiles {
229 let n_start = ni * tn;
230 let n_end = (n_start + tn).min(n);
231 let actual_n = n_end - n_start;
232
233 for ki in 0..k_tiles {
234 let k_start = ki * tk;
235 let k_end = (k_start + tk).min(k);
236 let actual_k = k_end - k_start;
237
238 for i in 0..actual_m {
240 for j in 0..actual_n {
241 let mut acc = 0.0f32;
242 for p in 0..actual_k {
243 acc += a[(m_start + i) * k + (k_start + p)]
244 * b[(k_start + p) * n + (n_start + j)];
245 }
246 c[(m_start + i) * n + (n_start + j)] += alpha * acc;
247 }
248 }
249 mma_count += 1;
250 }
251 }
252 }
253
254 let flops = 2 * (m as u64) * (n as u64) * (k as u64);
255 Ok(GemmStats { mma_count, flops, precision: self.precision })
256 }
257}
258
259#[derive(Debug, Clone)]
261pub struct GemmStats {
262 pub mma_count: u64,
264 pub flops: u64,
266 pub precision: MmaPrecision,
268}
269
270impl fmt::Display for GemmStats {
271 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
272 write!(f, "GEMM: {} MMA ops, {:.2}M FLOPs, {:?}",
273 self.mma_count, self.flops as f64 / 1e6, self.precision)
274 }
275}
276
277#[cfg(test)]
280mod tests {
281 use super::*;
282
283 #[test]
284 fn test_fragment_zeros() {
285 let frag = Fragment::zeros(4, 4);
286 assert_eq!(frag.data.len(), 16);
287 assert!(frag.data.iter().all(|&v| v == 0.0));
288 }
289
290 #[test]
291 fn test_fragment_from_f32() {
292 let data: Vec<f32> = (0..16).map(|i| i as f32).collect();
293 let frag = Fragment::from_f32(&data, 4, 4).unwrap();
294 assert_eq!(frag.get(0, 0), 0.0);
295 assert_eq!(frag.get(1, 2), 6.0);
296 assert_eq!(frag.get(3, 3), 15.0);
297 }
298
299 #[test]
300 fn test_mma_identity() {
301 let engine = TensorCoreEngine::new(MmaPrecision::Fp32, FragmentShape::new(2, 2, 2));
302
303 let a = Fragment::from_f32(&[1.0, 0.0, 0.0, 1.0], 2, 2).unwrap();
305 let b = Fragment::from_f32(&[5.0, 6.0, 7.0, 8.0], 2, 2).unwrap();
307 let c = Fragment::zeros(2, 2);
309
310 let d = engine.mma(&a, &b, &c).unwrap();
311 assert!((d.get(0, 0) - 5.0).abs() < 1e-6);
312 assert!((d.get(0, 1) - 6.0).abs() < 1e-6);
313 assert!((d.get(1, 0) - 7.0).abs() < 1e-6);
314 assert!((d.get(1, 1) - 8.0).abs() < 1e-6);
315 }
316
317 #[test]
318 fn test_mma_accumulate() {
319 let engine = TensorCoreEngine::new(MmaPrecision::Fp16Fp32, FragmentShape::new(2, 2, 2));
320
321 let a = Fragment::from_f32(&[1.0, 2.0, 3.0, 4.0], 2, 2).unwrap();
322 let b = Fragment::from_f32(&[5.0, 6.0, 7.0, 8.0], 2, 2).unwrap();
323 let c = Fragment::from_f32(&[10.0, 10.0, 10.0, 10.0], 2, 2).unwrap();
324
325 let d = engine.mma(&a, &b, &c).unwrap();
327 assert!((d.get(0, 0) - 29.0).abs() < 1e-6); assert!((d.get(0, 1) - 32.0).abs() < 1e-6); assert!((d.get(1, 0) - 53.0).abs() < 1e-6); assert!((d.get(1, 1) - 60.0).abs() < 1e-6); }
332
333 #[test]
334 fn test_mma_shape_validation() {
335 let engine = TensorCoreEngine::new(MmaPrecision::Fp32, FragmentShape::new(4, 4, 4));
336 let a = Fragment::zeros(2, 2); let b = Fragment::zeros(4, 4);
338 let c = Fragment::zeros(4, 4);
339 assert!(engine.mma(&a, &b, &c).is_err());
340 }
341
342 #[test]
343 fn test_gemm_basic() {
344 let engine = TensorCoreEngine::new(MmaPrecision::Fp32, FragmentShape::new(2, 2, 2));
345 let a = vec![1.0, 2.0, 3.0, 4.0]; let b = vec![5.0, 6.0, 7.0, 8.0]; let mut c = vec![0.0; 4]; let stats = engine.gemm(&a, &b, &mut c, 2, 2, 2, 1.0, 0.0).unwrap();
350 assert!((c[0] - 19.0).abs() < 1e-4); assert!((c[1] - 22.0).abs() < 1e-4);
352 assert!((c[2] - 43.0).abs() < 1e-4);
353 assert!((c[3] - 50.0).abs() < 1e-4);
354 assert_eq!(stats.flops, 16); }
356
357 #[test]
358 fn test_gemm_alpha_beta() {
359 let engine = TensorCoreEngine::new(MmaPrecision::Fp32, FragmentShape::new(2, 2, 2));
360 let a = vec![1.0, 0.0, 0.0, 1.0]; let b = vec![1.0, 2.0, 3.0, 4.0];
362 let mut c = vec![10.0, 10.0, 10.0, 10.0];
363
364 engine.gemm(&a, &b, &mut c, 2, 2, 2, 2.0, 0.5).unwrap();
366 assert!((c[0] - 7.0).abs() < 1e-4); assert!((c[1] - 9.0).abs() < 1e-4); }
369
370 #[test]
371 fn test_gemm_non_square() {
372 let engine = TensorCoreEngine::new(MmaPrecision::Fp32, FragmentShape::new(2, 2, 2));
373 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
375 let b = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
376 let mut c = vec![0.0; 12];
377
378 engine.gemm(&a, &b, &mut c, 3, 4, 2, 1.0, 0.0).unwrap();
379 assert!((c[0] - 11.0).abs() < 1e-4);
381 assert!((c[1] - 14.0).abs() < 1e-4);
382 assert!((c[2] - 17.0).abs() < 1e-4);
383 assert!((c[3] - 20.0).abs() < 1e-4);
384 }
385
386 #[test]
387 fn test_fragment_half_roundtrip() {
388 let data = vec![Half::from_f32(1.0), Half::from_f32(2.0), Half::from_f32(3.0), Half::from_f32(4.0)];
389 let frag = Fragment::from_half(&data, 2, 2).unwrap();
390 let back = frag.to_half();
391 for i in 0..4 {
392 assert!((back[i].to_f32() - data[i].to_f32()).abs() < 0.01);
393 }
394 }
395
396 #[test]
397 fn test_fragment_bf16_roundtrip() {
398 let data = vec![BFloat16::from_f32(1.5), BFloat16::from_f32(2.5)];
399 let frag = Fragment::from_bf16(&data, 1, 2).unwrap();
400 let back = frag.to_bf16();
401 assert!((back[0].to_f32() - 1.5).abs() < 0.1);
402 assert!((back[1].to_f32() - 2.5).abs() < 0.1);
403 }
404
405 #[test]
406 fn test_gemm_stats_display() {
407 let stats = GemmStats { mma_count: 64, flops: 1_000_000, precision: MmaPrecision::Fp16Fp32 };
408 let s = format!("{}", stats);
409 assert!(s.contains("64 MMA"));
410 assert!(s.contains("Fp16Fp32"));
411 }
412}