#[cfg(feature = "cuda")]
use crate::autograd::cuda_optim::{
clip_scale_reduce_cuda, gradient_clip_gpu_scale_cuda, squared_sum_launch_into, FusedClipState,
};
#[cfg(feature = "cuda")]
use crate::transformer::cuda_block::CudaLoraGradWorkspace;
#[cfg(feature = "cuda")]
use trueno_gpu::driver::{CudaContext, CudaStream, GpuBuffer};
#[cfg(feature = "cuda")]
use trueno_gpu::kernels::SquaredSumKernel;
#[cfg(feature = "cuda")]
pub(crate) fn init_lora_fused_clip(
ws: &CudaLoraGradWorkspace,
ctx: &std::sync::Arc<CudaContext>,
) -> Option<FusedClipState> {
let sizes: [u32; 6] = [
ws.grad_lora_a_q.len() as u32,
ws.grad_lora_b_q.len() as u32,
ws.grad_lora_a_v.len() as u32,
ws.grad_lora_b_v.len() as u32,
ws.grad_input_norm.len() as u32,
ws.grad_post_attn_norm.len() as u32,
];
let mut offsets = [0u32; 9]; let mut total = 0u32;
for (i, &n) in sizes.iter().enumerate() {
offsets[i] = total;
let kernel = SquaredSumKernel::new(n);
total += kernel.num_blocks();
}
let partials_buf = GpuBuffer::<f32>::new(ctx, total as usize).ok()?;
let scale_buf = GpuBuffer::<f32>::new(ctx, 2).ok()?;
Some(FusedClipState {
partials_buf,
scale_buf,
offsets,
num_blocks: [0; 9],
total_partials: total,
})
}
#[cfg(feature = "cuda")]
pub(crate) fn clip_lora_gradients_fused(
ws: &mut CudaLoraGradWorkspace,
max_norm: f32,
state: &FusedClipState,
stream: &CudaStream,
) {
let bufs: [&GpuBuffer<f32>; 6] = [
&ws.grad_lora_a_q,
&ws.grad_lora_b_q,
&ws.grad_lora_a_v,
&ws.grad_lora_b_v,
&ws.grad_input_norm,
&ws.grad_post_attn_norm,
];
for (i, buf) in bufs.iter().enumerate() {
let n = buf.len() as u32;
if n == 0 {
continue;
}
let output_ptr = state.partials_buf.as_ptr() + u64::from(state.offsets[i]) * 4;
let _ = squared_sum_launch_into(buf, n, output_ptr, stream);
}
let _ = clip_scale_reduce_cuda(
&state.partials_buf,
state.total_partials,
max_norm,
&state.scale_buf,
stream,
);
let scale_ptr = state.scale_buf.as_ptr();
let bufs_mut: [&mut GpuBuffer<f32>; 6] = [
&mut ws.grad_lora_a_q,
&mut ws.grad_lora_b_q,
&mut ws.grad_lora_a_v,
&mut ws.grad_lora_b_v,
&mut ws.grad_input_norm,
&mut ws.grad_post_attn_norm,
];
for buf in bufs_mut {
let n = buf.len() as u32;
if n == 0 {
continue;
}
let _ = gradient_clip_gpu_scale_cuda(buf, scale_ptr, n, stream);
}
}