use std::any::TypeId;
use ferrotorch_core::gpu_dispatch::gpu_backend;
use ferrotorch_core::numeric_cast::cast;
use ferrotorch_core::{Device, FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
use crate::parameter::Parameter;
#[inline]
fn is_f32<T: Float>() -> bool {
TypeId::of::<T>() == TypeId::of::<f32>()
}
#[inline]
fn is_f64<T: Float>() -> bool {
TypeId::of::<T>() == TypeId::of::<f64>()
}
fn readback_scalar_f32(
handle: &ferrotorch_core::gpu_dispatch::GpuBufferHandle,
backend: &dyn ferrotorch_core::gpu_dispatch::GpuBackend,
) -> FerrotorchResult<f32> {
let bytes = backend.gpu_to_cpu(handle)?;
if bytes.len() < 4 {
return Err(FerrotorchError::InvalidArgument {
message: "readback_scalar_f32: buffer is empty".into(),
});
}
let val = f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
Ok(val)
}
fn readback_scalar_f64(
handle: &ferrotorch_core::gpu_dispatch::GpuBufferHandle,
backend: &dyn ferrotorch_core::gpu_dispatch::GpuBackend,
) -> FerrotorchResult<f64> {
let bytes = backend.gpu_to_cpu(handle)?;
if bytes.len() < 8 {
return Err(FerrotorchError::InvalidArgument {
message: "readback_scalar_f64: buffer is empty".into(),
});
}
let val = f64::from_le_bytes([
bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
]);
Ok(val)
}
pub fn clip_grad_norm_<T: Float>(
params: &[&Parameter<T>],
max_norm: f64,
norm_type: f64,
) -> FerrotorchResult<f64> {
let grads: Vec<Tensor<T>> = params
.iter()
.filter_map(|p| p.grad().ok().flatten())
.collect();
if grads.is_empty() {
return Ok(0.0);
}
let common_device = grads[0].device();
for g in &grads[1..] {
if g.device() != common_device {
return Err(FerrotorchError::DeviceMismatch {
expected: common_device,
got: g.device(),
});
}
}
match common_device {
Device::Cpu => clip_grad_norm_cpu(params, &grads, max_norm, norm_type),
Device::Cuda(_ordinal) => {
if !is_f32::<T>() && !is_f64::<T>() {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "clip_grad_norm_ on CUDA requires f32 or f64 gradients",
});
}
if norm_type != 2.0 {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "clip_grad_norm_ on CUDA only supports norm_type == 2.0; \
move gradients to CPU for other norms",
});
}
clip_grad_norm_cuda(params, &grads, max_norm)
}
_ => Err(FerrotorchError::DeviceUnavailable),
}
}
fn clip_grad_norm_cpu<T: Float>(
params: &[&Parameter<T>],
grads: &[Tensor<T>],
max_norm: f64,
norm_type: f64,
) -> FerrotorchResult<f64> {
let total_norm: f64 = if norm_type == f64::INFINITY {
let mut max_val: f64 = 0.0;
for g in grads {
let data = g.data_vec()?;
for v in &data {
let abs_v = cast::<T, f64>(*v)?.abs();
if abs_v > max_val {
max_val = abs_v;
}
}
}
max_val
} else {
let mut accum: f64 = 0.0;
for g in grads {
let data = g.data_vec()?;
for v in &data {
accum += cast::<T, f64>(*v)?.abs().powf(norm_type);
}
}
accum.powf(1.0 / norm_type)
};
if total_norm > max_norm {
let clip_coef = max_norm / total_norm;
let clip_t: T = cast(clip_coef)?;
for param in params {
if let Some(g) = param.grad()? {
let data = g.data_vec()?;
let scaled: Vec<T> = data.iter().map(|&v| v * clip_t).collect();
let new_grad =
Tensor::from_storage(TensorStorage::cpu(scaled), g.shape().to_vec(), false)?;
param.set_grad(Some(new_grad))?;
}
}
}
Ok(total_norm)
}
fn clip_grad_norm_cuda<T: Float>(
params: &[&Parameter<T>],
grads: &[Tensor<T>],
max_norm: f64,
) -> FerrotorchResult<f64> {
let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let mut total_sq: f64 = 0.0;
for g in grads {
let g_handle = g.gpu_handle()?;
let numel = g.numel();
let per_tensor_sq: f64 = if is_f32::<T>() {
let sq_handle = backend.mul_f32(g_handle, g_handle)?;
let sum_handle = backend.sum_f32(&sq_handle, numel)?;
let scalar = readback_scalar_f32(&sum_handle, backend)?;
scalar as f64
} else {
let sq_handle = backend.mul_f64(g_handle, g_handle)?;
let sum_handle = backend.sum_f64(&sq_handle, numel)?;
readback_scalar_f64(&sum_handle, backend)?
};
total_sq += per_tensor_sq;
}
let total_norm = total_sq.sqrt();
if total_norm > max_norm {
let clip_coef = max_norm / total_norm;
for param in params {
if let Some(g) = param.grad()? {
let g_handle = g.gpu_handle()?;
let shape = g.shape().to_vec();
let ordinal = match g.device() {
Device::Cuda(o) => o,
_ => unreachable!(),
};
let scaled_handle = if is_f32::<T>() {
#[allow(clippy::cast_possible_truncation)]
backend.scale_f32(g_handle, clip_coef as f32)?
} else {
backend.scale_f64(g_handle, clip_coef)?
};
let new_storage = TensorStorage::gpu(scaled_handle);
let _ = ordinal;
let new_grad = Tensor::from_storage(new_storage, shape, false)?;
param.set_grad(Some(new_grad))?;
}
}
}
Ok(total_norm)
}
pub fn clip_grad_value_<T: Float>(
params: &[&Parameter<T>],
clip_value: f64,
) -> FerrotorchResult<()> {
let grads: Vec<Tensor<T>> = params
.iter()
.filter_map(|p| p.grad().ok().flatten())
.collect();
if grads.is_empty() {
return Ok(());
}
let common_device = grads[0].device();
for g in &grads[1..] {
if g.device() != common_device {
return Err(FerrotorchError::DeviceMismatch {
expected: common_device,
got: g.device(),
});
}
}
match common_device {
Device::Cpu => clip_grad_value_cpu(params, clip_value),
Device::Cuda(_) => {
if !is_f32::<T>() && !is_f64::<T>() {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "clip_grad_value_ on CUDA requires f32 or f64 gradients",
});
}
clip_grad_value_cuda(params, clip_value)
}
_ => Err(FerrotorchError::DeviceUnavailable),
}
}
fn clip_grad_value_cpu<T: Float>(
params: &[&Parameter<T>],
clip_value: f64,
) -> FerrotorchResult<()> {
let lo: T = cast(-clip_value)?;
let hi: T = cast(clip_value)?;
for param in params {
if let Some(g) = param.grad()? {
let data = g.data_vec()?;
let clamped: Vec<T> = data
.iter()
.map(|&v| {
if v < lo {
lo
} else if v > hi {
hi
} else {
v
}
})
.collect();
let new_grad =
Tensor::from_storage(TensorStorage::cpu(clamped), g.shape().to_vec(), false)?;
param.set_grad(Some(new_grad))?;
}
}
Ok(())
}
fn clip_grad_value_cuda<T: Float>(
params: &[&Parameter<T>],
clip_value: f64,
) -> FerrotorchResult<()> {
let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
for param in params {
if let Some(g) = param.grad()? {
let g_handle = g.gpu_handle()?;
let shape = g.shape().to_vec();
let clamped_handle = if is_f32::<T>() {
#[allow(clippy::cast_possible_truncation)]
backend.clamp_f32(g_handle, -(clip_value as f32), clip_value as f32)?
} else {
backend.clamp_f64(g_handle, -clip_value, clip_value)?
};
let new_storage = TensorStorage::gpu(clamped_handle);
let new_grad = Tensor::from_storage(new_storage, shape, false)?;
param.set_grad(Some(new_grad))?;
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn param_with_grad(shape: &[usize], grad_data: &[f32]) -> Parameter<f32> {
let p = Parameter::<f32>::zeros(shape).unwrap();
let grad = Tensor::from_storage(
TensorStorage::cpu(grad_data.to_vec()),
shape.to_vec(),
false,
)
.unwrap();
p.set_grad(Some(grad)).unwrap();
p
}
#[test]
fn test_clip_grad_norm_reduces_norm() {
let p = param_with_grad(&[2], &[3.0, 4.0]);
let total = clip_grad_norm_(&[&p], 2.5, 2.0).unwrap();
assert!((total - 5.0).abs() < 1e-6);
let g = p.grad().unwrap().unwrap();
let d = g.data().unwrap();
let new_norm = (d[0] as f64 * d[0] as f64 + d[1] as f64 * d[1] as f64).sqrt();
assert!((new_norm - 2.5).abs() < 1e-4);
}
#[test]
fn test_clip_grad_norm_no_clip_when_below() {
let p = param_with_grad(&[2], &[1.0, 0.0]);
let total = clip_grad_norm_(&[&p], 10.0, 2.0).unwrap();
assert!((total - 1.0).abs() < 1e-6);
let g = p.grad().unwrap().unwrap();
let d = g.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-6);
assert!((d[1] - 0.0).abs() < 1e-6);
}
#[test]
fn test_clip_grad_norm_multiple_params() {
let p1 = param_with_grad(&[1], &[3.0]);
let p2 = param_with_grad(&[1], &[4.0]);
let total = clip_grad_norm_(&[&p1, &p2], 2.5, 2.0).unwrap();
assert!((total - 5.0).abs() < 1e-6);
let g1 = p1.grad().unwrap().unwrap().data().unwrap()[0] as f64;
let g2 = p2.grad().unwrap().unwrap().data().unwrap()[0] as f64;
let new_norm = (g1 * g1 + g2 * g2).sqrt();
assert!((new_norm - 2.5).abs() < 1e-4);
}
#[test]
fn test_clip_grad_norm_returns_total_norm() {
let p = param_with_grad(&[3], &[1.0, 2.0, 2.0]);
let total = clip_grad_norm_(&[&p], 100.0, 2.0).unwrap();
assert!((total - 3.0).abs() < 1e-6);
}
#[test]
fn test_clip_grad_norm_skips_none_grads() {
let p_with = param_with_grad(&[2], &[3.0, 4.0]);
let p_without = Parameter::<f32>::zeros(&[2]).unwrap();
let total = clip_grad_norm_(&[&p_with, &p_without], 2.5, 2.0).unwrap();
assert!((total - 5.0).abs() < 1e-6);
}
#[test]
fn test_clip_grad_value_clamps_elements() {
let p = param_with_grad(&[4], &[-5.0, 0.5, 3.0, -0.1]);
clip_grad_value_(&[&p], 1.0).unwrap();
let g = p.grad().unwrap().unwrap();
let d = g.data().unwrap();
assert!((d[0] - (-1.0)).abs() < 1e-6);
assert!((d[1] - 0.5).abs() < 1e-6);
assert!((d[2] - 1.0).abs() < 1e-6);
assert!((d[3] - (-0.1)).abs() < 1e-6);
}
#[test]
fn test_clip_grad_value_skips_none_grads() {
let p = Parameter::<f32>::zeros(&[3]).unwrap();
clip_grad_value_(&[&p], 1.0).unwrap();
assert!(p.grad().unwrap().is_none());
}
#[test]
fn test_clip_grad_value_preserves_within_range() {
let p = param_with_grad(&[3], &[0.1, -0.2, 0.3]);
clip_grad_value_(&[&p], 1.0).unwrap();
let g = p.grad().unwrap().unwrap();
let d = g.data().unwrap();
assert!((d[0] - 0.1).abs() < 1e-6);
assert!((d[1] - (-0.2)).abs() < 1e-6);
assert!((d[2] - 0.3).abs() < 1e-6);
}
#[cfg(feature = "cuda")]
fn param_with_gpu_grad(shape: &[usize], grad_data: &[f32]) -> Parameter<f32> {
let p = Parameter::<f32>::zeros(shape).unwrap();
let grad_cpu = Tensor::from_storage(
TensorStorage::cpu(grad_data.to_vec()),
shape.to_vec(),
false,
)
.unwrap();
let grad_gpu = grad_cpu.to(Device::Cuda(0)).unwrap();
p.set_grad(Some(grad_gpu)).unwrap();
p
}
#[test]
#[cfg(feature = "cuda")]
fn test_gpu_clip_grad_norm_l2_f32() {
use ferrotorch_gpu::init_cuda_backend;
init_cuda_backend().expect("CUDA init failed");
let p_cpu = param_with_grad(&[2], &[3.0, 4.0]);
let ref_norm = clip_grad_norm_(&[&p_cpu], 2.5, 2.0).unwrap();
let ref_g = p_cpu.grad().unwrap().unwrap();
let ref_d = ref_g.data().unwrap().to_vec();
let p_gpu = param_with_gpu_grad(&[2], &[3.0, 4.0]);
let gpu_norm = clip_grad_norm_(&[&p_gpu], 2.5, 2.0).unwrap();
assert!(
(gpu_norm - ref_norm).abs() < 1e-5,
"GPU norm {gpu_norm} != CPU norm {ref_norm}"
);
let g_after = p_gpu.grad().unwrap().unwrap();
assert!(
g_after.is_cuda(),
"gradient should stay on CUDA after clip_grad_norm_"
);
let gpu_d = g_after.data_vec().unwrap();
for (i, (&gv, &rv)) in gpu_d.iter().zip(ref_d.iter()).enumerate() {
assert!((gv - rv).abs() < 1e-5, "gradient[{i}]: GPU={gv} CPU={rv}");
}
}
#[test]
#[cfg(feature = "cuda")]
fn test_gpu_clip_grad_norm_l2_f64() {
use ferrotorch_gpu::init_cuda_backend;
init_cuda_backend().expect("CUDA init failed");
let p_gpu = {
let p = Parameter::<f64>::zeros(&[3]).unwrap();
let grad_cpu = Tensor::<f64>::from_storage(
TensorStorage::cpu(vec![1.0_f64, 2.0, 2.0]),
vec![3],
false,
)
.unwrap();
let grad_gpu = grad_cpu.to(Device::Cuda(0)).unwrap();
p.set_grad(Some(grad_gpu)).unwrap();
p
};
let norm = clip_grad_norm_(&[&p_gpu], 10.0, 2.0).unwrap();
assert!(
(norm - 3.0).abs() < 1e-9,
"expected L2 norm 3.0, got {norm}"
);
assert!(
p_gpu.grad().unwrap().unwrap().is_cuda(),
"f64 gradient should stay on CUDA after clip_grad_norm_"
);
}
#[test]
#[cfg(feature = "cuda")]
fn test_gpu_clip_grad_value_f32() {
use ferrotorch_gpu::init_cuda_backend;
init_cuda_backend().expect("CUDA init failed");
let data = [-5.0_f32, 0.5, 3.0, -0.1];
let p_cpu = param_with_grad(&[4], &data);
clip_grad_value_(&[&p_cpu], 1.0).unwrap();
let ref_d = p_cpu.grad().unwrap().unwrap().data().unwrap().to_vec();
let p_gpu = param_with_gpu_grad(&[4], &data);
clip_grad_value_(&[&p_gpu], 1.0).unwrap();
let g_after = p_gpu.grad().unwrap().unwrap();
assert!(
g_after.is_cuda(),
"gradient should stay on CUDA after clip_grad_value_"
);
let gpu_d = g_after.data_vec().unwrap();
for (i, (&gv, &rv)) in gpu_d.iter().zip(ref_d.iter()).enumerate() {
assert!((gv - rv).abs() < 1e-5, "clamped[{i}]: GPU={gv} CPU={rv}");
}
}
#[test]
#[cfg(feature = "cuda")]
fn test_mixed_device_returns_device_mismatch() {
use ferrotorch_gpu::init_cuda_backend;
init_cuda_backend().expect("CUDA init failed");
let p_cpu = param_with_grad(&[2], &[1.0, 2.0]);
let p_gpu = param_with_gpu_grad(&[2], &[3.0, 4.0]);
let result = clip_grad_norm_(&[&p_cpu, &p_gpu], 5.0, 2.0);
assert!(
matches!(result, Err(FerrotorchError::DeviceMismatch { .. })),
"expected DeviceMismatch, got {result:?}"
);
let result2 = clip_grad_value_(&[&p_cpu, &p_gpu], 1.0);
assert!(
matches!(result2, Err(FerrotorchError::DeviceMismatch { .. })),
"expected DeviceMismatch for clip_grad_value_, got {result2:?}"
);
}
#[test]
#[cfg(feature = "cuda")]
fn test_non_l2_cuda_returns_error() {
use ferrotorch_gpu::init_cuda_backend;
init_cuda_backend().expect("CUDA init failed");
let p_gpu = param_with_gpu_grad(&[2], &[1.0, 2.0]);
let result = clip_grad_norm_(&[&p_gpu], 5.0, 1.0);
assert!(
matches!(result, Err(FerrotorchError::NotImplementedOnCuda { .. })),
"expected NotImplementedOnCuda for L1 norm on CUDA, got {result:?}"
);
let p_gpu2 = param_with_gpu_grad(&[2], &[1.0, 2.0]);
let result2 = clip_grad_norm_(&[&p_gpu2], 5.0, f64::INFINITY);
assert!(
matches!(result2, Err(FerrotorchError::NotImplementedOnCuda { .. })),
"expected NotImplementedOnCuda for inf norm on CUDA, got {result2:?}"
);
}
}