Skip to main content

entrenar/train/transformer_trainer/
config.rs

1//! Configuration for transformer training
2
3use crate::autograd::{CheckpointConfig, MixedPrecisionConfig};
4use crate::train::TrainConfig;
5use crate::transformer::TransformerConfig;
6use std::net::SocketAddr;
7
8/// Role of a node in distributed pretraining.
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
10pub enum DistributedRole {
11    /// Coordinates training: AllReduces gradients, manages checkpoints
12    #[default]
13    Coordinator,
14    /// Computes forward/backward on assigned shard
15    Worker,
16}
17
18/// Compute backend for a distributed worker.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
20pub enum DistributedBackend {
21    /// NVIDIA CUDA
22    Cuda,
23    /// wgpu (cross-platform)
24    Wgpu,
25    /// Auto-detect best available
26    #[default]
27    Auto,
28}
29
30/// Configuration for distributed pretraining (DDP).
31///
32/// Specifies this worker's role, rank, and communication topology.
33/// All workers must agree on `world_size`. The coordinator address
34/// is where workers connect and where AllReduce is orchestrated.
35///
36/// # Contract
37///
38/// C-DDP-001: After AllReduce + optimizer step, all workers hold identical weights.
39#[derive(Debug, Clone)]
40pub struct DistributedTrainConfig {
41    /// Total number of workers participating
42    pub world_size: usize,
43    /// This worker's global rank (0-indexed)
44    pub rank: usize,
45    /// This worker's local rank on its machine (for multi-GPU)
46    pub local_rank: usize,
47    /// Role: coordinator (rank 0) or worker
48    pub role: DistributedRole,
49    /// Address for coordinator to bind / workers to connect
50    pub coordinator_addr: SocketAddr,
51    /// Compute backend for this worker
52    pub backend: DistributedBackend,
53}
54
55/// Configuration for transformer training
56#[allow(clippy::struct_excessive_bools)]
57#[derive(Debug, Clone)]
58pub struct TransformerTrainConfig {
59    /// Base training configuration
60    pub base: TrainConfig,
61    /// Transformer architecture configuration
62    pub model_config: TransformerConfig,
63    /// Checkpoint configuration for memory efficiency
64    pub checkpoint_config: CheckpointConfig,
65    /// Mixed-precision configuration
66    pub precision_config: MixedPrecisionConfig,
67    /// Maximum sequence length
68    pub max_seq_len: usize,
69    /// Accumulation steps for gradient accumulation
70    pub accumulation_steps: usize,
71    /// Warmup steps for learning rate scheduler
72    pub warmup_steps: usize,
73    /// Learning rate
74    pub lr: f32,
75    /// Maximum training steps (stop after this many optimizer steps)
76    pub max_steps: Option<usize>,
77    /// Use CUDA GPU training when available (default: true = auto-detect)
78    pub use_cuda: bool,
79    /// AdamW beta1 (default: 0.9)
80    pub beta1: f32,
81    /// AdamW beta2 (default: 0.999)
82    pub beta2: f32,
83    /// AdamW weight decay (default: 0.01)
84    pub weight_decay: f32,
85    /// Distributed training configuration (None = single-GPU)
86    pub distributed: Option<DistributedTrainConfig>,
87    /// Enable bitwise deterministic training (CUBLAS_WORKSPACE_CONFIG, cuDNN deterministic)
88    /// Contract: C-DETERM-001
89    pub deterministic: bool,
90    /// Random seed for reproducibility
91    pub seed: u64,
92    /// KAIZEN-047: Step profiler report interval (0 = disabled, N = print every N steps)
93    pub profile_interval: usize,
94    /// LoRA rank (None = full fine-tuning, Some(r) = LoRA with rank r)
95    pub lora_rank: Option<usize>,
96    /// LoRA alpha scaling factor (default: 2 * rank)
97    pub lora_alpha: Option<f32>,
98    /// LoRA target modules (e.g., `q_proj`, `v_proj`)
99    pub lora_target_modules: Option<Vec<String>>,
100    /// LoRA+ ratio: LR multiplier for B matrices (ENT-LoRA-006)
101    /// Default 1.0 = standard LoRA. 16.0 = LoRA+ (Hayou et al. ICML 2024)
102    pub lora_plus_ratio: f32,
103    /// Double quantization for QLoRA (ENT-LoRA-008)
104    /// Quantizes FP32 absmax constants to 8-bit, saving ~0.37 bits/param
105    pub double_quantize: bool,
106    /// Quantize frozen base weights to NF4 (4-bit) for QLoRA pretraining (ENT-263)
107    ///
108    /// When enabled with LoRA, uses `CudaNf4TransformerBlock` instead of
109    /// `CudaTransformerBlock`, achieving ~8x VRAM compression on frozen weights.
110    /// Only LoRA adapters and norm weights remain trainable in fp32.
111    pub quantize_nf4: bool,
112}
113
114impl TransformerTrainConfig {
115    /// Create new config with defaults
116    pub fn new(model_config: TransformerConfig) -> Self {
117        Self {
118            base: TrainConfig::default(),
119            model_config,
120            checkpoint_config: CheckpointConfig::disabled(),
121            precision_config: MixedPrecisionConfig::fp32(),
122            max_seq_len: 512,
123            accumulation_steps: 1,
124            warmup_steps: 0,
125            lr: 0.001,
126            max_steps: None,
127            use_cuda: true,
128            beta1: 0.9,
129            beta2: 0.999,
130            weight_decay: 0.01,
131            distributed: None,
132            deterministic: false,
133            seed: 42,
134            profile_interval: 0,
135            lora_rank: None,
136            lora_alpha: None,
137            lora_target_modules: None,
138            lora_plus_ratio: 1.0,
139            double_quantize: false,
140            quantize_nf4: false,
141        }
142    }
143
144    /// Enable gradient checkpointing
145    pub fn with_checkpointing(mut self, num_segments: usize) -> Self {
146        self.checkpoint_config = CheckpointConfig::enabled(num_segments);
147        self
148    }
149
150    /// Enable bf16 mixed precision
151    pub fn with_bf16(mut self) -> Self {
152        self.precision_config = MixedPrecisionConfig::bf16();
153        self
154    }
155
156    /// Enable fp16 mixed precision with dynamic loss scaling
157    pub fn with_fp16(mut self) -> Self {
158        self.precision_config = MixedPrecisionConfig::fp16();
159        self
160    }
161
162    /// Set maximum sequence length
163    pub fn with_max_seq_len(mut self, len: usize) -> Self {
164        self.max_seq_len = len;
165        self
166    }
167
168    /// Set gradient accumulation steps
169    pub fn with_accumulation_steps(mut self, steps: usize) -> Self {
170        self.accumulation_steps = steps.max(1);
171        self
172    }
173
174    /// Set warmup steps
175    pub fn with_warmup_steps(mut self, steps: usize) -> Self {
176        self.warmup_steps = steps;
177        self
178    }
179
180    /// Set learning rate
181    pub fn with_lr(mut self, lr: f32) -> Self {
182        self.lr = lr;
183        self
184    }
185
186    /// Set gradient clipping
187    pub fn with_grad_clip(mut self, clip: f32) -> Self {
188        self.base.max_grad_norm = Some(clip);
189        self
190    }
191
192    /// Set maximum training steps
193    pub fn with_max_steps(mut self, steps: usize) -> Self {
194        self.max_steps = Some(steps);
195        self
196    }
197
198    /// Enable or disable CUDA GPU training (default: true = auto-detect)
199    pub fn with_use_cuda(mut self, use_cuda: bool) -> Self {
200        self.use_cuda = use_cuda;
201        self
202    }
203
204    /// Set AdamW beta2 (default: 0.999)
205    pub fn with_beta2(mut self, beta2: f32) -> Self {
206        self.beta2 = beta2;
207        self
208    }
209
210    /// Set AdamW weight decay (default: 0.01)
211    pub fn with_weight_decay(mut self, wd: f32) -> Self {
212        self.weight_decay = wd;
213        self
214    }
215
216    /// Enable bitwise deterministic training (C-DETERM-001)
217    ///
218    /// Sets CUBLAS_WORKSPACE_CONFIG, cuDNN deterministic mode, and disables
219    /// cuDNN benchmark. May reduce throughput but guarantees reproducibility.
220    pub fn with_deterministic(mut self, deterministic: bool) -> Self {
221        self.deterministic = deterministic;
222        self
223    }
224
225    /// Set random seed for reproducibility
226    pub fn with_seed(mut self, seed: u64) -> Self {
227        self.seed = seed;
228        self
229    }
230
231    /// Apply deterministic settings to the CUDA environment.
232    ///
233    /// Must be called before any cuBLAS/cuDNN operations.
234    /// Uses `ReproducibilityConfig` from finetune infrastructure.
235    ///
236    /// # Contract (C-DETERM-001)
237    ///
238    /// After calling this, `CUBLAS_WORKSPACE_CONFIG=:4096:8` and
239    /// `CUDNN_DETERMINISTIC=1` are guaranteed set in the process environment.
240    pub fn apply_deterministic_settings(&self) {
241        if self.deterministic {
242            use crate::finetune::ReproducibilityConfig;
243            let repro = ReproducibilityConfig::with_seed(self.seed);
244            repro.apply();
245        }
246    }
247
248    /// Set step profiler report interval (0 = disabled, N = print every N steps)
249    pub fn with_profile_interval(mut self, interval: usize) -> Self {
250        self.profile_interval = interval;
251        self
252    }
253
254    /// Enable LoRA fine-tuning with rank, alpha, and target modules
255    ///
256    /// When LoRA is enabled, only LoRA adapter weights (A, B matrices) and
257    /// layer norms are trainable. Base model weights are frozen.
258    ///
259    /// # Contract (ENT-LoRA-001)
260    /// - Base weights frozen (requires_grad=false)
261    /// - Only LoRA A/B + norms are optimizer targets
262    /// - scale = alpha / rank
263    pub fn with_lora(mut self, rank: usize, alpha: f32, target_modules: Vec<String>) -> Self {
264        self.lora_rank = Some(rank);
265        self.lora_alpha = Some(alpha);
266        self.lora_target_modules = Some(target_modules);
267        self
268    }
269
270    /// Set LoRA+ ratio (ENT-LoRA-006)
271    ///
272    /// LR multiplier for B matrices. Default 1.0 = standard LoRA.
273    /// 16.0 = LoRA+ (Hayou et al. ICML 2024) — B learns 16x faster than A.
274    pub fn with_lora_plus_ratio(mut self, ratio: f32) -> Self {
275        self.lora_plus_ratio = ratio;
276        self
277    }
278
279    /// Enable double quantization for QLoRA (ENT-LoRA-008)
280    pub fn with_double_quantize(mut self, enabled: bool) -> Self {
281        self.double_quantize = enabled;
282        self
283    }
284
285    /// Enable NF4 quantization for QLoRA pretraining (ENT-263)
286    ///
287    /// When enabled with LoRA, frozen base weights are quantized to 4-bit NF4,
288    /// achieving ~8x VRAM compression. Only LoRA adapters and norm weights are
289    /// trainable. Requires `lora_rank` to be set.
290    pub fn with_quantize_nf4(mut self, enabled: bool) -> Self {
291        self.quantize_nf4 = enabled;
292        self
293    }
294
295    /// Check if NF4 quantization is enabled for QLoRA
296    #[must_use]
297    pub fn is_nf4(&self) -> bool {
298        self.quantize_nf4
299    }
300
301    /// Check if LoRA fine-tuning is enabled
302    #[must_use]
303    pub fn is_lora(&self) -> bool {
304        self.lora_rank.is_some()
305    }
306
307    /// Enable distributed training with the given configuration
308    pub fn with_distributed(mut self, config: DistributedTrainConfig) -> Self {
309        self.distributed = Some(config);
310        self
311    }
312
313    /// Check if distributed training is enabled
314    #[must_use]
315    pub fn is_distributed(&self) -> bool {
316        self.distributed.is_some()
317    }
318
319    /// Get world size (1 for single-GPU)
320    #[must_use]
321    pub fn world_size(&self) -> usize {
322        self.distributed.as_ref().map_or(1, |d| d.world_size)
323    }
324
325    /// Get this worker's rank (0 for single-GPU)
326    #[must_use]
327    pub fn rank(&self) -> usize {
328        self.distributed.as_ref().map_or(0, |d| d.rank)
329    }
330}