use super::half::Half;
use super::bfloat16::BFloat16;
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum MmaPrecision {
Fp16Fp32,
Bf16Fp32,
Tf32,
Int8Int32,
Fp32,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct FragmentShape {
pub m: usize,
pub n: usize,
pub k: usize,
}
impl FragmentShape {
pub const M16N16K16: Self = Self { m: 16, n: 16, k: 16 };
pub const M16N8K16: Self = Self { m: 16, n: 8, k: 16 };
pub const M8N32K16: Self = Self { m: 8, n: 32, k: 16 };
pub fn new(m: usize, n: usize, k: usize) -> Self {
Self { m, n, k }
}
}
#[derive(Debug, Clone)]
pub struct Fragment {
pub data: Vec<f32>,
pub rows: usize,
pub cols: usize,
}
impl Fragment {
pub fn zeros(rows: usize, cols: usize) -> Self {
Self {
data: vec![0.0; rows * cols],
rows,
cols,
}
}
pub fn from_f32(data: &[f32], rows: usize, cols: usize) -> crate::Result<Self> {
if data.len() != rows * cols {
return Err(crate::error::CudaRustError::RuntimeError(
format!("Fragment size mismatch: {}×{} needs {} elements, got {}",
rows, cols, rows * cols, data.len()),
));
}
Ok(Self {
data: data.to_vec(),
rows,
cols,
})
}
pub fn from_half(data: &[Half], rows: usize, cols: usize) -> crate::Result<Self> {
if data.len() != rows * cols {
return Err(crate::error::CudaRustError::RuntimeError(
format!("Fragment size mismatch: expected {} elements, got {}", rows * cols, data.len()),
));
}
Ok(Self {
data: data.iter().map(|h| h.to_f32()).collect(),
rows,
cols,
})
}
pub fn from_bf16(data: &[BFloat16], rows: usize, cols: usize) -> crate::Result<Self> {
if data.len() != rows * cols {
return Err(crate::error::CudaRustError::RuntimeError(
format!("Fragment size mismatch: expected {} elements, got {}", rows * cols, data.len()),
));
}
Ok(Self {
data: data.iter().map(|b| b.to_f32()).collect(),
rows,
cols,
})
}
pub fn get(&self, row: usize, col: usize) -> f32 {
self.data[row * self.cols + col]
}
pub fn set(&mut self, row: usize, col: usize, val: f32) {
self.data[row * self.cols + col] = val;
}
pub fn to_half(&self) -> Vec<Half> {
self.data.iter().map(|&v| Half::from_f32(v)).collect()
}
pub fn to_bf16(&self) -> Vec<BFloat16> {
self.data.iter().map(|&v| BFloat16::from_f32(v)).collect()
}
}
pub struct TensorCoreEngine {
precision: MmaPrecision,
shape: FragmentShape,
}
impl TensorCoreEngine {
pub fn new(precision: MmaPrecision, shape: FragmentShape) -> Self {
Self { precision, shape }
}
pub fn mma(&self, a: &Fragment, b: &Fragment, c: &Fragment) -> crate::Result<Fragment> {
if a.rows != self.shape.m || a.cols != self.shape.k {
return Err(crate::error::CudaRustError::RuntimeError(
format!("Fragment A shape {}×{} doesn't match MMA {}×{}",
a.rows, a.cols, self.shape.m, self.shape.k),
));
}
if b.rows != self.shape.k || b.cols != self.shape.n {
return Err(crate::error::CudaRustError::RuntimeError(
format!("Fragment B shape {}×{} doesn't match MMA {}×{}",
b.rows, b.cols, self.shape.k, self.shape.n),
));
}
if c.rows != self.shape.m || c.cols != self.shape.n {
return Err(crate::error::CudaRustError::RuntimeError(
format!("Fragment C shape {}×{} doesn't match MMA {}×{}",
c.rows, c.cols, self.shape.m, self.shape.n),
));
}
let m = self.shape.m;
let n = self.shape.n;
let k = self.shape.k;
let mut d = Fragment::zeros(m, n);
for i in 0..m {
for j in 0..n {
let mut acc = c.get(i, j);
for p in 0..k {
acc += a.get(i, p) * b.get(p, j);
}
d.set(i, j, acc);
}
}
Ok(d)
}
pub fn gemm(
&self,
a: &[f32], b: &[f32], c: &mut [f32],
m: usize, n: usize, k: usize,
alpha: f32, beta: f32,
) -> crate::Result<GemmStats> {
if a.len() != m * k || b.len() != k * n || c.len() != m * n {
return Err(crate::error::CudaRustError::RuntimeError("GEMM dimension mismatch".into()));
}
let tm = self.shape.m;
let tn = self.shape.n;
let tk = self.shape.k;
let mut mma_count = 0u64;
for val in c.iter_mut() {
*val *= beta;
}
let m_tiles = (m + tm - 1) / tm;
let n_tiles = (n + tn - 1) / tn;
let k_tiles = (k + tk - 1) / tk;
for mi in 0..m_tiles {
let m_start = mi * tm;
let m_end = (m_start + tm).min(m);
let actual_m = m_end - m_start;
for ni in 0..n_tiles {
let n_start = ni * tn;
let n_end = (n_start + tn).min(n);
let actual_n = n_end - n_start;
for ki in 0..k_tiles {
let k_start = ki * tk;
let k_end = (k_start + tk).min(k);
let actual_k = k_end - k_start;
for i in 0..actual_m {
for j in 0..actual_n {
let mut acc = 0.0f32;
for p in 0..actual_k {
acc += a[(m_start + i) * k + (k_start + p)]
* b[(k_start + p) * n + (n_start + j)];
}
c[(m_start + i) * n + (n_start + j)] += alpha * acc;
}
}
mma_count += 1;
}
}
}
let flops = 2 * (m as u64) * (n as u64) * (k as u64);
Ok(GemmStats { mma_count, flops, precision: self.precision })
}
}
#[derive(Debug, Clone)]
pub struct GemmStats {
pub mma_count: u64,
pub flops: u64,
pub precision: MmaPrecision,
}
impl fmt::Display for GemmStats {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "GEMM: {} MMA ops, {:.2}M FLOPs, {:?}",
self.mma_count, self.flops as f64 / 1e6, self.precision)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fragment_zeros() {
let frag = Fragment::zeros(4, 4);
assert_eq!(frag.data.len(), 16);
assert!(frag.data.iter().all(|&v| v == 0.0));
}
#[test]
fn test_fragment_from_f32() {
let data: Vec<f32> = (0..16).map(|i| i as f32).collect();
let frag = Fragment::from_f32(&data, 4, 4).unwrap();
assert_eq!(frag.get(0, 0), 0.0);
assert_eq!(frag.get(1, 2), 6.0);
assert_eq!(frag.get(3, 3), 15.0);
}
#[test]
fn test_mma_identity() {
let engine = TensorCoreEngine::new(MmaPrecision::Fp32, FragmentShape::new(2, 2, 2));
let a = Fragment::from_f32(&[1.0, 0.0, 0.0, 1.0], 2, 2).unwrap();
let b = Fragment::from_f32(&[5.0, 6.0, 7.0, 8.0], 2, 2).unwrap();
let c = Fragment::zeros(2, 2);
let d = engine.mma(&a, &b, &c).unwrap();
assert!((d.get(0, 0) - 5.0).abs() < 1e-6);
assert!((d.get(0, 1) - 6.0).abs() < 1e-6);
assert!((d.get(1, 0) - 7.0).abs() < 1e-6);
assert!((d.get(1, 1) - 8.0).abs() < 1e-6);
}
#[test]
fn test_mma_accumulate() {
let engine = TensorCoreEngine::new(MmaPrecision::Fp16Fp32, FragmentShape::new(2, 2, 2));
let a = Fragment::from_f32(&[1.0, 2.0, 3.0, 4.0], 2, 2).unwrap();
let b = Fragment::from_f32(&[5.0, 6.0, 7.0, 8.0], 2, 2).unwrap();
let c = Fragment::from_f32(&[10.0, 10.0, 10.0, 10.0], 2, 2).unwrap();
let d = engine.mma(&a, &b, &c).unwrap();
assert!((d.get(0, 0) - 29.0).abs() < 1e-6); assert!((d.get(0, 1) - 32.0).abs() < 1e-6); assert!((d.get(1, 0) - 53.0).abs() < 1e-6); assert!((d.get(1, 1) - 60.0).abs() < 1e-6); }
#[test]
fn test_mma_shape_validation() {
let engine = TensorCoreEngine::new(MmaPrecision::Fp32, FragmentShape::new(4, 4, 4));
let a = Fragment::zeros(2, 2); let b = Fragment::zeros(4, 4);
let c = Fragment::zeros(4, 4);
assert!(engine.mma(&a, &b, &c).is_err());
}
#[test]
fn test_gemm_basic() {
let engine = TensorCoreEngine::new(MmaPrecision::Fp32, FragmentShape::new(2, 2, 2));
let a = vec![1.0, 2.0, 3.0, 4.0]; let b = vec![5.0, 6.0, 7.0, 8.0]; let mut c = vec![0.0; 4];
let stats = engine.gemm(&a, &b, &mut c, 2, 2, 2, 1.0, 0.0).unwrap();
assert!((c[0] - 19.0).abs() < 1e-4); assert!((c[1] - 22.0).abs() < 1e-4);
assert!((c[2] - 43.0).abs() < 1e-4);
assert!((c[3] - 50.0).abs() < 1e-4);
assert_eq!(stats.flops, 16); }
#[test]
fn test_gemm_alpha_beta() {
let engine = TensorCoreEngine::new(MmaPrecision::Fp32, FragmentShape::new(2, 2, 2));
let a = vec![1.0, 0.0, 0.0, 1.0]; let b = vec![1.0, 2.0, 3.0, 4.0];
let mut c = vec![10.0, 10.0, 10.0, 10.0];
engine.gemm(&a, &b, &mut c, 2, 2, 2, 2.0, 0.5).unwrap();
assert!((c[0] - 7.0).abs() < 1e-4); assert!((c[1] - 9.0).abs() < 1e-4); }
#[test]
fn test_gemm_non_square() {
let engine = TensorCoreEngine::new(MmaPrecision::Fp32, FragmentShape::new(2, 2, 2));
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let mut c = vec![0.0; 12];
engine.gemm(&a, &b, &mut c, 3, 4, 2, 1.0, 0.0).unwrap();
assert!((c[0] - 11.0).abs() < 1e-4);
assert!((c[1] - 14.0).abs() < 1e-4);
assert!((c[2] - 17.0).abs() < 1e-4);
assert!((c[3] - 20.0).abs() < 1e-4);
}
#[test]
fn test_fragment_half_roundtrip() {
let data = vec![Half::from_f32(1.0), Half::from_f32(2.0), Half::from_f32(3.0), Half::from_f32(4.0)];
let frag = Fragment::from_half(&data, 2, 2).unwrap();
let back = frag.to_half();
for i in 0..4 {
assert!((back[i].to_f32() - data[i].to_f32()).abs() < 0.01);
}
}
#[test]
fn test_fragment_bf16_roundtrip() {
let data = vec![BFloat16::from_f32(1.5), BFloat16::from_f32(2.5)];
let frag = Fragment::from_bf16(&data, 1, 2).unwrap();
let back = frag.to_bf16();
assert!((back[0].to_f32() - 1.5).abs() < 0.1);
assert!((back[1].to_f32() - 2.5).abs() < 0.1);
}
#[test]
fn test_gemm_stats_display() {
let stats = GemmStats { mma_count: 64, flops: 1_000_000, precision: MmaPrecision::Fp16Fp32 };
let s = format!("{}", stats);
assert!(s.contains("64 MMA"));
assert!(s.contains("Fp16Fp32"));
}
}