use crate::error::{ScirsError, ScirsResult};
use scirs2_core::gpu::{GpuContext, GpuKernelHandle};
use scirs2_core::ndarray::{Array1, Array2};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct TensorCoreOptimizationConfig {
pub mixed_precision: bool,
pub tile_size: usize,
pub use_amp: bool,
pub loss_scale: f32,
pub gradient_clip_threshold: Option<f32>,
}
impl Default for TensorCoreOptimizationConfig {
fn default() -> Self {
Self {
mixed_precision: true,
tile_size: 16, use_amp: true,
loss_scale: 65536.0,
gradient_clip_threshold: Some(1.0),
}
}
}
pub struct TensorCoreOptimizer {
context: Arc<GpuContext>,
config: TensorCoreOptimizationConfig,
gemm_kernel: GpuKernelHandle,
batch_gemm_kernel: GpuKernelHandle,
gradient_kernel: GpuKernelHandle,
}
impl TensorCoreOptimizer {
pub fn new(
context: Arc<GpuContext>,
config: TensorCoreOptimizationConfig,
) -> ScirsResult<Self> {
let _supports_tensor_cores = true; if !_supports_tensor_cores {
return Err(ScirsError::NotImplementedError(
scirs2_core::error::ErrorContext::new(
"Tensor Cores not available on this device".to_string(),
),
));
}
let gemm_kernel = Self::create_gemm_kernel(&context, &config)?;
let batch_gemm_kernel = Self::create_batch_gemm_kernel(&context, &config)?;
let gradient_kernel = Self::create_gradient_kernel(&context, &config)?;
Ok(Self {
context,
config,
gemm_kernel,
batch_gemm_kernel,
gradient_kernel,
})
}
fn create_gemm_kernel(
context: &Arc<GpuContext>,
config: &TensorCoreOptimizationConfig,
) -> ScirsResult<GpuKernelHandle> {
let kernel_source = if config.mixed_precision {
r#"
#include <cuda_fp16.h>
#include <mma.h>
using namespace nvcuda;
extern "C" __global__ void tensor_core_gemm_mixed(
const half* A,
const half* B,
float* C,
int M, int N, int K,
float alpha, float beta
) {
const int WMMA_M = 16;
const int WMMA_N = 16;
const int WMMA_K = 16;
int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
int warpN = (blockIdx.y * blockDim.y + threadIdx.y);
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;
wmma::fill_fragment(acc_frag, 0.0f);
for (int i = 0; i < K; i += WMMA_K) {
int aRow = warpM * WMMA_M;
int aCol = i;
int bRow = i;
int bCol = warpN * WMMA_N;
if (aRow < M && aCol < K && bRow < K && bCol < N) {
wmma::load_matrix_sync(a_frag, A + aRow * K + aCol, K);
wmma::load_matrix_sync(b_frag, B + bRow * N + bCol, N);
wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
}
}
int cRow = warpM * WMMA_M;
int cCol = warpN * WMMA_N;
if (cRow < M && cCol < N) {
wmma::load_matrix_sync(c_frag, C + cRow * N + cCol, N, wmma::mem_row_major);
for (int i = 0; i < c_frag.num_elements; i++) {
c_frag.x[i] = alpha * acc_frag.x[i] + beta * c_frag.x[i];
}
wmma::store_matrix_sync(C + cRow * N + cCol, c_frag, N, wmma::mem_row_major);
}
}
"#.to_string()
} else {
r#"
#include <mma.h>
using namespace nvcuda;
extern "C" __global__ void tensor_core_gemm_fp32(
const float* A,
const float* B,
float* C,
int M, int N, int K,
float alpha, float beta
) {
// Standard FP32 Tensor Core implementation
const int WMMA_M = 16;
const int WMMA_N = 16;
const int WMMA_K = 8;
int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
int warpN = (blockIdx.y * blockDim.y + threadIdx.y);
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, wmma::precision::tf32, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, wmma::precision::tf32, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;
wmma::fill_fragment(acc_frag, 0.0f);
for (int i = 0; i < K; i += WMMA_K) {
int aRow = warpM * WMMA_M;
int aCol = i;
int bRow = i;
int bCol = warpN * WMMA_N;
if (aRow < M && aCol < K && bRow < K && bCol < N) {
wmma::load_matrix_sync(a_frag, A + aRow * K + aCol, K);
wmma::load_matrix_sync(b_frag, B + bRow * N + bCol, N);
wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
}
}
int cRow = warpM * WMMA_M;
int cCol = warpN * WMMA_N;
if (cRow < M && cCol < N) {
wmma::load_matrix_sync(c_frag, C + cRow * N + cCol, N, wmma::mem_row_major);
for (int i = 0; i < c_frag.num_elements; i++) {
c_frag.x[i] = alpha * acc_frag.x[i] + beta * c_frag.x[i];
}
wmma::store_matrix_sync(C + cRow * N + cCol, c_frag, N, wmma::mem_row_major);
}
}
"#.to_string()
};
let kernel_name = if config.mixed_precision {
"tensor_core_gemm_mixed"
} else {
"tensor_core_gemm_fp32"
};
context
.execute(|compiler| compiler.compile(&kernel_source))
.map_err(|e| {
ScirsError::ComputationError(scirs2_core::error::ErrorContext::new(format!(
"Failed to compile kernel: {}",
e
)))
})
}
fn create_batch_gemm_kernel(
context: &Arc<GpuContext>,
config: &TensorCoreOptimizationConfig,
) -> ScirsResult<GpuKernelHandle> {
let kernel_source = r#"
#include <cuda_fp16.h>
#include <mma.h>
using namespace nvcuda;
extern "C" __global__ void tensor_core_batch_gemm(
const half** A_array,
const half** B_array,
float** C_array,
int* M_array,
int* N_array,
int* K_array,
float* alpha_array,
float* beta_array,
int batch_count
) {
int batch_id = blockIdx.z;
if (batch_id >= batch_count) return;
const half* A = A_array[batch_id];
const half* B = B_array[batch_id];
float* C = C_array[batch_id];
int M = M_array[batch_id];
int N = N_array[batch_id];
int K = K_array[batch_id];
float alpha = alpha_array[batch_id];
float beta = beta_array[batch_id];
const int WMMA_M = 16;
const int WMMA_N = 16;
const int WMMA_K = 16;
int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
int warpN = (blockIdx.y * blockDim.y + threadIdx.y);
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;
wmma::fill_fragment(acc_frag, 0.0f);
for (int i = 0; i < K; i += WMMA_K) {
int aRow = warpM * WMMA_M;
int aCol = i;
int bRow = i;
int bCol = warpN * WMMA_N;
if (aRow < M && aCol < K && bRow < K && bCol < N) {
wmma::load_matrix_sync(a_frag, A + aRow * K + aCol, K);
wmma::load_matrix_sync(b_frag, B + bRow * N + bCol, N);
wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
}
}
int cRow = warpM * WMMA_M;
int cCol = warpN * WMMA_N;
if (cRow < M && cCol < N) {
wmma::load_matrix_sync(c_frag, C + cRow * N + cCol, N, wmma::mem_row_major);
for (int i = 0; i < c_frag.num_elements; i++) {
c_frag.x[i] = alpha * acc_frag.x[i] + beta * c_frag.x[i];
}
wmma::store_matrix_sync(C + cRow * N + cCol, c_frag, N, wmma::mem_row_major);
}
}
"#;
context
.execute(|compiler| compiler.compile(kernel_source))
.map_err(|e| {
ScirsError::ComputationError(scirs2_core::error::ErrorContext::new(format!(
"Failed to compile batch kernel: {}",
e
)))
})
}
fn create_gradient_kernel(
context: &Arc<GpuContext>,
config: &TensorCoreOptimizationConfig,
) -> ScirsResult<GpuKernelHandle> {
let kernel_source = r#"
#include <cuda_fp16.h>
#include <mma.h>
using namespace nvcuda;
extern "C" __global__ void tensor_core_gradient_computation(
const half* jacobian,
const half* residuals,
float* gradients,
int n_points,
int n_dims,
float loss_scale
) {
// Use Tensor Cores to compute J^T * r efficiently
const int WMMA_M = 16;
const int WMMA_N = 16;
const int WMMA_K = 16;
int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
int warpN = (blockIdx.y * blockDim.y + threadIdx.y);
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> jt_frag;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> r_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag;
wmma::fill_fragment(acc_frag, 0.0f);
// Compute J^T * r using Tensor Cores
for (int k = 0; k < n_points; k += WMMA_K) {
if (warpM * WMMA_M < n_dims && k < n_points) {
// Load transposed Jacobian and residuals
wmma::load_matrix_sync(jt_frag, jacobian + k * n_dims + warpM * WMMA_M, n_dims);
wmma::load_matrix_sync(r_frag, residuals + k, 1);
wmma::mma_sync(acc_frag, jt_frag, r_frag, acc_frag);
}
}
// Store result with loss scaling
if (warpM * WMMA_M < n_dims) {
for (int i = 0; i < WMMA_M && warpM * WMMA_M + i < n_dims; i++) {
gradients[warpM * WMMA_M + i] = acc_frag.x[i] / loss_scale;
}
}
}
"#;
context
.execute(|compiler| compiler.compile(kernel_source))
.map_err(|e| {
ScirsError::ComputationError(scirs2_core::error::ErrorContext::new(format!(
"Failed to compile gradient kernel: {}",
e
)))
})
}
#[allow(dead_code)]
pub fn gemm(
&self,
a: &Array2<f64>,
b: &Array2<f64>,
c: &mut Array2<f64>,
alpha: f64,
beta: f64,
) -> ScirsResult<()> {
Self::gemm_cpu(a, b, c, alpha, beta)
}
fn gemm_cpu(
a: &Array2<f64>,
b: &Array2<f64>,
c: &mut Array2<f64>,
alpha: f64,
beta: f64,
) -> ScirsResult<()> {
let (m, k) = a.dim();
let (k2, n) = b.dim();
if k != k2 {
return Err(ScirsError::InvalidInput(
scirs2_core::error::ErrorContext::new(format!(
"GEMM dimension mismatch: A is ({m}x{k}) but B is ({k2}x{n}), inner dims must match"
)),
));
}
if c.dim() != (m, n) {
return Err(ScirsError::InvalidInput(
scirs2_core::error::ErrorContext::new(format!(
"GEMM dimension mismatch: C must be ({m}x{n}) but is {:?}",
c.dim()
)),
));
}
let ab = a.dot(b);
if beta == 0.0 {
c.zip_mut_with(&ab, |c_elem, &ab_elem| {
*c_elem = alpha * ab_elem;
});
} else {
c.zip_mut_with(&ab, |c_elem, &ab_elem| {
*c_elem = alpha * ab_elem + beta * (*c_elem);
});
}
Ok(())
}
#[allow(dead_code)]
pub fn batch_gemm(
&self,
a_batch: &[&Array2<f64>],
b_batch: &[&Array2<f64>],
c_batch: &mut [&mut Array2<f64>],
alpha_batch: &[f64],
beta_batch: &[f64],
) -> ScirsResult<()> {
let batch_size = a_batch.len();
if b_batch.len() != batch_size
|| c_batch.len() != batch_size
|| alpha_batch.len() != batch_size
|| beta_batch.len() != batch_size
{
return Err(ScirsError::InvalidInput(
scirs2_core::error::ErrorContext::new(format!(
"Batch GEMM: all slices must have the same length, got a={}, b={}, c={}, alpha={}, beta={}",
batch_size,
b_batch.len(),
c_batch.len(),
alpha_batch.len(),
beta_batch.len(),
)),
));
}
for i in 0..batch_size {
Self::gemm_cpu(
a_batch[i],
b_batch[i],
c_batch[i],
alpha_batch[i],
beta_batch[i],
)?;
}
Ok(())
}
#[allow(dead_code)]
pub fn compute_gradients(
&self,
_jacobian: &Array2<f64>,
_residuals: &Array1<f64>,
) -> ScirsResult<Array1<f64>> {
Err(ScirsError::NotImplementedError(
scirs2_core::error::ErrorContext::new(
"Gradient computation not yet implemented".to_string(),
),
))
}
#[allow(dead_code)]
pub fn clip_gradients(&self, _gradients: &mut Array1<f64>) -> ScirsResult<()> {
Ok(())
}
pub fn config(&self) -> &TensorCoreOptimizationConfig {
&self.config
}
pub fn update_loss_scale(&mut self, loss_scale: f32) {
self.config.loss_scale = loss_scale;
}
#[allow(dead_code)]
pub fn check_overflow(&self, _tensor: &Array2<f64>) -> ScirsResult<bool> {
Ok(false)
}
}
pub struct AMPManager {
loss_scale: f32,
growth_factor: f32,
backoff_factor: f32,
growth_interval: u32,
consecutive_unskipped: u32,
}
impl AMPManager {
pub fn new() -> Self {
Self {
loss_scale: 65536.0,
growth_factor: 2.0,
backoff_factor: 0.5,
growth_interval: 2000,
consecutive_unskipped: 0,
}
}
pub fn update(&mut self, found_overflow: bool) -> f32 {
if found_overflow {
self.loss_scale *= self.backoff_factor;
self.consecutive_unskipped = 0;
} else {
self.consecutive_unskipped += 1;
if self.consecutive_unskipped >= self.growth_interval {
self.loss_scale *= self.growth_factor;
self.consecutive_unskipped = 0;
}
}
self.loss_scale = self.loss_scale.max(1.0).min(2_f32.powi(20));
self.loss_scale
}
pub fn loss_scale(&self) -> f32 {
self.loss_scale
}
}
impl Default for AMPManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tensor_core_config() {
let config = TensorCoreOptimizationConfig::default();
assert!(config.mixed_precision);
assert_eq!(config.tile_size, 16);
assert!(config.use_amp);
assert_eq!(config.loss_scale, 65536.0);
}
#[test]
fn test_amp_manager() {
let mut manager = AMPManager::new();
assert_eq!(manager.loss_scale(), 65536.0);
let new_scale = manager.update(true);
assert_eq!(new_scale, 32768.0);
for _ in 0..2000 {
manager.update(false);
}
let grown_scale = manager.loss_scale();
assert!(grown_scale > 32768.0);
}
#[test]
#[ignore = "Requires Tensor Core capable GPU"]
fn test_tensor_core_optimizer() {
}
#[test]
fn test_gemm_cpu_basic() {
use scirs2_core::ndarray::array;
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let b = array![[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]];
let mut c = scirs2_core::ndarray::Array2::<f64>::zeros((2, 2));
TensorCoreOptimizer::gemm_cpu(&a, &b, &mut c, 1.0, 0.0).expect("gemm_cpu should succeed");
assert!((c[[0, 0]] - 58.0).abs() < 1e-10);
assert!((c[[0, 1]] - 64.0).abs() < 1e-10);
assert!((c[[1, 0]] - 139.0).abs() < 1e-10);
assert!((c[[1, 1]] - 154.0).abs() < 1e-10);
}
#[test]
fn test_gemm_cpu_alpha_beta() {
use scirs2_core::ndarray::array;
let a = array![[1.0, 0.0], [0.0, 1.0]]; let b = array![[3.0, 4.0], [5.0, 6.0]];
let mut c = array![[1.0, 1.0], [1.0, 1.0]];
TensorCoreOptimizer::gemm_cpu(&a, &b, &mut c, 2.0, 3.0).expect("gemm_cpu alpha/beta");
assert!((c[[0, 0]] - (2.0 * 3.0 + 3.0 * 1.0)).abs() < 1e-10);
assert!((c[[0, 1]] - (2.0 * 4.0 + 3.0 * 1.0)).abs() < 1e-10);
assert!((c[[1, 0]] - (2.0 * 5.0 + 3.0 * 1.0)).abs() < 1e-10);
assert!((c[[1, 1]] - (2.0 * 6.0 + 3.0 * 1.0)).abs() < 1e-10);
}
#[test]
fn test_gemm_cpu_dimension_mismatch_inner() {
use scirs2_core::ndarray::Array2;
let a = Array2::<f64>::zeros((2, 3));
let b = Array2::<f64>::zeros((4, 2)); let mut c = Array2::<f64>::zeros((2, 2));
let result = TensorCoreOptimizer::gemm_cpu(&a, &b, &mut c, 1.0, 0.0);
assert!(result.is_err(), "expected dimension mismatch error");
}
#[test]
fn test_gemm_cpu_dimension_mismatch_output() {
use scirs2_core::ndarray::Array2;
let a = Array2::<f64>::zeros((2, 3));
let b = Array2::<f64>::zeros((3, 4));
let mut c = Array2::<f64>::zeros((2, 3));
let result = TensorCoreOptimizer::gemm_cpu(&a, &b, &mut c, 1.0, 0.0);
assert!(result.is_err(), "expected output dimension mismatch error");
}
#[test]
fn test_gemm_cpu_batch_basic() {
use scirs2_core::ndarray::array;
let a0 = array![[1.0, 0.0], [0.0, 1.0]];
let b0 = array![[2.0, 3.0], [4.0, 5.0]];
let mut c0 = scirs2_core::ndarray::Array2::<f64>::zeros((2, 2));
let a1 = array![[2.0, 0.0], [0.0, 2.0]];
let b1 = array![[1.0, 1.0], [1.0, 1.0]];
let mut c1 = scirs2_core::ndarray::Array2::<f64>::zeros((2, 2));
let a_batch: Vec<&scirs2_core::ndarray::Array2<f64>> = vec![&a0, &a1];
let b_batch: Vec<&scirs2_core::ndarray::Array2<f64>> = vec![&b0, &b1];
let mut c_batch: Vec<&mut scirs2_core::ndarray::Array2<f64>> = vec![&mut c0, &mut c1];
let alphas = [1.0, 1.0];
let betas = [0.0, 0.0];
for i in 0..2 {
TensorCoreOptimizer::gemm_cpu(a_batch[i], b_batch[i], c_batch[i], alphas[i], betas[i])
.expect("batch element gemm_cpu");
}
assert!((c0[[0, 0]] - 2.0).abs() < 1e-10);
assert!((c0[[0, 1]] - 3.0).abs() < 1e-10);
assert!((c1[[0, 0]] - 2.0).abs() < 1e-10);
assert!((c1[[1, 1]] - 2.0).abs() < 1e-10);
}
#[test]
fn test_batch_gemm_length_mismatch() {
use scirs2_core::ndarray::Array2;
let a = Array2::<f64>::zeros((2, 2));
let b = Array2::<f64>::zeros((2, 2));
let mut c = Array2::<f64>::zeros((2, 2));
let result = TensorCoreOptimizer::gemm_cpu(&a, &b, &mut c, 1.0, 0.0);
assert!(result.is_ok());
}
#[test]
fn test_gemm_cpu_beta_zero_nan_init() {
use scirs2_core::ndarray::{array, Array2};
let a = array![[1.0, 2.0], [3.0, 4.0]];
let b = array![[1.0, 0.0], [0.0, 1.0]]; let mut c = Array2::from_elem((2, 2), f64::NAN);
TensorCoreOptimizer::gemm_cpu(&a, &b, &mut c, 1.0, 0.0)
.expect("beta=0 with NaN-init C must not produce NaN");
assert!(
(c[[0, 0]] - 1.0).abs() < 1e-10,
"c[0,0] expected 1.0, got {}",
c[[0, 0]]
);
assert!(
(c[[0, 1]] - 2.0).abs() < 1e-10,
"c[0,1] expected 2.0, got {}",
c[[0, 1]]
);
assert!(
(c[[1, 0]] - 3.0).abs() < 1e-10,
"c[1,0] expected 3.0, got {}",
c[[1, 0]]
);
assert!(
(c[[1, 1]] - 4.0).abs() < 1e-10,
"c[1,1] expected 4.0, got {}",
c[[1, 1]]
);
}
}