entrenar/train/transformer_trainer/
trainer.rs1use crate::autograd::{checkpoint, GradScaler};
4use crate::io::{save_model, Model, ModelFormat, ModelMetadata, SaveConfig};
5use crate::lora::LoRALayer;
6use crate::optim::{AdamW, Optimizer};
7use crate::train::{CausalLMLoss, LossFn, MetricsTracker};
8use crate::transformer::Transformer;
9use crate::Tensor;
10use std::path::Path;
11
12use super::batch::LMBatch;
13use super::config::TransformerTrainConfig;
14
15pub struct TransformerTrainer {
17 model: Transformer,
19 loss_fn: CausalLMLoss,
21 optimizer: AdamW,
23 grad_scaler: GradScaler,
25 config: TransformerTrainConfig,
27 pub metrics: MetricsTracker,
29 step: usize,
31 accumulated_loss: f32,
33 accumulated_batches: usize,
35 lora_layers: Option<Vec<LoRALayer>>,
38}
39
40impl TransformerTrainer {
41 pub fn new(config: TransformerTrainConfig) -> Self {
43 let seed_guard = crate::transformer::init::lock_init_seed(config.seed);
51 let model = Transformer::new(&config.model_config);
52 drop(seed_guard);
53 Self::build(model, config)
54 }
55
56 pub fn with_model(model: Transformer, config: TransformerTrainConfig) -> Self {
58 Self::build(model, config)
59 }
60
61 fn build(model: Transformer, config: TransformerTrainConfig) -> Self {
63 let loss_fn = CausalLMLoss::new(config.model_config.vocab_size);
64 let optimizer = AdamW::default_params(config.lr);
65 let grad_scaler = GradScaler::from_config(&config.precision_config);
66
67 let lora_layers = if let Some(rank) = config.lora_rank {
69 let alpha = config.lora_alpha.unwrap_or(rank as f32 * 2.0);
70 let default_targets = vec!["q_proj".to_string(), "v_proj".to_string()];
71 let raw_targets = config.lora_target_modules.as_deref().unwrap_or(&default_targets);
73 let expanded = crate::lora::LoRAConfig::expand_shorthand(raw_targets);
74 let target_modules = expanded.as_slice();
75
76 let mut layers = Vec::new();
77 let hidden_size = config.model_config.hidden_size;
78 let num_kv_heads = config.model_config.num_kv_heads;
79 let head_dim = config.model_config.head_dim();
80 let q_dim = config.model_config.q_dim();
81 let kv_hidden_size = num_kv_heads * head_dim;
82
83 let intermediate = config.model_config.intermediate_size;
84
85 for block in &model.layers {
86 if target_modules.iter().any(|m| m == "q_proj") {
88 layers.push(LoRALayer::new(
89 block.self_attn.w_q.clone(),
90 q_dim,
91 hidden_size,
92 rank,
93 alpha,
94 ));
95 }
96 if target_modules.iter().any(|m| m == "k_proj") {
97 layers.push(LoRALayer::new(
98 block.self_attn.w_k.clone(),
99 kv_hidden_size,
100 hidden_size,
101 rank,
102 alpha,
103 ));
104 }
105 if target_modules.iter().any(|m| m == "v_proj") {
106 layers.push(LoRALayer::new(
107 block.self_attn.w_v.clone(),
108 kv_hidden_size,
109 hidden_size,
110 rank,
111 alpha,
112 ));
113 }
114 if target_modules.iter().any(|m| m == "o_proj") {
115 layers.push(LoRALayer::new(
116 block.self_attn.w_o.clone(),
117 hidden_size,
118 q_dim,
119 rank,
120 alpha,
121 ));
122 }
123 if target_modules.iter().any(|m| m == "gate_proj") {
125 layers.push(LoRALayer::new(
126 block.ffn.w_gate.clone(),
127 intermediate,
128 hidden_size,
129 rank,
130 alpha,
131 ));
132 }
133 if target_modules.iter().any(|m| m == "up_proj") {
134 layers.push(LoRALayer::new(
135 block.ffn.w_up.clone(),
136 intermediate,
137 hidden_size,
138 rank,
139 alpha,
140 ));
141 }
142 if target_modules.iter().any(|m| m == "down_proj") {
143 layers.push(LoRALayer::new(
144 block.ffn.w_down.clone(),
145 hidden_size,
146 intermediate,
147 rank,
148 alpha,
149 ));
150 }
151 }
152
153 let lora_param_count: usize =
154 layers.iter().map(|l| l.rank() * (l.d_in() + l.d_out())).sum();
155 let total_params: usize = model.parameters().iter().map(|p| p.len()).sum();
156 println!(
157 " LoRA enabled: rank={rank}, alpha={alpha}, \
158 {lora_param_count} trainable params ({:.2}% of {total_params})",
159 100.0 * lora_param_count as f64 / total_params as f64
160 );
161
162 Some(layers)
163 } else {
164 None
165 };
166
167 Self {
168 model,
169 loss_fn,
170 optimizer,
171 grad_scaler,
172 config,
173 metrics: MetricsTracker::new(),
174 step: 0,
175 accumulated_loss: 0.0,
176 accumulated_batches: 0,
177 lora_layers,
178 }
179 }
180
181 pub fn forward_single(&self, input_ids: &[u32], target_ids: &[u32]) -> (f32, Tensor, Tensor) {
187 let logits = if let Some(ref lora) = self.lora_layers {
189 self.model.forward_with_lora(input_ids, lora)
191 } else if self.config.checkpoint_config.enabled {
192 checkpoint(|_| self.model.forward(input_ids), &Tensor::zeros(1, false))
193 } else {
194 self.model.forward(input_ids)
195 };
196
197 let targets = Tensor::from_vec(target_ids.iter().map(|&id| id as f32).collect(), false);
199 let loss = self.loss_fn.forward(&logits, &targets);
200 let loss_val = loss.data()[0];
201
202 (loss_val, loss, logits)
203 }
204
205 fn compute_batch_gradients(&self, batch: &LMBatch) -> f32 {
207 let mut total_loss = 0.0;
208
209 for i in 0..batch.batch_size {
210 let Some(input_ids) = batch.get_input(i) else {
211 continue;
212 };
213 let Some(target_ids) = batch.get_target(i) else {
214 continue;
215 };
216
217 let (loss_val, loss, _logits) = self.forward_single(input_ids, target_ids);
218
219 if let Some(backward_op) = loss.backward_op() {
220 backward_op.backward();
221 }
222
223 total_loss += loss_val / self.config.accumulation_steps as f32;
224 }
225
226 total_loss / batch.batch_size as f32
227 }
228
229 fn clip_and_step(&mut self) {
231 if let Some(max_norm) = self.config.base.max_grad_norm {
232 let params = if let Some(ref lora) = self.lora_layers {
233 lora.iter().flat_map(|l| vec![l.lora_a(), l.lora_b()]).collect::<Vec<_>>()
234 } else {
235 self.model.parameters()
236 };
237 let total_norm: f32 = params
238 .iter()
239 .filter_map(|p| p.grad())
240 .map(|g| g.iter().map(|x| x * x).sum::<f32>())
241 .sum::<f32>()
242 .sqrt();
243
244 if total_norm > max_norm {
245 let scale = max_norm / (total_norm + 1e-6);
246 let _ = scale;
247 }
248 }
249
250 if let Some(ref mut lora) = self.lora_layers {
252 let ratio = self.config.lora_plus_ratio;
254 if ratio != 1.0 {
255 for layer in lora.iter_mut() {
256 if let Some(grad) = layer.lora_b_mut().grad() {
257 let scaled = grad.mapv(|g| g * ratio);
258 layer.lora_b_mut().set_grad(scaled);
259 }
260 }
261 }
262
263 let mut params: Vec<&mut Tensor> =
264 lora.iter_mut().flat_map(|l| l.trainable_params()).collect();
265 for layer in &mut self.model.layers {
267 params.push(&mut layer.input_norm.weight);
268 params.push(&mut layer.post_attn_norm.weight);
269 }
270 params.push(&mut self.model.norm.weight);
271 self.optimizer.step_refs(&mut params);
272 } else {
273 let mut params = self.model.parameters_mut();
274 self.optimizer.step_refs(&mut params);
275 }
276
277 self.step += 1;
278 self.metrics.losses.push(self.accumulated_loss);
279 self.metrics.increment_step();
280
281 self.accumulated_loss = 0.0;
282 self.accumulated_batches = 0;
283 }
284
285 pub fn train_batch(&mut self, batch: &LMBatch) -> f32 {
289 if batch.batch_size == 0 {
290 return 0.0;
291 }
292
293 if self.accumulated_batches == 0 {
294 if let Some(ref mut lora) = self.lora_layers {
296 let mut params: Vec<&mut Tensor> =
297 lora.iter_mut().flat_map(|l| l.trainable_params()).collect();
298 for layer in &mut self.model.layers {
299 params.push(&mut layer.input_norm.weight);
300 params.push(&mut layer.post_attn_norm.weight);
301 }
302 params.push(&mut self.model.norm.weight);
303 self.optimizer.zero_grad_refs(&mut params);
304 } else {
305 let mut params = self.model.parameters_mut();
306 self.optimizer.zero_grad_refs(&mut params);
307 }
308 }
309
310 let avg_loss = self.compute_batch_gradients(batch);
311
312 self.accumulated_loss += avg_loss;
313 self.accumulated_batches += 1;
314
315 if self.accumulated_batches >= self.config.accumulation_steps {
316 self.clip_and_step();
317 }
318
319 avg_loss
320 }
321
322 pub fn train_epoch(&mut self, batches: &[LMBatch]) -> f32 {
324 self.train_epoch_with_callback(batches, |_, _, _| {})
325 }
326
327 pub fn train_epoch_with_callback<F>(&mut self, batches: &[LMBatch], mut on_batch: F) -> f32
335 where
336 F: FnMut(usize, f32, &Self),
337 {
338 if batches.is_empty() {
339 return 0.0;
340 }
341
342 let mut total_loss = 0.0;
343 let mut batches_processed = 0;
344
345 for (i, batch) in batches.iter().enumerate() {
346 if let Some(max) = self.config.max_steps {
348 if self.step >= max {
349 break;
350 }
351 }
352
353 let batch_loss = self.train_batch(batch);
354 total_loss += batch_loss;
355 batches_processed += 1;
356 on_batch(i, batch_loss, self);
357 }
358
359 total_loss / batches_processed.max(1) as f32
360 }
361
362 pub fn reached_max_steps(&self) -> bool {
364 self.config.max_steps.is_some_and(|max| self.step >= max)
365 }
366
367 pub fn step(&self) -> usize {
369 self.step
370 }
371
372 pub fn model(&self) -> &Transformer {
374 &self.model
375 }
376
377 pub fn model_mut(&mut self) -> &mut Transformer {
379 &mut self.model
380 }
381
382 pub fn current_lr(&self) -> f32 {
384 let base_lr = self.config.lr;
385
386 if self.step < self.config.warmup_steps {
387 base_lr * (self.step as f32 / self.config.warmup_steps as f32)
389 } else {
390 base_lr
391 }
392 }
393
394 pub fn grad_scaler_stats(&self) -> (f32, usize, usize) {
396 (
397 self.grad_scaler.scale(),
398 self.grad_scaler.overflow_count(),
399 self.grad_scaler.successful_steps(),
400 )
401 }
402
403 pub fn is_mixed_precision(&self) -> bool {
405 self.config.precision_config.is_mixed()
406 }
407
408 pub fn is_checkpointing(&self) -> bool {
410 self.config.checkpoint_config.enabled
411 }
412
413 pub fn is_lora(&self) -> bool {
415 self.lora_layers.is_some()
416 }
417
418 pub fn lora_layers(&self) -> Option<&[LoRALayer]> {
420 self.lora_layers.as_deref()
421 }
422
423 pub fn lora_layers_mut(&mut self) -> Option<&mut Vec<LoRALayer>> {
425 self.lora_layers.as_mut()
426 }
427
428 pub fn save_lora_adapter(
440 &self,
441 output_dir: impl AsRef<Path>,
442 base_model_name: Option<&str>,
443 ) -> crate::Result<()> {
444 let lora = self.lora_layers.as_ref().ok_or_else(|| {
445 crate::error::Error::ConfigError("Cannot save adapter: LoRA not enabled".into())
446 })?;
447
448 let rank = self.config.lora_rank.unwrap_or(8);
449 let alpha = self.config.lora_alpha.unwrap_or(rank as f32 * 2.0);
450 let target_modules = self
451 .config
452 .lora_target_modules
453 .clone()
454 .unwrap_or_else(|| vec!["q_proj".to_string(), "v_proj".to_string()]);
455
456 let expanded = crate::lora::LoRAConfig::expand_shorthand(&target_modules);
458 let lora_config = crate::lora::LoRAConfig::new(rank, alpha)
459 .target_modules(&expanded.iter().map(String::as_str).collect::<Vec<_>>());
460
461 let num_layers = self.model.layers.len();
464
465 let module_paths: Vec<(&str, &str)> = [
467 ("q_proj", "self_attn.q_proj"),
468 ("k_proj", "self_attn.k_proj"),
469 ("v_proj", "self_attn.v_proj"),
470 ("o_proj", "self_attn.o_proj"),
471 ("gate_proj", "mlp.gate_proj"),
472 ("up_proj", "mlp.up_proj"),
473 ("down_proj", "mlp.down_proj"),
474 ]
475 .iter()
476 .filter(|(name, _)| expanded.iter().any(|t| t == *name))
477 .copied()
478 .collect();
479
480 let all_names: Vec<String> = (0..num_layers)
482 .flat_map(|i| {
483 module_paths.iter().map(move |(_, path)| format!("model.layers.{i}.{path}"))
484 })
485 .collect();
486
487 let mut adapters: Vec<(&str, &LoRALayer)> = Vec::new();
488 for (idx, layer) in lora.iter().enumerate() {
489 if idx < all_names.len() {
490 adapters.push((&all_names[idx], layer));
491 }
492 }
493
494 crate::lora::save_adapter_peft(&adapters, &lora_config, base_model_name, output_dir)
495 .map_err(|e| crate::error::Error::Io(e.to_string()))
496 }
497
498 pub fn save(
513 &self,
514 path: impl AsRef<Path>,
515 name: &str,
516 architecture: &str,
517 ) -> crate::Result<()> {
518 let params: Vec<(String, Tensor)> = self
520 .model
521 .named_parameters()
522 .into_iter()
523 .map(|(name, tensor)| (name, tensor.clone()))
524 .collect();
525
526 let metadata = ModelMetadata::new(name, architecture);
527 let model = Model::new(metadata, params);
528 let config = SaveConfig::new(ModelFormat::SafeTensors);
529
530 save_model(&model, path, &config)
531 }
532
533 pub fn save_apr(
548 &self,
549 path: impl AsRef<Path>,
550 name: &str,
551 architecture: &str,
552 ) -> crate::Result<()> {
553 let params: Vec<(String, Tensor)> =
554 self.model.named_parameters().into_iter().map(|(n, t)| (n, t.clone())).collect();
555 let metadata = ModelMetadata::new(name, architecture);
556 let model = Model::new(metadata, params);
557 let config = SaveConfig::new(ModelFormat::Apr);
558 save_model(&model, path, &config)
559 }
560
561 #[must_use]
572 pub fn optimizer_state_sha256(&self) -> String {
573 use sha2::{Digest, Sha256};
574 let mut hasher = Sha256::new();
575 hasher.update(b"aprender-train:adamw:optstate:v1");
576 hasher.update(self.optimizer.step_count().to_le_bytes());
577 let moment_streams: [(&[u8], &[Option<ndarray::Array1<f32>>]); 2] =
578 [(b"m", self.optimizer.first_moments()), (b"v", self.optimizer.second_moments())];
579 for (tag, buffers) in moment_streams {
580 hasher.update(tag);
581 hasher.update((buffers.len() as u64).to_le_bytes());
582 for slot in buffers {
583 match slot {
584 Some(arr) => {
585 hasher.update(b"some");
586 hasher.update((arr.len() as u64).to_le_bytes());
587 let bytes: &[u8] = bytemuck::cast_slice(
588 arr.as_slice().expect("AdamW buffers are contiguous"),
589 );
590 hasher.update(bytes);
591 }
592 None => hasher.update(b"none"),
593 }
594 }
595 }
596 format!("{:x}", hasher.finalize())
597 }
598}