Skip to main content

entrenar/finetune/instruct_pipeline/
mod.rs

1//! Instruction-following fine-tuning pipeline (GH-371)
2//!
3//! Wires Transformer + LoRA for causal language model fine-tuning on
4//! instruction-response pairs.
5//!
6//! # Architecture
7//!
8//! ```text
9//! [prompt_ids ++ response_ids] -> Transformer.forward() -> logits [seq_len, vocab_size]
10//!   -> causal_lm_loss(logits[prompt_len..], response_ids) -> scalar loss
11//! ```
12//!
13//! # Contract
14//!
15//! - F-INST-002: Loss computed only on response tokens (prompt tokens masked)
16//! - F-INST-003: Perplexity = exp(avg_loss) reported per epoch
17//! - F-INST-004: LoRA adapters saved in APR format
18
19mod accessors;
20mod backward;
21mod constructors;
22mod cuda_forward;
23mod cuda_init;
24mod generate;
25mod training;
26mod wgpu;
27
28#[cfg(test)]
29mod tests;
30#[cfg(test)]
31mod tests_cov3;
32#[cfg(test)]
33mod tests_cov3b;
34
35use crate::lora::LoRALayer;
36use crate::optim::{clip_grad_norm_refs, AdamW, Optimizer};
37use crate::tokenizer::HfTokenizer;
38use crate::train::transformer_trainer::step_profiler::StepProfiler;
39use crate::transformer::{Transformer, TransformerConfig};
40use crate::Tensor;
41use std::path::{Path, PathBuf};
42
43#[cfg(feature = "cuda")]
44use crate::autograd::cuda_training::CudaTrainer;
45#[cfg(feature = "cuda")]
46use crate::gpu::guard::VramGuard;
47#[cfg(feature = "cuda")]
48use crate::transformer::{
49    CudaBlock, CudaBlockScratch, CudaLoraGradWorkspace, GpuLoraOptimizerState,
50};
51#[cfg(feature = "cuda")]
52use trueno_gpu::driver::GpuBuffer;
53
54/// Configuration for instruction fine-tuning.
55#[derive(Debug, Clone)]
56pub struct InstructConfig {
57    /// LoRA rank
58    pub lora_rank: usize,
59    /// LoRA alpha
60    pub lora_alpha: f32,
61    /// Learning rate
62    pub learning_rate: f32,
63    /// Number of training epochs
64    pub epochs: usize,
65    /// Maximum sequence length (prompt + response)
66    pub max_seq_len: usize,
67    /// Maximum gradient norm for clipping
68    pub gradient_clip_norm: Option<f32>,
69    /// Quantize frozen weights to NF4 (4-bit) for QLoRA training (default: false).
70    ///
71    /// When enabled, uses `CudaNf4TransformerBlock` (~8x VRAM compression) instead
72    /// of `CudaTransformerBlock`. GPU backward pass updates only LoRA adapters.
73    pub quantize_nf4: bool,
74}
75
76impl Default for InstructConfig {
77    fn default() -> Self {
78        Self {
79            lora_rank: 16,
80            lora_alpha: 32.0,
81            learning_rate: 2e-4,
82            epochs: 3,
83            max_seq_len: 512,
84            gradient_clip_norm: Some(1.0),
85            quantize_nf4: false,
86        }
87    }
88}
89
90/// Result of processing one instruction-response pair.
91#[derive(Debug, Clone)]
92pub struct InstructStepResult {
93    /// Cross-entropy loss on response tokens
94    pub loss: f32,
95    /// Number of response tokens
96    pub num_response_tokens: usize,
97    /// Perplexity = exp(loss)
98    pub perplexity: f32,
99}
100
101/// Result of processing a mini-batch of instruction samples.
102#[derive(Debug, Clone)]
103pub struct InstructBatchResult {
104    /// Average cross-entropy loss across the batch (response tokens only)
105    pub avg_loss: f32,
106    /// Total response tokens in batch
107    pub total_response_tokens: usize,
108    /// Perplexity = exp(avg_loss)
109    pub perplexity: f32,
110    /// Gradient norm before clipping
111    pub grad_norm: f32,
112}
113
114/// Instruction fine-tuning pipeline.
115///
116/// Owns the transformer and LoRA adapters. Uses `Transformer::forward()`
117/// for causal LM logits and computes loss on response tokens only.
118/// GPU-resident training state for NF4 QLoRA backward pass.
119///
120/// Holds per-layer activation snapshots and scratch buffers needed for
121/// activation checkpointing during NF4 backward.
122#[cfg(feature = "cuda")]
123pub(super) struct InstructGpuTrainingState {
124    /// Saved input to each block during forward [num_layers][max_seq_len * hidden_size]
125    layer_inputs: Vec<GpuBuffer<f32>>,
126    /// Final RMSNorm weight uploaded to GPU [hidden_size]
127    final_norm_weight: GpuBuffer<f32>,
128    /// Blocks output saved on GPU for final norm backward [max_seq_len * hidden_size]
129    blocks_output: GpuBuffer<f32>,
130    /// Gradient scratch buffer A [max_seq_len * hidden_size]
131    grad_buf_a: GpuBuffer<f32>,
132    /// Gradient scratch buffer B [max_seq_len * hidden_size]
133    grad_buf_b: GpuBuffer<f32>,
134    /// Gradient for final RMSNorm weight [hidden_size]
135    grad_final_norm_weight: GpuBuffer<f32>,
136    embed_transposed: GpuBuffer<f32>, // [hidden*vocab] lm_head forward
137    embed_original: GpuBuffer<f32>,   // [vocab*hidden] lm_head backward (KAIZEN-068)
138    /// GPU scratch for logits [max_seq_len * vocab_size]
139    logits_buf: GpuBuffer<f32>,
140    /// GPU scratch for grad_hidden [max_seq_len * hidden_size]
141    grad_hidden_buf: GpuBuffer<f32>,
142    /// KAIZEN-045: Pre-allocated scratch buffer for activation checkpointing in backward
143    output_scratch: GpuBuffer<f32>,
144    /// KAIZEN-045: Pre-allocated upload buffer for gradient H2D transfer in backward
145    grad_upload_buf: GpuBuffer<f32>,
146    /// KAIZEN-062: Pre-allocated forward ping-pong buffer A
147    fwd_scratch_a: GpuBuffer<f32>,
148    /// KAIZEN-062: Pre-allocated forward ping-pong buffer B
149    fwd_scratch_b: GpuBuffer<f32>,
150    /// KAIZEN-062: Pre-allocated lm_head hidden input buffer
151    lm_head_hidden_buf: GpuBuffer<f32>,
152    /// PMAT-464: Cached CUDA graph for forward pass replay.
153    forward_graph_exec: Option<trueno_gpu::driver::CudaGraphExec>,
154    graph_cached_seq_len: usize,
155    /// PMAT-488: Cached CUDA graph for backward pass replay.
156    backward_graph_state: Option<super::backward_graph::BackwardGraphState>,
157    /// PMAT-063: cuBLAS workspace buffer (must outlive CUDA graph)
158    cublas_workspace: Option<GpuBuffer<f32>>,
159    /// PMAT-483: Per-layer forward timing (microseconds per layer per step)
160    profiler_layer_fwd_us: Vec<u64>,
161    /// PMAT-483: Per-layer backward timing (microseconds per layer per step)
162    profiler_layer_bwd_us: Vec<u64>,
163    /// PMAT-483: Temporary layer start timestamp
164    profiler_layer_start: Option<std::time::Instant>,
165    /// PMAT-483/entrenar#328: Per-operation timing within layers (accumulated per step)
166    /// Index matches StepProfiler::OP_* constants. Reset each step.
167    profiler_op_us: [u64; 16],
168    /// Per-operation start timestamp
169    profiler_op_start: Option<std::time::Instant>,
170}
171
172pub struct InstructPipeline {
173    /// Base transformer model
174    pub model: Transformer,
175    /// LoRA adapters applied to Q/V attention projections
176    pub lora_layers: Vec<LoRALayer>,
177    /// Pipeline configuration
178    pub config: InstructConfig,
179    /// AdamW optimizer for trainable parameters
180    optimizer: AdamW,
181    /// Optional BPE tokenizer
182    tokenizer: Option<HfTokenizer>,
183    /// Path to base model (for checkpoint provenance)
184    model_dir: Option<PathBuf>,
185    /// PMAT-483: Per-step profiler for scientific training measurement.
186    /// Zero-overhead when disabled. Enable via --profile-interval N.
187    pub profiler: StepProfiler,
188    /// CUDA trainer for GPU memory management
189    #[cfg(feature = "cuda")]
190    cuda_trainer: Option<CudaTrainer>,
191    /// CUDA-accelerated transformer blocks -- one per layer
192    #[cfg(feature = "cuda")]
193    cuda_blocks: Option<Vec<CudaBlock>>,
194    /// Shared scratch buffers for NF4 forward pass
195    #[cfg(feature = "cuda")]
196    shared_scratch: Option<CudaBlockScratch>,
197    /// Count of GPU forward passes that produced NaN/Inf
198    #[cfg(feature = "cuda")]
199    #[allow(dead_code)]
200    cuda_nan_count: usize,
201    /// GPU training state for NF4 QLoRA backward pass
202    #[cfg(feature = "cuda")]
203    gpu_training: Option<InstructGpuTrainingState>,
204    /// Shared LoRA gradient workspace for NF4 QLoRA backward
205    #[cfg(feature = "cuda")]
206    cuda_lora_grad_workspace: Option<CudaLoraGradWorkspace>,
207    /// PMAT-477: Fused clip state -- zero D2H sync gradient clipping
208    #[cfg(feature = "cuda")]
209    lora_fused_clip: Option<crate::autograd::cuda_optim::FusedClipState>,
210    /// Per-layer LoRA optimizer states for NF4 QLoRA training
211    #[cfg(feature = "cuda")]
212    cuda_lora_optimizer_states: Option<Vec<GpuLoraOptimizerState>>,
213    /// NF4 LoRA optimizer step counter
214    #[cfg(feature = "cuda")]
215    nf4_lora_step: u32,
216    /// VRAM reservation guard (GPU-SHARE-002). Releases ledger entry on Drop.
217    #[cfg(feature = "cuda")]
218    #[allow(dead_code)]
219    vram_guard: Option<VramGuard>,
220    /// wgpu training pipeline (zero unsafe alternative to CUDA)
221    #[cfg(feature = "gpu")]
222    wgpu_training: Option<WgpuTrainingState>,
223}
224
225/// State for wgpu-based training pipeline (WgpuTrainingPipeline)
226#[cfg(feature = "gpu")]
227struct WgpuTrainingState {
228    /// GPU forward pass with persistent weight buffers + tiled GEMM
229    fwd: trueno::backends::gpu::WgslForwardPass,
230    cross_entropy: crate::autograd::wgpu_cross_entropy::WgslCrossEntropy,
231    trainer: crate::autograd::wgpu_training::WgpuTrainer,
232    // GPU buffers for logits, labels, losses, logsumexp
233    logits_buf: trueno::backends::gpu::wgpu::Buffer,
234    labels_buf: trueno::backends::gpu::wgpu::Buffer,
235    losses_buf: trueno::backends::gpu::wgpu::Buffer,
236    logsumexp_buf: trueno::backends::gpu::wgpu::Buffer,
237    // Precomputed lm_head GPU buffers
238    lm_head_gpu: trueno::backends::gpu::wgpu::Buffer,
239    lm_head_t_gpu: trueno::backends::gpu::wgpu::Buffer,
240    // Model config needed for forward pass
241    num_layers: usize,
242    hidden_dim: usize,
243    vocab_size: usize,
244}
245
246/// Configuration for autoregressive text generation.
247#[derive(Debug, Clone)]
248pub struct GenerateConfig {
249    /// Maximum number of new tokens to generate (default: 256)
250    pub max_new_tokens: usize,
251    /// Sampling temperature (0.0 = greedy/argmax, >0 = stochastic)
252    pub temperature: f32,
253    /// Top-k filtering (0 = disabled, >0 = keep only top-k logits)
254    pub top_k: usize,
255    /// Additional stop token IDs (generation stops on EOS or any of these)
256    pub stop_tokens: Vec<u32>,
257}
258
259/// Sample a token from logits with temperature and top-k filtering.
260fn sample_token(logits: &[f32], temperature: f32, top_k: usize) -> u32 {
261    if temperature <= 0.0 || top_k == 1 {
262        // Greedy: argmax
263        return logits
264            .iter()
265            .enumerate()
266            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
267            .map_or(0, |(idx, _)| idx as u32);
268    }
269
270    // Temperature scaling
271    let scaled: Vec<f32> = logits.iter().map(|&l| l / temperature).collect();
272
273    // Top-k filtering
274    let mut indices_and_logits: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
275    indices_and_logits
276        .sort_unstable_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
277
278    let k = if top_k > 0 && top_k < indices_and_logits.len() {
279        top_k
280    } else {
281        indices_and_logits.len()
282    };
283    let top = &indices_and_logits[..k];
284
285    // Softmax over top-k
286    let max_logit = top[0].1;
287    let exps: Vec<f32> = top.iter().map(|(_, l)| (l - max_logit).exp()).collect();
288    let sum: f32 = exps.iter().sum();
289    let probs: Vec<f32> = exps.iter().map(|e| e / sum).collect();
290
291    // Sample from distribution (simple linear scan)
292    let r: f32 = simple_random();
293    let mut cumulative = 0.0;
294    for (i, &p) in probs.iter().enumerate() {
295        cumulative += p;
296        if r < cumulative {
297            return top[i].0 as u32;
298        }
299    }
300
301    // Fallback to top-1
302    top[0].0 as u32
303}
304
305/// Simple pseudo-random float in [0, 1) using thread-local state.
306/// Not cryptographically secure but sufficient for sampling.
307fn simple_random() -> f32 {
308    use std::cell::Cell;
309    thread_local! {
310        static STATE: Cell<u64> = Cell::new(
311            std::time::SystemTime::now()
312                .duration_since(std::time::UNIX_EPOCH)
313                .map(|d| d.as_nanos() as u64)
314                .unwrap_or(42)
315        );
316    }
317    STATE.with(|s| {
318        // xorshift64
319        let mut x = s.get();
320        x ^= x << 13;
321        x ^= x >> 7;
322        x ^= x << 17;
323        s.set(x);
324        (x >> 40) as f32 / (1u64 << 24) as f32
325    })
326}