entrenar/finetune/
fused_lora_clip.rs1#[cfg(feature = "cuda")]
18use crate::autograd::cuda_optim::{
19 clip_scale_reduce_cuda, gradient_clip_gpu_scale_cuda, squared_sum_launch_into, FusedClipState,
20};
21#[cfg(feature = "cuda")]
22use crate::transformer::cuda_block::CudaLoraGradWorkspace;
23#[cfg(feature = "cuda")]
24use trueno_gpu::driver::{CudaContext, CudaStream, GpuBuffer};
25#[cfg(feature = "cuda")]
26use trueno_gpu::kernels::SquaredSumKernel;
27
28#[cfg(feature = "cuda")]
33pub(crate) fn init_lora_fused_clip(
34 ws: &CudaLoraGradWorkspace,
35 ctx: &std::sync::Arc<CudaContext>,
36) -> Option<FusedClipState> {
37 let sizes: [u32; 6] = [
38 ws.grad_lora_a_q.len() as u32,
39 ws.grad_lora_b_q.len() as u32,
40 ws.grad_lora_a_v.len() as u32,
41 ws.grad_lora_b_v.len() as u32,
42 ws.grad_input_norm.len() as u32,
43 ws.grad_post_attn_norm.len() as u32,
44 ];
45
46 let mut offsets = [0u32; 9]; let mut total = 0u32;
48 for (i, &n) in sizes.iter().enumerate() {
49 offsets[i] = total;
50 let kernel = SquaredSumKernel::new(n);
51 total += kernel.num_blocks();
52 }
53
54 let partials_buf = GpuBuffer::<f32>::new(ctx, total as usize).ok()?;
55 let scale_buf = GpuBuffer::<f32>::new(ctx, 2).ok()?;
56
57 Some(FusedClipState {
58 partials_buf,
59 scale_buf,
60 offsets,
61 num_blocks: [0; 9],
62 total_partials: total,
63 })
64}
65
66#[cfg(feature = "cuda")]
73pub(crate) fn clip_lora_gradients_fused(
74 ws: &mut CudaLoraGradWorkspace,
75 max_norm: f32,
76 state: &FusedClipState,
77 stream: &CudaStream,
78) {
79 let bufs: [&GpuBuffer<f32>; 6] = [
81 &ws.grad_lora_a_q,
82 &ws.grad_lora_b_q,
83 &ws.grad_lora_a_v,
84 &ws.grad_lora_b_v,
85 &ws.grad_input_norm,
86 &ws.grad_post_attn_norm,
87 ];
88
89 for (i, buf) in bufs.iter().enumerate() {
90 let n = buf.len() as u32;
91 if n == 0 {
92 continue;
93 }
94 let output_ptr = state.partials_buf.as_ptr() + u64::from(state.offsets[i]) * 4;
95 let _ = squared_sum_launch_into(buf, n, output_ptr, stream);
96 }
97
98 let _ = clip_scale_reduce_cuda(
100 &state.partials_buf,
101 state.total_partials,
102 max_norm,
103 &state.scale_buf,
104 stream,
105 );
106
107 let scale_ptr = state.scale_buf.as_ptr();
109 let bufs_mut: [&mut GpuBuffer<f32>; 6] = [
110 &mut ws.grad_lora_a_q,
111 &mut ws.grad_lora_b_q,
112 &mut ws.grad_lora_a_v,
113 &mut ws.grad_lora_b_v,
114 &mut ws.grad_input_norm,
115 &mut ws.grad_post_attn_norm,
116 ];
117 for buf in bufs_mut {
118 let n = buf.len() as u32;
119 if n == 0 {
120 continue;
121 }
122 let _ = gradient_clip_gpu_scale_cuda(buf, scale_ptr, n, stream);
123 }
124}