use crate::error::{Error, Result};
use crate::ops::cuda::kernels::{self, FUSED_GRAD_UNSCALE_CLIP_MODULE};
use crate::ops::impl_generic::training::dynamic_loss_scale_update_impl;
use crate::ops::traits::FusedFp8TrainingOps;
use cudarc::driver::PushKernelArg;
use cudarc::driver::safe::LaunchConfig;
use numr::dtype::DType;
use numr::runtime::Device;
use numr::runtime::cuda::{CudaClient, CudaRuntime};
use numr::tensor::Tensor;
fn kernel_suffix(dtype: DType) -> Result<&'static str> {
match dtype {
DType::F32 => Ok("f32"),
DType::F64 => Ok("f64"),
DType::F16 => Ok("f16"),
DType::BF16 => Ok("bf16"),
_ => Err(Error::InvalidArgument {
arg: "dtype",
reason: format!("fused_grad_unscale_clip: unsupported dtype {:?}", dtype),
}),
}
}
impl FusedFp8TrainingOps<CudaRuntime> for CudaClient {
fn fused_grad_unscale_clip(
&self,
grad: &Tensor<CudaRuntime>,
max_norm: f64,
loss_scale: f64,
) -> Result<(Tensor<CudaRuntime>, f64, bool)> {
let n: usize = grad.shape().iter().product();
let dtype = grad.dtype();
let suffix = kernel_suffix(dtype)?;
let out = grad.clone();
let device_index = grad.device().id();
let module = kernels::get_or_load_module(
self.context(),
device_index,
FUSED_GRAD_UNSCALE_CLIP_MODULE,
)?;
let stream = self.stream_arc();
let mut found_inf_dev = unsafe {
stream.alloc::<i32>(1).map_err(|e| Error::KernelError {
reason: format!("alloc found_inf: {:?}", e),
})?
};
stream
.memcpy_htod(&[0i32], &mut found_inf_dev)
.map_err(|e| Error::KernelError {
reason: format!("zero found_inf: {:?}", e),
})?;
let mut norm_sq_dev = unsafe {
stream.alloc::<f32>(1).map_err(|e| Error::KernelError {
reason: format!("alloc norm_sq: {:?}", e),
})?
};
stream
.memcpy_htod(&[0.0f32], &mut norm_sq_dev)
.map_err(|e| Error::KernelError {
reason: format!("zero norm_sq: {:?}", e),
})?;
let threads = 256u32;
let blocks = n.div_ceil(256) as u32;
let shared_bytes = threads * 4; let cfg = LaunchConfig {
grid_dim: (blocks, 1, 1),
block_dim: (threads, 1, 1),
shared_mem_bytes: shared_bytes,
};
let kernel_name = format!("fused_grad_unscale_clip_{}", suffix);
let func = kernels::get_kernel_function(&module, &kernel_name)?;
let out_ptr = out.ptr();
let grad_ptr = grad.ptr();
let n_i32 = n as i32;
let inv_scale_f = (1.0 / loss_scale) as f32;
let inv_scale_d = 1.0 / loss_scale;
unsafe {
let mut builder = self.stream().launch_builder(&func);
builder.arg(&out_ptr);
builder.arg(&grad_ptr);
builder.arg(&found_inf_dev);
builder.arg(&norm_sq_dev);
if dtype == DType::F64 {
builder.arg(&inv_scale_d);
} else {
builder.arg(&inv_scale_f);
}
builder.arg(&n_i32);
builder.launch(cfg).map_err(|e| Error::KernelError {
reason: format!("fused_grad_unscale_clip launch failed: {:?}", e),
})?;
}
if max_norm > 0.0 {
let clip_name = format!("clip_scale_{}", suffix);
let clip_func = kernels::get_kernel_function(&module, &clip_name)?;
let clip_cfg = LaunchConfig {
grid_dim: (blocks, 1, 1),
block_dim: (threads, 1, 1),
shared_mem_bytes: 0,
};
let max_norm_f = max_norm as f32;
unsafe {
let mut builder = self.stream().launch_builder(&clip_func);
builder.arg(&out_ptr);
builder.arg(&norm_sq_dev);
builder.arg(&found_inf_dev);
if dtype == DType::F64 {
builder.arg(&max_norm);
} else {
builder.arg(&max_norm_f);
}
builder.arg(&n_i32);
builder.launch(clip_cfg).map_err(|e| Error::KernelError {
reason: format!("clip_scale launch failed: {:?}", e),
})?;
}
}
let mut found_inf_host = [0i32];
stream
.memcpy_dtoh(&found_inf_dev, &mut found_inf_host)
.map_err(|e| Error::KernelError {
reason: format!("read found_inf: {:?}", e),
})?;
let mut norm_sq_host = [0.0f32];
stream
.memcpy_dtoh(&norm_sq_dev, &mut norm_sq_host)
.map_err(|e| Error::KernelError {
reason: format!("read norm_sq: {:?}", e),
})?;
let found_inf = found_inf_host[0] != 0;
let norm = (norm_sq_host[0] as f64).sqrt();
Ok((out, norm, found_inf))
}
fn dynamic_loss_scale_update(
&self,
found_inf: bool,
loss_scale: f64,
growth_tracker: i32,
growth_interval: i32,
backoff_factor: f64,
) -> Result<(f64, i32)> {
dynamic_loss_scale_update_impl(
found_inf,
loss_scale,
growth_tracker,
growth_interval,
backoff_factor,
)
}
}