entrenar/train/transformer_trainer/config.rs
1//! Configuration for transformer training
2
3use crate::autograd::{CheckpointConfig, MixedPrecisionConfig};
4use crate::train::TrainConfig;
5use crate::transformer::TransformerConfig;
6use std::net::SocketAddr;
7
8/// Role of a node in distributed pretraining.
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
10pub enum DistributedRole {
11 /// Coordinates training: AllReduces gradients, manages checkpoints
12 #[default]
13 Coordinator,
14 /// Computes forward/backward on assigned shard
15 Worker,
16}
17
18/// Compute backend for a distributed worker.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
20pub enum DistributedBackend {
21 /// NVIDIA CUDA
22 Cuda,
23 /// wgpu (cross-platform)
24 Wgpu,
25 /// Auto-detect best available
26 #[default]
27 Auto,
28}
29
30/// Configuration for distributed pretraining (DDP).
31///
32/// Specifies this worker's role, rank, and communication topology.
33/// All workers must agree on `world_size`. The coordinator address
34/// is where workers connect and where AllReduce is orchestrated.
35///
36/// # Contract
37///
38/// C-DDP-001: After AllReduce + optimizer step, all workers hold identical weights.
39#[derive(Debug, Clone)]
40pub struct DistributedTrainConfig {
41 /// Total number of workers participating
42 pub world_size: usize,
43 /// This worker's global rank (0-indexed)
44 pub rank: usize,
45 /// This worker's local rank on its machine (for multi-GPU)
46 pub local_rank: usize,
47 /// Role: coordinator (rank 0) or worker
48 pub role: DistributedRole,
49 /// Address for coordinator to bind / workers to connect
50 pub coordinator_addr: SocketAddr,
51 /// Compute backend for this worker
52 pub backend: DistributedBackend,
53}
54
55/// Configuration for transformer training
56#[allow(clippy::struct_excessive_bools)]
57#[derive(Debug, Clone)]
58pub struct TransformerTrainConfig {
59 /// Base training configuration
60 pub base: TrainConfig,
61 /// Transformer architecture configuration
62 pub model_config: TransformerConfig,
63 /// Checkpoint configuration for memory efficiency
64 pub checkpoint_config: CheckpointConfig,
65 /// Mixed-precision configuration
66 pub precision_config: MixedPrecisionConfig,
67 /// Maximum sequence length
68 pub max_seq_len: usize,
69 /// Accumulation steps for gradient accumulation
70 pub accumulation_steps: usize,
71 /// Warmup steps for learning rate scheduler
72 pub warmup_steps: usize,
73 /// Learning rate
74 pub lr: f32,
75 /// Maximum training steps (stop after this many optimizer steps)
76 pub max_steps: Option<usize>,
77 /// Use CUDA GPU training when available (default: true = auto-detect)
78 pub use_cuda: bool,
79 /// AdamW beta1 (default: 0.9)
80 pub beta1: f32,
81 /// AdamW beta2 (default: 0.999)
82 pub beta2: f32,
83 /// AdamW weight decay (default: 0.01)
84 pub weight_decay: f32,
85 /// Distributed training configuration (None = single-GPU)
86 pub distributed: Option<DistributedTrainConfig>,
87 /// Enable bitwise deterministic training (CUBLAS_WORKSPACE_CONFIG, cuDNN deterministic)
88 /// Contract: C-DETERM-001
89 pub deterministic: bool,
90 /// Random seed for reproducibility
91 pub seed: u64,
92 /// KAIZEN-047: Step profiler report interval (0 = disabled, N = print every N steps)
93 pub profile_interval: usize,
94 /// LoRA rank (None = full fine-tuning, Some(r) = LoRA with rank r)
95 pub lora_rank: Option<usize>,
96 /// LoRA alpha scaling factor (default: 2 * rank)
97 pub lora_alpha: Option<f32>,
98 /// LoRA target modules (e.g., `q_proj`, `v_proj`)
99 pub lora_target_modules: Option<Vec<String>>,
100 /// LoRA+ ratio: LR multiplier for B matrices (ENT-LoRA-006)
101 /// Default 1.0 = standard LoRA. 16.0 = LoRA+ (Hayou et al. ICML 2024)
102 pub lora_plus_ratio: f32,
103 /// Double quantization for QLoRA (ENT-LoRA-008)
104 /// Quantizes FP32 absmax constants to 8-bit, saving ~0.37 bits/param
105 pub double_quantize: bool,
106 /// Quantize frozen base weights to NF4 (4-bit) for QLoRA pretraining (ENT-263)
107 ///
108 /// When enabled with LoRA, uses `CudaNf4TransformerBlock` instead of
109 /// `CudaTransformerBlock`, achieving ~8x VRAM compression on frozen weights.
110 /// Only LoRA adapters and norm weights remain trainable in fp32.
111 pub quantize_nf4: bool,
112}
113
114impl TransformerTrainConfig {
115 /// Create new config with defaults
116 pub fn new(model_config: TransformerConfig) -> Self {
117 Self {
118 base: TrainConfig::default(),
119 model_config,
120 checkpoint_config: CheckpointConfig::disabled(),
121 precision_config: MixedPrecisionConfig::fp32(),
122 max_seq_len: 512,
123 accumulation_steps: 1,
124 warmup_steps: 0,
125 lr: 0.001,
126 max_steps: None,
127 use_cuda: true,
128 beta1: 0.9,
129 beta2: 0.999,
130 weight_decay: 0.01,
131 distributed: None,
132 deterministic: false,
133 seed: 42,
134 profile_interval: 0,
135 lora_rank: None,
136 lora_alpha: None,
137 lora_target_modules: None,
138 lora_plus_ratio: 1.0,
139 double_quantize: false,
140 quantize_nf4: false,
141 }
142 }
143
144 /// Enable gradient checkpointing
145 pub fn with_checkpointing(mut self, num_segments: usize) -> Self {
146 self.checkpoint_config = CheckpointConfig::enabled(num_segments);
147 self
148 }
149
150 /// Enable bf16 mixed precision
151 pub fn with_bf16(mut self) -> Self {
152 self.precision_config = MixedPrecisionConfig::bf16();
153 self
154 }
155
156 /// Enable fp16 mixed precision with dynamic loss scaling
157 pub fn with_fp16(mut self) -> Self {
158 self.precision_config = MixedPrecisionConfig::fp16();
159 self
160 }
161
162 /// Set maximum sequence length
163 pub fn with_max_seq_len(mut self, len: usize) -> Self {
164 self.max_seq_len = len;
165 self
166 }
167
168 /// Set gradient accumulation steps
169 pub fn with_accumulation_steps(mut self, steps: usize) -> Self {
170 self.accumulation_steps = steps.max(1);
171 self
172 }
173
174 /// Set warmup steps
175 pub fn with_warmup_steps(mut self, steps: usize) -> Self {
176 self.warmup_steps = steps;
177 self
178 }
179
180 /// Set learning rate
181 pub fn with_lr(mut self, lr: f32) -> Self {
182 self.lr = lr;
183 self
184 }
185
186 /// Set gradient clipping
187 pub fn with_grad_clip(mut self, clip: f32) -> Self {
188 self.base.max_grad_norm = Some(clip);
189 self
190 }
191
192 /// Set maximum training steps
193 pub fn with_max_steps(mut self, steps: usize) -> Self {
194 self.max_steps = Some(steps);
195 self
196 }
197
198 /// Enable or disable CUDA GPU training (default: true = auto-detect)
199 pub fn with_use_cuda(mut self, use_cuda: bool) -> Self {
200 self.use_cuda = use_cuda;
201 self
202 }
203
204 /// Set AdamW beta2 (default: 0.999)
205 pub fn with_beta2(mut self, beta2: f32) -> Self {
206 self.beta2 = beta2;
207 self
208 }
209
210 /// Set AdamW weight decay (default: 0.01)
211 pub fn with_weight_decay(mut self, wd: f32) -> Self {
212 self.weight_decay = wd;
213 self
214 }
215
216 /// Enable bitwise deterministic training (C-DETERM-001)
217 ///
218 /// Sets CUBLAS_WORKSPACE_CONFIG, cuDNN deterministic mode, and disables
219 /// cuDNN benchmark. May reduce throughput but guarantees reproducibility.
220 pub fn with_deterministic(mut self, deterministic: bool) -> Self {
221 self.deterministic = deterministic;
222 self
223 }
224
225 /// Set random seed for reproducibility
226 pub fn with_seed(mut self, seed: u64) -> Self {
227 self.seed = seed;
228 self
229 }
230
231 /// Apply deterministic settings to the CUDA environment.
232 ///
233 /// Must be called before any cuBLAS/cuDNN operations.
234 /// Uses `ReproducibilityConfig` from finetune infrastructure.
235 ///
236 /// # Contract (C-DETERM-001)
237 ///
238 /// After calling this, `CUBLAS_WORKSPACE_CONFIG=:4096:8` and
239 /// `CUDNN_DETERMINISTIC=1` are guaranteed set in the process environment.
240 pub fn apply_deterministic_settings(&self) {
241 if self.deterministic {
242 use crate::finetune::ReproducibilityConfig;
243 let repro = ReproducibilityConfig::with_seed(self.seed);
244 repro.apply();
245 }
246 }
247
248 /// Set step profiler report interval (0 = disabled, N = print every N steps)
249 pub fn with_profile_interval(mut self, interval: usize) -> Self {
250 self.profile_interval = interval;
251 self
252 }
253
254 /// Enable LoRA fine-tuning with rank, alpha, and target modules
255 ///
256 /// When LoRA is enabled, only LoRA adapter weights (A, B matrices) and
257 /// layer norms are trainable. Base model weights are frozen.
258 ///
259 /// # Contract (ENT-LoRA-001)
260 /// - Base weights frozen (requires_grad=false)
261 /// - Only LoRA A/B + norms are optimizer targets
262 /// - scale = alpha / rank
263 pub fn with_lora(mut self, rank: usize, alpha: f32, target_modules: Vec<String>) -> Self {
264 self.lora_rank = Some(rank);
265 self.lora_alpha = Some(alpha);
266 self.lora_target_modules = Some(target_modules);
267 self
268 }
269
270 /// Set LoRA+ ratio (ENT-LoRA-006)
271 ///
272 /// LR multiplier for B matrices. Default 1.0 = standard LoRA.
273 /// 16.0 = LoRA+ (Hayou et al. ICML 2024) — B learns 16x faster than A.
274 pub fn with_lora_plus_ratio(mut self, ratio: f32) -> Self {
275 self.lora_plus_ratio = ratio;
276 self
277 }
278
279 /// Enable double quantization for QLoRA (ENT-LoRA-008)
280 pub fn with_double_quantize(mut self, enabled: bool) -> Self {
281 self.double_quantize = enabled;
282 self
283 }
284
285 /// Enable NF4 quantization for QLoRA pretraining (ENT-263)
286 ///
287 /// When enabled with LoRA, frozen base weights are quantized to 4-bit NF4,
288 /// achieving ~8x VRAM compression. Only LoRA adapters and norm weights are
289 /// trainable. Requires `lora_rank` to be set.
290 pub fn with_quantize_nf4(mut self, enabled: bool) -> Self {
291 self.quantize_nf4 = enabled;
292 self
293 }
294
295 /// Check if NF4 quantization is enabled for QLoRA
296 #[must_use]
297 pub fn is_nf4(&self) -> bool {
298 self.quantize_nf4
299 }
300
301 /// Check if LoRA fine-tuning is enabled
302 #[must_use]
303 pub fn is_lora(&self) -> bool {
304 self.lora_rank.is_some()
305 }
306
307 /// Enable distributed training with the given configuration
308 pub fn with_distributed(mut self, config: DistributedTrainConfig) -> Self {
309 self.distributed = Some(config);
310 self
311 }
312
313 /// Check if distributed training is enabled
314 #[must_use]
315 pub fn is_distributed(&self) -> bool {
316 self.distributed.is_some()
317 }
318
319 /// Get world size (1 for single-GPU)
320 #[must_use]
321 pub fn world_size(&self) -> usize {
322 self.distributed.as_ref().map_or(1, |d| d.world_size)
323 }
324
325 /// Get this worker's rank (0 for single-GPU)
326 #[must_use]
327 pub fn rank(&self) -> usize {
328 self.distributed.as_ref().map_or(0, |d| d.rank)
329 }
330}