Skip to main content

entrenar/finetune/
fused_lora_clip.rs

1//! PMAT-477: Fused LoRA gradient clipping — zero D2H sync for CUDA graph capture.
2//!
3//! Replaces per-layer synchronous `clip_gradients()` with ALB-078 fused pipeline:
4//! 1. 6x `squared_sum_launch_into` (async, into contiguous partials buffer)
5//! 2. 1x `clip_scale_reduce_cuda` (GPU-side norm + scale computation)
6//! 3. 6x `gradient_clip_gpu_scale_cuda` (GPU-side scale read, no D2H)
7//!
8//! This eliminates 6 `stream.synchronize()` calls per backward step (168 per
9//! backward pass across 28 layers), enabling CUDA graph capture of the backward loop.
10//!
11//! # Contract: cuda-graph-backward-v1.yaml
12//!
13//! - F-GRAPH-BWD-001: Loss trajectory matches ungraphed within 0.1
14//! - F-GRAPH-BWD-002: Graph capture succeeds (no CUDA_ERROR)
15//! - F-GRAPH-BWD-003: Throughput >= 1.10x ungraphed at batch=4
16
17#[cfg(feature = "cuda")]
18use crate::autograd::cuda_optim::{
19    clip_scale_reduce_cuda, gradient_clip_gpu_scale_cuda, squared_sum_launch_into, FusedClipState,
20};
21#[cfg(feature = "cuda")]
22use crate::transformer::cuda_block::CudaLoraGradWorkspace;
23#[cfg(feature = "cuda")]
24use trueno_gpu::driver::{CudaContext, CudaStream, GpuBuffer};
25#[cfg(feature = "cuda")]
26use trueno_gpu::kernels::SquaredSumKernel;
27
28/// Initialize a `FusedClipState` sized for the 6 LoRA gradient buffers.
29///
30/// Pre-allocates the contiguous partials buffer and scale output buffer.
31/// Called once at training init, reused every backward step.
32#[cfg(feature = "cuda")]
33pub(crate) fn init_lora_fused_clip(
34    ws: &CudaLoraGradWorkspace,
35    ctx: &std::sync::Arc<CudaContext>,
36) -> Option<FusedClipState> {
37    let sizes: [u32; 6] = [
38        ws.grad_lora_a_q.len() as u32,
39        ws.grad_lora_b_q.len() as u32,
40        ws.grad_lora_a_v.len() as u32,
41        ws.grad_lora_b_v.len() as u32,
42        ws.grad_input_norm.len() as u32,
43        ws.grad_post_attn_norm.len() as u32,
44    ];
45
46    let mut offsets = [0u32; 9]; // FusedClipState uses [9] — pad unused
47    let mut total = 0u32;
48    for (i, &n) in sizes.iter().enumerate() {
49        offsets[i] = total;
50        let kernel = SquaredSumKernel::new(n);
51        total += kernel.num_blocks();
52    }
53
54    let partials_buf = GpuBuffer::<f32>::new(ctx, total as usize).ok()?;
55    let scale_buf = GpuBuffer::<f32>::new(ctx, 2).ok()?;
56
57    Some(FusedClipState {
58        partials_buf,
59        scale_buf,
60        offsets,
61        num_blocks: [0; 9],
62        total_partials: total,
63    })
64}
65
66/// Apply fused gradient clipping to LoRA workspace — zero D2H sync.
67///
68/// Three-phase pipeline (all GPU-side, CUDA graph capturable):
69/// 1. Launch squared-sum reductions for all 6 gradient buffers (async)
70/// 2. Reduce partials and compute clip scale on GPU
71/// 3. Apply clip scale from GPU memory to all 6 buffers
72#[cfg(feature = "cuda")]
73pub(crate) fn clip_lora_gradients_fused(
74    ws: &mut CudaLoraGradWorkspace,
75    max_norm: f32,
76    state: &FusedClipState,
77    stream: &CudaStream,
78) {
79    // Phase 1: Launch all 6 squared-sum reductions (async, no sync).
80    let bufs: [&GpuBuffer<f32>; 6] = [
81        &ws.grad_lora_a_q,
82        &ws.grad_lora_b_q,
83        &ws.grad_lora_a_v,
84        &ws.grad_lora_b_v,
85        &ws.grad_input_norm,
86        &ws.grad_post_attn_norm,
87    ];
88
89    for (i, buf) in bufs.iter().enumerate() {
90        let n = buf.len() as u32;
91        if n == 0 {
92            continue;
93        }
94        let output_ptr = state.partials_buf.as_ptr() + u64::from(state.offsets[i]) * 4;
95        let _ = squared_sum_launch_into(buf, n, output_ptr, stream);
96    }
97
98    // Phase 2: Reduce all partials → clip_scale on GPU (no D2H).
99    let _ = clip_scale_reduce_cuda(
100        &state.partials_buf,
101        state.total_partials,
102        max_norm,
103        &state.scale_buf,
104        stream,
105    );
106
107    // Phase 3: Apply clip scale from GPU memory (no D2H).
108    let scale_ptr = state.scale_buf.as_ptr();
109    let bufs_mut: [&mut GpuBuffer<f32>; 6] = [
110        &mut ws.grad_lora_a_q,
111        &mut ws.grad_lora_b_q,
112        &mut ws.grad_lora_a_v,
113        &mut ws.grad_lora_b_v,
114        &mut ws.grad_input_norm,
115        &mut ws.grad_post_attn_norm,
116    ];
117    for buf in bufs_mut {
118        let n = buf.len() as u32;
119        if n == 0 {
120            continue;
121        }
122        let _ = gradient_clip_gpu_scale_cuda(buf, scale_ptr, n, stream);
123    }
124}