entrenar/finetune/instruct_pipeline/mod.rs
1//! Instruction-following fine-tuning pipeline (GH-371)
2//!
3//! Wires Transformer + LoRA for causal language model fine-tuning on
4//! instruction-response pairs.
5//!
6//! # Architecture
7//!
8//! ```text
9//! [prompt_ids ++ response_ids] -> Transformer.forward() -> logits [seq_len, vocab_size]
10//! -> causal_lm_loss(logits[prompt_len..], response_ids) -> scalar loss
11//! ```
12//!
13//! # Contract
14//!
15//! - F-INST-002: Loss computed only on response tokens (prompt tokens masked)
16//! - F-INST-003: Perplexity = exp(avg_loss) reported per epoch
17//! - F-INST-004: LoRA adapters saved in APR format
18
19mod accessors;
20mod backward;
21mod constructors;
22mod cuda_forward;
23mod cuda_init;
24mod generate;
25mod training;
26mod wgpu;
27
28#[cfg(test)]
29mod tests;
30#[cfg(test)]
31mod tests_cov3;
32#[cfg(test)]
33mod tests_cov3b;
34
35use crate::lora::LoRALayer;
36use crate::optim::{clip_grad_norm_refs, AdamW, Optimizer};
37use crate::tokenizer::HfTokenizer;
38use crate::train::transformer_trainer::step_profiler::StepProfiler;
39use crate::transformer::{Transformer, TransformerConfig};
40use crate::Tensor;
41use std::path::{Path, PathBuf};
42
43#[cfg(feature = "cuda")]
44use crate::autograd::cuda_training::CudaTrainer;
45#[cfg(feature = "cuda")]
46use crate::gpu::guard::VramGuard;
47#[cfg(feature = "cuda")]
48use crate::transformer::{
49 CudaBlock, CudaBlockScratch, CudaLoraGradWorkspace, GpuLoraOptimizerState,
50};
51#[cfg(feature = "cuda")]
52use trueno_gpu::driver::GpuBuffer;
53
54/// Configuration for instruction fine-tuning.
55#[derive(Debug, Clone)]
56pub struct InstructConfig {
57 /// LoRA rank
58 pub lora_rank: usize,
59 /// LoRA alpha
60 pub lora_alpha: f32,
61 /// Learning rate
62 pub learning_rate: f32,
63 /// Number of training epochs
64 pub epochs: usize,
65 /// Maximum sequence length (prompt + response)
66 pub max_seq_len: usize,
67 /// Maximum gradient norm for clipping
68 pub gradient_clip_norm: Option<f32>,
69 /// Quantize frozen weights to NF4 (4-bit) for QLoRA training (default: false).
70 ///
71 /// When enabled, uses `CudaNf4TransformerBlock` (~8x VRAM compression) instead
72 /// of `CudaTransformerBlock`. GPU backward pass updates only LoRA adapters.
73 pub quantize_nf4: bool,
74}
75
76impl Default for InstructConfig {
77 fn default() -> Self {
78 Self {
79 lora_rank: 16,
80 lora_alpha: 32.0,
81 learning_rate: 2e-4,
82 epochs: 3,
83 max_seq_len: 512,
84 gradient_clip_norm: Some(1.0),
85 quantize_nf4: false,
86 }
87 }
88}
89
90/// Result of processing one instruction-response pair.
91#[derive(Debug, Clone)]
92pub struct InstructStepResult {
93 /// Cross-entropy loss on response tokens
94 pub loss: f32,
95 /// Number of response tokens
96 pub num_response_tokens: usize,
97 /// Perplexity = exp(loss)
98 pub perplexity: f32,
99}
100
101/// Result of processing a mini-batch of instruction samples.
102#[derive(Debug, Clone)]
103pub struct InstructBatchResult {
104 /// Average cross-entropy loss across the batch (response tokens only)
105 pub avg_loss: f32,
106 /// Total response tokens in batch
107 pub total_response_tokens: usize,
108 /// Perplexity = exp(avg_loss)
109 pub perplexity: f32,
110 /// Gradient norm before clipping
111 pub grad_norm: f32,
112}
113
114/// Instruction fine-tuning pipeline.
115///
116/// Owns the transformer and LoRA adapters. Uses `Transformer::forward()`
117/// for causal LM logits and computes loss on response tokens only.
118/// GPU-resident training state for NF4 QLoRA backward pass.
119///
120/// Holds per-layer activation snapshots and scratch buffers needed for
121/// activation checkpointing during NF4 backward.
122#[cfg(feature = "cuda")]
123pub(super) struct InstructGpuTrainingState {
124 /// Saved input to each block during forward [num_layers][max_seq_len * hidden_size]
125 layer_inputs: Vec<GpuBuffer<f32>>,
126 /// Final RMSNorm weight uploaded to GPU [hidden_size]
127 final_norm_weight: GpuBuffer<f32>,
128 /// Blocks output saved on GPU for final norm backward [max_seq_len * hidden_size]
129 blocks_output: GpuBuffer<f32>,
130 /// Gradient scratch buffer A [max_seq_len * hidden_size]
131 grad_buf_a: GpuBuffer<f32>,
132 /// Gradient scratch buffer B [max_seq_len * hidden_size]
133 grad_buf_b: GpuBuffer<f32>,
134 /// Gradient for final RMSNorm weight [hidden_size]
135 grad_final_norm_weight: GpuBuffer<f32>,
136 embed_transposed: GpuBuffer<f32>, // [hidden*vocab] lm_head forward
137 embed_original: GpuBuffer<f32>, // [vocab*hidden] lm_head backward (KAIZEN-068)
138 /// GPU scratch for logits [max_seq_len * vocab_size]
139 logits_buf: GpuBuffer<f32>,
140 /// GPU scratch for grad_hidden [max_seq_len * hidden_size]
141 grad_hidden_buf: GpuBuffer<f32>,
142 /// KAIZEN-045: Pre-allocated scratch buffer for activation checkpointing in backward
143 output_scratch: GpuBuffer<f32>,
144 /// KAIZEN-045: Pre-allocated upload buffer for gradient H2D transfer in backward
145 grad_upload_buf: GpuBuffer<f32>,
146 /// KAIZEN-062: Pre-allocated forward ping-pong buffer A
147 fwd_scratch_a: GpuBuffer<f32>,
148 /// KAIZEN-062: Pre-allocated forward ping-pong buffer B
149 fwd_scratch_b: GpuBuffer<f32>,
150 /// KAIZEN-062: Pre-allocated lm_head hidden input buffer
151 lm_head_hidden_buf: GpuBuffer<f32>,
152 /// PMAT-464: Cached CUDA graph for forward pass replay.
153 forward_graph_exec: Option<trueno_gpu::driver::CudaGraphExec>,
154 graph_cached_seq_len: usize,
155 /// PMAT-488: Cached CUDA graph for backward pass replay.
156 backward_graph_state: Option<super::backward_graph::BackwardGraphState>,
157 /// PMAT-063: cuBLAS workspace buffer (must outlive CUDA graph)
158 cublas_workspace: Option<GpuBuffer<f32>>,
159 /// PMAT-483: Per-layer forward timing (microseconds per layer per step)
160 profiler_layer_fwd_us: Vec<u64>,
161 /// PMAT-483: Per-layer backward timing (microseconds per layer per step)
162 profiler_layer_bwd_us: Vec<u64>,
163 /// PMAT-483: Temporary layer start timestamp
164 profiler_layer_start: Option<std::time::Instant>,
165 /// PMAT-483/entrenar#328: Per-operation timing within layers (accumulated per step)
166 /// Index matches StepProfiler::OP_* constants. Reset each step.
167 profiler_op_us: [u64; 16],
168 /// Per-operation start timestamp
169 profiler_op_start: Option<std::time::Instant>,
170}
171
172pub struct InstructPipeline {
173 /// Base transformer model
174 pub model: Transformer,
175 /// LoRA adapters applied to Q/V attention projections
176 pub lora_layers: Vec<LoRALayer>,
177 /// Pipeline configuration
178 pub config: InstructConfig,
179 /// AdamW optimizer for trainable parameters
180 optimizer: AdamW,
181 /// Optional BPE tokenizer
182 tokenizer: Option<HfTokenizer>,
183 /// Path to base model (for checkpoint provenance)
184 model_dir: Option<PathBuf>,
185 /// PMAT-483: Per-step profiler for scientific training measurement.
186 /// Zero-overhead when disabled. Enable via --profile-interval N.
187 pub profiler: StepProfiler,
188 /// CUDA trainer for GPU memory management
189 #[cfg(feature = "cuda")]
190 cuda_trainer: Option<CudaTrainer>,
191 /// CUDA-accelerated transformer blocks -- one per layer
192 #[cfg(feature = "cuda")]
193 cuda_blocks: Option<Vec<CudaBlock>>,
194 /// Shared scratch buffers for NF4 forward pass
195 #[cfg(feature = "cuda")]
196 shared_scratch: Option<CudaBlockScratch>,
197 /// Count of GPU forward passes that produced NaN/Inf
198 #[cfg(feature = "cuda")]
199 #[allow(dead_code)]
200 cuda_nan_count: usize,
201 /// GPU training state for NF4 QLoRA backward pass
202 #[cfg(feature = "cuda")]
203 gpu_training: Option<InstructGpuTrainingState>,
204 /// Shared LoRA gradient workspace for NF4 QLoRA backward
205 #[cfg(feature = "cuda")]
206 cuda_lora_grad_workspace: Option<CudaLoraGradWorkspace>,
207 /// PMAT-477: Fused clip state -- zero D2H sync gradient clipping
208 #[cfg(feature = "cuda")]
209 lora_fused_clip: Option<crate::autograd::cuda_optim::FusedClipState>,
210 /// Per-layer LoRA optimizer states for NF4 QLoRA training
211 #[cfg(feature = "cuda")]
212 cuda_lora_optimizer_states: Option<Vec<GpuLoraOptimizerState>>,
213 /// NF4 LoRA optimizer step counter
214 #[cfg(feature = "cuda")]
215 nf4_lora_step: u32,
216 /// VRAM reservation guard (GPU-SHARE-002). Releases ledger entry on Drop.
217 #[cfg(feature = "cuda")]
218 #[allow(dead_code)]
219 vram_guard: Option<VramGuard>,
220 /// wgpu training pipeline (zero unsafe alternative to CUDA)
221 #[cfg(feature = "gpu")]
222 wgpu_training: Option<WgpuTrainingState>,
223}
224
225/// State for wgpu-based training pipeline (WgpuTrainingPipeline)
226#[cfg(feature = "gpu")]
227struct WgpuTrainingState {
228 /// GPU forward pass with persistent weight buffers + tiled GEMM
229 fwd: trueno::backends::gpu::WgslForwardPass,
230 cross_entropy: crate::autograd::wgpu_cross_entropy::WgslCrossEntropy,
231 trainer: crate::autograd::wgpu_training::WgpuTrainer,
232 // GPU buffers for logits, labels, losses, logsumexp
233 logits_buf: trueno::backends::gpu::wgpu::Buffer,
234 labels_buf: trueno::backends::gpu::wgpu::Buffer,
235 losses_buf: trueno::backends::gpu::wgpu::Buffer,
236 logsumexp_buf: trueno::backends::gpu::wgpu::Buffer,
237 // Precomputed lm_head GPU buffers
238 lm_head_gpu: trueno::backends::gpu::wgpu::Buffer,
239 lm_head_t_gpu: trueno::backends::gpu::wgpu::Buffer,
240 // Model config needed for forward pass
241 num_layers: usize,
242 hidden_dim: usize,
243 vocab_size: usize,
244}
245
246/// Configuration for autoregressive text generation.
247#[derive(Debug, Clone)]
248pub struct GenerateConfig {
249 /// Maximum number of new tokens to generate (default: 256)
250 pub max_new_tokens: usize,
251 /// Sampling temperature (0.0 = greedy/argmax, >0 = stochastic)
252 pub temperature: f32,
253 /// Top-k filtering (0 = disabled, >0 = keep only top-k logits)
254 pub top_k: usize,
255 /// Additional stop token IDs (generation stops on EOS or any of these)
256 pub stop_tokens: Vec<u32>,
257}
258
259/// Sample a token from logits with temperature and top-k filtering.
260fn sample_token(logits: &[f32], temperature: f32, top_k: usize) -> u32 {
261 if temperature <= 0.0 || top_k == 1 {
262 // Greedy: argmax
263 return logits
264 .iter()
265 .enumerate()
266 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
267 .map_or(0, |(idx, _)| idx as u32);
268 }
269
270 // Temperature scaling
271 let scaled: Vec<f32> = logits.iter().map(|&l| l / temperature).collect();
272
273 // Top-k filtering
274 let mut indices_and_logits: Vec<(usize, f32)> = scaled.iter().copied().enumerate().collect();
275 indices_and_logits
276 .sort_unstable_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
277
278 let k = if top_k > 0 && top_k < indices_and_logits.len() {
279 top_k
280 } else {
281 indices_and_logits.len()
282 };
283 let top = &indices_and_logits[..k];
284
285 // Softmax over top-k
286 let max_logit = top[0].1;
287 let exps: Vec<f32> = top.iter().map(|(_, l)| (l - max_logit).exp()).collect();
288 let sum: f32 = exps.iter().sum();
289 let probs: Vec<f32> = exps.iter().map(|e| e / sum).collect();
290
291 // Sample from distribution (simple linear scan)
292 let r: f32 = simple_random();
293 let mut cumulative = 0.0;
294 for (i, &p) in probs.iter().enumerate() {
295 cumulative += p;
296 if r < cumulative {
297 return top[i].0 as u32;
298 }
299 }
300
301 // Fallback to top-1
302 top[0].0 as u32
303}
304
305/// Simple pseudo-random float in [0, 1) using thread-local state.
306/// Not cryptographically secure but sufficient for sampling.
307fn simple_random() -> f32 {
308 use std::cell::Cell;
309 thread_local! {
310 static STATE: Cell<u64> = Cell::new(
311 std::time::SystemTime::now()
312 .duration_since(std::time::UNIX_EPOCH)
313 .map(|d| d.as_nanos() as u64)
314 .unwrap_or(42)
315 );
316 }
317 STATE.with(|s| {
318 // xorshift64
319 let mut x = s.get();
320 x ^= x << 13;
321 x ^= x >> 7;
322 x ^= x << 17;
323 s.set(x);
324 (x >> 40) as f32 / (1u64 << 24) as f32
325 })
326}