1#![allow(dead_code)]
2const QWEN_HIDDEN_SIZE: usize = 896;
6const QWEN_NUM_ATTENTION_HEADS: usize = 14;
7const QWEN_NUM_KV_HEADS: usize = 2;
8const QWEN_INTERMEDIATE_SIZE: usize = 4864;
9const QWEN_NUM_HIDDEN_LAYERS: usize = 24;
10const QWEN_VOCAB_SIZE: usize = 151936;
11const QWEN_MAX_POSITION_EMBEDDINGS: usize = 32768;
12const QWEN_ROPE_THETA: f64 = 1_000_000.0;
13
14use super::batches::load_training_batches;
15use crate::config::schema::{ModelMode, TrainSpec};
16use crate::config::validate::validate_config;
17use crate::error::{Error, Result};
18use crate::monitor::tui::state::{TrainingSnapshot, TrainingState, TrainingStatus};
19use crate::storage::{ExperimentStorage, ParameterValue, RunStatus, SqliteBackend};
20use crate::tokenizer::HfTokenizer;
21use crate::trace::TRACER;
22#[cfg(feature = "cuda")]
23use crate::train::CudaTransformerTrainer;
24use crate::train::{LMBatch, TransformerTrainConfig, TransformerTrainer};
25use crate::transformer::{
26 load_safetensors_weights, Architecture, ModelArchitecture, Transformer, TransformerConfig,
27};
28use crate::yaml_mode;
29use std::fs;
30use std::path::{Path, PathBuf};
31use std::time::{SystemTime, UNIX_EPOCH};
32
33pub fn train_from_yaml<P: AsRef<Path>>(config_path: P) -> Result<()> {
51 let spec = load_config(config_path)?;
52
53 match spec.model.mode {
55 ModelMode::Transformer => train_transformer_from_spec(&spec),
56 ModelMode::Tabular => train_tabular_from_spec(&spec),
57 }
58}
59
60fn build_train_config(
62 model_config: crate::transformer::TransformerConfig,
63 spec: &TrainSpec,
64) -> TransformerTrainConfig {
65 let mut config = TransformerTrainConfig::new(model_config)
66 .with_lr(spec.optimizer.lr)
67 .with_warmup_steps(spec.training.warmup_steps)
68 .with_max_seq_len({
69 let seq_len = spec.data.seq_len.unwrap_or_else(|| {
70 eprintln!("Warning: seq_len not specified in config, defaulting to 512");
71 512
72 });
73 seq_len
74 });
75
76 if let Some(clip) = spec.training.grad_clip {
77 config = config.with_grad_clip(clip);
78 }
79
80 if let Some(v) = spec.optimizer.params.get("beta2").and_then(serde_json::Value::as_f64) {
82 config = config.with_beta2(v as f32);
83 }
84 if let Some(v) = spec.optimizer.params.get("weight_decay").and_then(serde_json::Value::as_f64) {
85 config = config.with_weight_decay(v as f32);
86 }
87
88 if let Some(accum) = spec.training.gradient_accumulation {
89 config = config.with_accumulation_steps(accum);
90 if accum > 1 {
91 let eff_batch = spec.data.batch_size * accum * spec.data.seq_len.unwrap_or(1024);
92 println!(" Gradient accumulation: {accum} (effective batch: {eff_batch} tokens/step)");
93 }
94 }
95
96 if let Some(max_steps) = spec.training.max_steps {
97 config = config.with_max_steps(max_steps);
98 }
99
100 if let Some(ref precision) = spec.training.mixed_precision {
102 match precision.as_str() {
103 "bf16" => config = config.with_bf16(),
104 "fp16" => config = config.with_fp16(),
105 "fp32" => {}
106 other => {
107 eprintln!("Warning: unknown mixed_precision value '{other}', defaulting to fp32");
108 }
109 }
110 }
111
112 if let Some(num_segments) = spec.training.checkpoints {
114 config = config.with_checkpointing(num_segments);
115 }
116
117 if spec.training.deterministic {
119 config = config.with_deterministic(true);
120 }
121 if let Some(seed) = spec.training.seed {
122 config = config.with_seed(seed);
123 }
124
125 if spec.training.profile_interval > 0 {
127 config = config.with_profile_interval(spec.training.profile_interval);
128 }
129
130 if let Some(ref lora) = spec.lora {
132 config = config.with_lora(lora.rank, lora.alpha, lora.target_modules.clone());
133 if lora.lora_plus_ratio != 1.0 {
135 config = config.with_lora_plus_ratio(lora.lora_plus_ratio);
136 }
137 if lora.double_quantize {
139 config = config.with_double_quantize(true);
140 }
141 if lora.quantize_base {
143 config = config.with_quantize_nf4(true);
144 }
145 }
146
147 if let Some(ref dist) = spec.training.distributed {
149 use crate::train::{DistributedBackend, DistributedRole, DistributedTrainConfig};
150
151 let role = match dist.role.as_str() {
152 "worker" => DistributedRole::Worker,
153 _ => DistributedRole::Coordinator,
154 };
155 let backend = match dist.backend.as_str() {
156 "cuda" => DistributedBackend::Cuda,
157 "wgpu" => DistributedBackend::Wgpu,
158 _ => DistributedBackend::Auto,
159 };
160 let addr: std::net::SocketAddr =
161 dist.coordinator_addr.parse().unwrap_or_else(|_| "0.0.0.0:9000".parse().unwrap());
162
163 config = config.with_distributed(DistributedTrainConfig {
164 world_size: dist.world_size,
165 rank: dist.rank,
166 local_rank: dist.local_rank,
167 role,
168 coordinator_addr: addr,
169 backend,
170 });
171 }
172
173 config
174}
175
176fn train_transformer_from_spec(spec: &TrainSpec) -> Result<()> {
180 println!("✓ Config loaded and validated (Transformer mode)");
181 println!(" Model: {}", spec.model.path.display());
182 println!(" Optimizer: {} (lr={})", spec.optimizer.name, spec.optimizer.lr);
183 println!(" Batch size: {}", spec.data.batch_size);
184 println!(" Epochs: {}", spec.training.epochs);
185 println!(" Training mode: {:?}", spec.training.mode);
186
187 if let Some(lora) = &spec.lora {
188 println!(" LoRA: rank={}, alpha={}", lora.rank, lora.alpha);
189 if lora.quantize_base {
190 println!(" QLoRA: NF4 quantized base weights (~8x VRAM compression)");
191 }
192 }
193 println!();
194
195 let model_config = build_transformer_config_from_spec(spec)?;
197
198 let resolved_path = resolve_model_path(&spec.model.path)?;
200
201 crate::transformer::init::set_init_seed(spec.training.seed.unwrap_or(42));
203
204 #[cfg(feature = "cuda")]
207 let (transformer, checkpoint_step) =
208 load_transformer_model(&resolved_path, &model_config, &spec.training.output_dir)?;
209 #[cfg(not(feature = "cuda"))]
210 let (transformer, _checkpoint_step) =
211 load_transformer_model(&resolved_path, &model_config, &spec.training.output_dir)?;
212
213 let train_config = build_train_config(model_config, spec);
215
216 train_config.apply_deterministic_settings();
218
219 println!("Loading training data...");
221 let batches = load_lm_batches(spec)?;
222 println!("✓ {} LM batches created", batches.len());
223 println!();
224
225 #[cfg(feature = "cuda")]
227 if train_config.use_cuda {
228 let cuda_config = train_config.clone();
229 let cuda_result = match transformer {
230 Some(loaded_model) => CudaTransformerTrainer::with_model(loaded_model, cuda_config),
231 None => CudaTransformerTrainer::new(cuda_config),
232 };
233
234 match cuda_result {
235 Ok(mut cuda_trainer) => {
236 if checkpoint_step > 0 {
238 cuda_trainer.set_initial_step(checkpoint_step);
239 println!(
240 " Resumed at step {checkpoint_step} (lr={:.2e})",
241 cuda_trainer.current_lr()
242 );
243 let apr_loaded = find_latest_apr_checkpoint(&spec.training.output_dir)
245 .is_some_and(|p| {
246 let (restored, total) = cuda_trainer.restore_lora_from_apr(&p);
248 if restored > 0 {
249 println!(" ✓ LoRA adapters restored ({restored}/{total} layers)");
250 }
251 cuda_trainer.load_optimizer_state_apr(&p)
252 });
253 if apr_loaded {
254 println!(" ✓ Embedding optimizer state restored (APR)");
255 } else if cuda_trainer.load_optimizer_state(&spec.training.output_dir) {
256 println!(" ✓ Embedding optimizer state restored (JSON)");
257 }
258 }
259 println!("✓ CudaTransformerTrainer initialized (GPU: {})", cuda_trainer.gpu_name());
260 if train_config.distributed.is_some() {
262 return train_loop_cuda_distributed(cuda_trainer, &batches, spec);
263 }
264 return train_loop_cuda(&mut cuda_trainer, &batches, spec);
265 }
266 Err(e) => {
267 eprintln!("Warning: CUDA training failed ({e}), falling back to CPU");
268 let mut trainer = TransformerTrainer::new(train_config);
270 println!("✓ TransformerTrainer initialized (CPU fallback)");
271 println!(" Mixed precision: {}", trainer.is_mixed_precision());
272 println!(" Checkpointing: {}", trainer.is_checkpointing());
273 println!();
274 return train_loop_cpu(&mut trainer, &batches, spec);
275 }
276 }
277 }
278
279 let mut trainer = if let Some(loaded_model) = transformer {
281 TransformerTrainer::with_model(loaded_model, train_config)
282 } else {
283 TransformerTrainer::new(train_config)
284 };
285 println!("✓ TransformerTrainer initialized (CPU)");
286 println!(" Mixed precision: {}", trainer.is_mixed_precision());
287 println!(" Checkpointing: {}", trainer.is_checkpointing());
288 println!();
289
290 train_loop_cpu(&mut trainer, &batches, spec)
291}
292
293fn train_loop_cpu(
295 trainer: &mut TransformerTrainer,
296 batches: &[LMBatch],
297 spec: &TrainSpec,
298) -> Result<()> {
299 println!("Starting transformer training (CPU)...");
300 println!();
301
302 TRACER.enable();
303 TRACER.clear();
304
305 let num_batches = batches.len();
306 let start_time = std::time::Instant::now();
307 let log_interval = (num_batches / 100).clamp(1, 100);
308
309 let state = TrainingState::new(&spec.training.output_dir);
311 let start_ms = now_ms();
312 let total_epochs = spec.training.epochs;
313
314 let mut tracker = PretrainTracker::open(spec, "CPU");
316
317 write_training_snapshot(
318 &state,
319 start_ms,
320 0,
321 total_epochs,
322 0,
323 num_batches,
324 0.0,
325 &[],
326 0.0,
327 0.0,
328 TrainingStatus::Initializing,
329 spec,
330 "CPU",
331 );
332
333 if let Some(max_steps) = spec.training.max_steps {
334 println!(" max_steps: {max_steps} (will stop early when reached)");
335 }
336
337 let mut loss_history: Vec<f32> = Vec::new();
338
339 for epoch in 0..spec.training.epochs {
340 let epoch_start = std::time::Instant::now();
341 let avg_loss =
342 trainer.train_epoch_with_callback(batches, |batch_idx, batch_loss, trainer| {
343 loss_history.push(batch_loss);
344 if loss_history.len() > 100 {
345 loss_history.remove(0);
346 }
347
348 if (batch_idx + 1) % log_interval == 0 || batch_idx == 0 {
349 let elapsed = epoch_start.elapsed().as_secs_f64();
350 let batches_done = batch_idx + 1;
351 let seq_len = spec.data.seq_len.unwrap_or(128);
352 let tokens_done = batches_done * spec.data.batch_size * seq_len;
353 let batch_per_sec = batches_done as f64 / elapsed.max(0.001);
354 let remaining = (num_batches - batches_done) as f64 / batch_per_sec.max(0.001);
355 let tok_per_sec = tokens_done as f64 / elapsed.max(0.001);
356 println!(
357 " [{}/{} batches] step={} loss={:.4} lr={:.2e} tok/s={:.0} eta={:.0}s",
358 batches_done,
359 num_batches,
360 trainer.step(),
361 batch_loss,
362 trainer.current_lr(),
363 tok_per_sec,
364 remaining,
365 );
366
367 write_training_snapshot(
369 &state,
370 start_ms,
371 epoch + 1,
372 total_epochs,
373 trainer.step(),
374 num_batches,
375 batch_loss,
376 &loss_history,
377 trainer.current_lr(),
378 tok_per_sec as f32,
379 TrainingStatus::Running,
380 spec,
381 "CPU",
382 );
383
384 tracker.log_step(
386 trainer.step() as u64,
387 batch_loss,
388 trainer.current_lr(),
389 tok_per_sec as f32,
390 );
391 }
392 });
393 let ppl = crate::train::perplexity(avg_loss);
394 println!(
395 "Epoch {}/{}: loss={:.6}, perplexity={:.2}, time={:.1}s",
396 epoch + 1,
397 spec.training.epochs,
398 avg_loss,
399 ppl,
400 epoch_start.elapsed().as_secs_f64(),
401 );
402
403 if trainer.reached_max_steps() {
404 println!(
405 "Reached max_steps={}, stopping training.",
406 spec.training.max_steps.unwrap_or(0)
407 );
408 break;
409 }
410 }
411
412 let total_time = start_time.elapsed();
413 println!("Total training time: {:.1}s", total_time.as_secs_f64());
414 println!("{}", TRACER.report());
415
416 let final_loss = trainer.metrics.losses.last().copied().unwrap_or(0.0);
418 write_training_snapshot(
419 &state,
420 start_ms,
421 total_epochs,
422 total_epochs,
423 trainer.step(),
424 num_batches,
425 final_loss,
426 &loss_history,
427 trainer.current_lr(),
428 0.0,
429 TrainingStatus::Completed,
430 spec,
431 "CPU",
432 );
433
434 tracker.complete();
436
437 save_trained_model_cpu(trainer, spec)
438}
439
440fn now_ms() -> u64 {
442 SystemTime::now().duration_since(UNIX_EPOCH).map(|d| d.as_millis() as u64).unwrap_or(0)
443}
444
445fn query_gpu_telemetry(device_name: &str) -> Option<crate::monitor::tui::state::GpuTelemetry> {
451 let output = std::process::Command::new("nvidia-smi")
452 .args([
453 "--query-gpu=utilization.gpu,memory.used,memory.total,temperature.gpu,power.draw,power.limit",
454 "--format=csv,noheader,nounits",
455 ])
456 .output()
457 .ok()?;
458
459 if !output.status.success() {
460 return None;
461 }
462
463 let stdout = String::from_utf8_lossy(&output.stdout);
464 let line = stdout.lines().next()?.trim();
465 let fields: Vec<&str> = line.split(',').map(str::trim).collect();
466 if fields.len() < 6 {
467 return None;
468 }
469
470 Some(crate::monitor::tui::state::GpuTelemetry {
471 device_name: device_name.to_string(),
472 utilization_percent: fields[0].parse().unwrap_or(0.0),
473 vram_used_gb: fields[1].parse::<f32>().unwrap_or(0.0) / 1024.0, vram_total_gb: fields[2].parse::<f32>().unwrap_or(0.0) / 1024.0,
475 temperature_celsius: fields[3].parse().unwrap_or(0.0),
476 power_watts: fields[4].parse().unwrap_or(0.0),
477 power_limit_watts: fields[5].parse().unwrap_or(0.0),
478 processes: Vec::new(),
479 })
480}
481
482fn write_training_snapshot(
488 state: &TrainingState,
489 start_ms: u64,
490 epoch: usize,
491 total_epochs: usize,
492 step: usize,
493 steps_per_epoch: usize,
494 loss: f32,
495 loss_history: &[f32],
496 lr: f32,
497 tokens_per_second: f32,
498 status: TrainingStatus,
499 spec: &TrainSpec,
500 gpu_name: &str,
501) {
502 let snapshot = TrainingSnapshot {
503 timestamp_ms: now_ms(),
504 epoch,
505 total_epochs,
506 step,
507 steps_per_epoch,
508 loss,
509 loss_history: loss_history.to_vec(),
510 learning_rate: lr,
511 lr_history: Vec::new(),
512 gradient_norm: 0.0, tokens_per_second,
514 start_timestamp_ms: start_ms,
515 gpu: query_gpu_telemetry(gpu_name).or_else(|| {
516 Some(crate::monitor::tui::state::GpuTelemetry {
517 device_name: gpu_name.to_string(),
518 ..Default::default()
519 })
520 }),
521 sample: None,
522 status,
523 experiment_id: spec.training.output_dir.display().to_string(),
524 model_name: spec.model.path.display().to_string(),
525 model_path: spec.model.path.display().to_string(),
526 optimizer_name: spec.optimizer.name.clone(),
527 batch_size: spec.data.batch_size,
528 checkpoint_path: spec.training.output_dir.display().to_string(),
529 executable_path: String::new(),
530 accuracy: 0.0,
531 samples_per_second: 0.0,
532 };
533 if let Err(e) = state.write(&snapshot) {
534 eprintln!("[ALB-045] Failed to write training_state.json: {e}");
535 }
536}
537
538struct PretrainTracker {
550 local: Option<SqliteBackend>,
551 global: Option<SqliteBackend>,
552 run_id: Option<String>,
553 global_run_id: Option<String>,
554}
555
556impl PretrainTracker {
557 fn open(spec: &TrainSpec, device: &str) -> Self {
559 let exp_name =
560 spec.training.output_dir.file_name().and_then(|n| n.to_str()).unwrap_or("pretrain");
561
562 let config_json = serde_json::json!({
563 "task": "pretrain",
564 "model": spec.model.path.display().to_string(),
565 "optimizer": &spec.optimizer.name,
566 "lr": spec.optimizer.lr,
567 "epochs": spec.training.epochs,
568 "batch_size": spec.data.batch_size,
569 "seq_len": spec.data.seq_len,
570 "max_steps": spec.training.max_steps,
571 "device": device,
572 "output_dir": spec.training.output_dir.display().to_string(),
573 });
574
575 let local = SqliteBackend::open_project(&spec.training.output_dir).ok();
577
578 let global = dirs::home_dir().map(|h| h.join(".entrenar")).and_then(|p| {
580 fs::create_dir_all(&p).ok()?;
581 SqliteBackend::open(p.join("experiments.db").to_string_lossy().as_ref()).ok()
582 });
583
584 let mut tracker = Self { local, global, run_id: None, global_run_id: None };
585
586 if let Some(store) = tracker.local.as_mut() {
588 if let Ok(eid) = store.create_experiment(exp_name, Some(config_json.clone())) {
589 if let Ok(rid) = store.create_run(&eid) {
590 let _ = store.start_run(&rid);
591 log_run_params(store, &rid, spec, device);
592 tracker.run_id = Some(rid);
593 }
594 }
595 }
596
597 if let Some(store) = tracker.global.as_mut() {
599 if let Ok(eid) = store.create_experiment(exp_name, Some(config_json)) {
600 if let Ok(rid) = store.create_run(&eid) {
601 let _ = store.start_run(&rid);
602 log_run_params(store, &rid, spec, device);
603 tracker.global_run_id = Some(rid);
604 }
605 }
606 }
607
608 tracker
609 }
610
611 fn log_step(&mut self, step: u64, loss: f32, lr: f32, tok_per_sec: f32) {
613 for (store, run_id) in [
614 (self.local.as_mut(), self.run_id.as_deref()),
615 (self.global.as_mut(), self.global_run_id.as_deref()),
616 ] {
617 if let (Some(s), Some(rid)) = (store, run_id) {
618 let _ = s.log_metric(rid, "loss", step, f64::from(loss));
619 let _ = s.log_metric(rid, "learning_rate", step, f64::from(lr));
620 let _ = s.log_metric(rid, "tokens_per_second", step, f64::from(tok_per_sec));
621 }
622 }
623 }
624
625 fn complete(&mut self) {
627 for (store, run_id) in [
628 (self.local.as_mut(), self.run_id.as_deref()),
629 (self.global.as_mut(), self.global_run_id.as_deref()),
630 ] {
631 if let (Some(s), Some(rid)) = (store, run_id) {
632 let _ = s.complete_run(rid, RunStatus::Success);
633 }
634 }
635 }
636
637 #[allow(dead_code)]
639 fn fail(&mut self) {
640 for (store, run_id) in [
641 (self.local.as_mut(), self.run_id.as_deref()),
642 (self.global.as_mut(), self.global_run_id.as_deref()),
643 ] {
644 if let (Some(s), Some(rid)) = (store, run_id) {
645 let _ = s.complete_run(rid, RunStatus::Failed);
646 }
647 }
648 }
649}
650
651fn log_run_params(store: &SqliteBackend, run_id: &str, spec: &TrainSpec, device: &str) {
653 let _ = store.log_param(run_id, "task", ParameterValue::String("pretrain".into()));
654 let _ = store.log_param(
655 run_id,
656 "model",
657 ParameterValue::String(spec.model.path.display().to_string()),
658 );
659 let _ =
660 store.log_param(run_id, "optimizer", ParameterValue::String(spec.optimizer.name.clone()));
661 let _ = store.log_param(
662 run_id,
663 "learning_rate",
664 ParameterValue::Float(f64::from(spec.optimizer.lr)),
665 );
666 let _ = store.log_param(run_id, "epochs", ParameterValue::Int(spec.training.epochs as i64));
667 let _ = store.log_param(run_id, "batch_size", ParameterValue::Int(spec.data.batch_size as i64));
668 let _ = store.log_param(run_id, "device", ParameterValue::String(device.to_string()));
669 let _ = store.log_param(
670 run_id,
671 "output_dir",
672 ParameterValue::String(spec.training.output_dir.display().to_string()),
673 );
674 if let Some(seq_len) = spec.data.seq_len {
675 let _ = store.log_param(run_id, "seq_len", ParameterValue::Int(seq_len as i64));
676 }
677 if let Some(max_steps) = spec.training.max_steps {
678 let _ = store.log_param(run_id, "max_steps", ParameterValue::Int(max_steps as i64));
679 }
680}
681
682fn print_max_steps(max_steps: Option<usize>) {
685 if let Some(ms) = max_steps {
686 println!(" max_steps: {ms} (will stop early when reached)");
687 }
688}
689
690#[cfg(feature = "cuda")]
696fn train_loop_cuda(
697 trainer: &mut CudaTransformerTrainer,
698 batches: &[LMBatch],
699 spec: &TrainSpec,
700) -> Result<()> {
701 use std::sync::atomic::{AtomicBool, Ordering};
702 use std::sync::Arc;
703
704 println!("Starting transformer training (CUDA GPU-resident)...");
705 println!();
706
707 let num_batches = batches.len();
708
709 if spec.training.max_steps.is_none() {
712 let total_steps = spec.training.epochs * num_batches;
713 trainer.set_max_steps(total_steps);
714 println!(
715 " max_steps: {total_steps} (auto: {epochs}×{num_batches})",
716 epochs = spec.training.epochs
717 );
718 }
719
720 let start_time = std::time::Instant::now();
721 let log_interval = (num_batches / 100).clamp(1, 100);
725 let save_interval = spec.training.save_interval;
726 let max_checkpoints = spec.training.max_checkpoints;
727
728 let eval_interval =
730 if spec.training.eval_interval > 0 { spec.training.eval_interval } else { save_interval };
731 let patience = spec.training.patience;
732 let mut best_val_loss: f32 = f32::INFINITY;
733 let mut evals_without_improvement: usize = 0;
734 let mut last_eval_step: usize = 0;
735
736 let state = TrainingState::new(&spec.training.output_dir);
738 let start_ms = now_ms();
739 let gpu_name = trainer.gpu_name();
740 let total_epochs = spec.training.epochs;
741
742 let mut tracker = PretrainTracker::open(spec, &gpu_name);
744
745 let num_params = trainer.num_params();
747 let seq_len = spec.data.seq_len.unwrap_or(128);
748 let tokens_per_batch = spec.data.batch_size * seq_len;
749 let gpu_peak_tflops: f64 = 82.58e12;
751
752 let jsonl_path = spec.training.output_dir.join("training_log.jsonl");
754 std::fs::create_dir_all(&spec.training.output_dir).ok();
755 let mut jsonl_file =
756 std::fs::OpenOptions::new().create(true).append(true).open(&jsonl_path).ok();
757 write_jsonl_event_json(
759 &mut jsonl_file,
760 &serde_json::json!({
761 "type": "config",
762 "num_params": num_params,
763 "batch_size": spec.data.batch_size,
764 "seq_len": seq_len,
765 "max_steps": spec.training.max_steps,
766 "epochs": spec.training.epochs,
767 "lr": spec.optimizer.lr,
768 "gpu": &gpu_name,
769 "timestamp": now_ms(),
770 }),
771 );
772
773 let shutdown_flag = Arc::new(AtomicBool::new(false));
775 {
776 let flag = shutdown_flag.clone();
777 let _ = ctrlc::set_handler(move || {
778 flag.store(true, Ordering::SeqCst);
779 eprintln!("\n[SIGINT] Graceful shutdown requested. Saving checkpoint...");
780 });
781 }
782
783 write_training_snapshot(
785 &state,
786 start_ms,
787 0,
788 total_epochs,
789 0,
790 num_batches,
791 0.0,
792 &[],
793 0.0,
794 0.0,
795 TrainingStatus::Initializing,
796 spec,
797 &gpu_name,
798 );
799
800 print_max_steps(spec.training.max_steps);
801
802 if eval_interval != save_interval {
804 println!(" eval_interval: {eval_interval} (decoupled from save_interval={save_interval})");
805 }
806 if patience > 0 {
807 println!(" early_stopping: patience={patience} eval intervals");
808 }
809
810 let mut scaling_predictor = ScalingLawPredictor::new();
812 let tokens_per_step = tokens_per_batch * spec.training.gradient_accumulation.unwrap_or(1);
813
814 let mut loss_history: Vec<f32> = Vec::new();
816 let mut last_save_step: usize = 0;
817
818 let model_name = spec
819 .model
820 .path
821 .file_name()
822 .and_then(|n| n.to_str())
823 .unwrap_or("entrenar-model")
824 .to_string();
825
826 let shuffle = spec.training.shuffle;
828 let seed = spec.training.seed.unwrap_or(42);
829
830 let val_batches = load_val_batches(spec);
832
833 let mut nan_skips: usize = 0;
835
836 let mut gnorm_ema: f64 = 0.0;
838 let mut gnorm_ema_sq: f64 = 0.0;
839 let zclip_alpha: f64 = 0.05; let zclip_threshold: f64 = 2.0; let heartbeat_path = spec.training.output_dir.join("heartbeat");
844
845 let mut loss_ema: f64 = 0.0;
847 let loss_ema_alpha: f64 = 0.05;
848 let loss_spike_threshold: f64 = 3.0; let mut rollback_count: usize = 0;
850 let max_rollbacks: usize = 3;
851
852 let mut gnorm_window: Vec<f64> = Vec::with_capacity(100);
854 let noise_scale_interval: usize = 100;
855 let mut last_noise_scale_step: usize = usize::MAX; write_config_provenance(&mut jsonl_file, spec);
859
860 let mut curriculum_stage: usize = 0;
862 let curriculum = spec.training.curriculum.as_deref();
863 print_curriculum_stages(curriculum);
864
865 let grad_accum = spec.training.gradient_accumulation.unwrap_or(1);
867 let resume_batch_idx = trainer.step() * grad_accum;
868
869 'outer: for epoch in 0..spec.training.epochs {
870 let epoch_start = std::time::Instant::now();
871 let mut total_loss = 0.0;
872 let mut batches_processed = 0;
873
874 let batch_order = shuffled_batch_order(num_batches, shuffle, seed, epoch);
876
877 for (iter_idx, &batch_idx) in batch_order.iter().enumerate() {
879 if iter_idx < resume_batch_idx {
881 continue;
882 }
883 if shutdown_flag.load(Ordering::SeqCst) {
885 handle_graceful_shutdown(
886 trainer,
887 spec,
888 &state,
889 &mut tracker,
890 start_ms,
891 epoch,
892 iter_idx,
893 total_epochs,
894 num_batches,
895 &loss_history,
896 &model_name,
897 &gpu_name,
898 seed,
899 loss_ema,
900 );
901 return Ok(());
902 }
903
904 if reached_max_steps(spec.training.max_steps, trainer.step()) {
906 break 'outer;
907 }
908
909 curriculum_stage = check_curriculum_transition(
911 curriculum,
912 curriculum_stage,
913 trainer.step(),
914 &mut jsonl_file,
915 );
916
917 let batch = &batches[batch_idx];
918 let step_start = std::time::Instant::now();
920 let batch_loss = trainer.train_batch(batch);
921 let step_elapsed = step_start.elapsed();
922
923 if !batch_loss.is_finite() {
925 nan_skips += 1;
926 println!(
927 " [WARN] NaN/Inf loss at step {} (skip #{}) — skipping",
928 trainer.step(),
929 nan_skips
930 );
931 continue;
932 }
933 total_loss += batch_loss;
934 batches_processed += 1;
935
936 if loss_ema == 0.0 {
938 loss_ema = f64::from(batch_loss);
939 }
940
941 detect_loss_spike(
943 batch_loss,
944 trainer.step(),
945 &mut loss_ema,
946 loss_ema_alpha,
947 loss_spike_threshold,
948 &mut rollback_count,
949 max_rollbacks,
950 &mut jsonl_file,
951 );
952
953 zclip_update(
955 f64::from(trainer.last_grad_norm()),
956 trainer.step(),
957 &mut gnorm_ema,
958 &mut gnorm_ema_sq,
959 zclip_alpha,
960 zclip_threshold,
961 );
962
963 update_noise_scale(
965 f64::from(trainer.last_grad_norm()),
966 trainer.step(),
967 &mut gnorm_window,
968 noise_scale_interval,
969 &mut last_noise_scale_step,
970 &mut jsonl_file,
971 );
972
973 write_heartbeat(&heartbeat_path, trainer.step());
975
976 push_capped(&mut loss_history, batch_loss, 100);
978
979 if should_log(iter_idx, log_interval) {
981 log_step_metrics(
982 trainer,
983 &state,
984 &mut tracker,
985 &mut jsonl_file,
986 &epoch_start,
987 &start_time,
988 &step_elapsed,
989 epoch,
990 total_epochs,
991 iter_idx,
992 num_batches,
993 tokens_per_batch,
994 num_params,
995 gpu_peak_tflops,
996 start_ms,
997 batch_loss,
998 &loss_history,
999 spec,
1000 &gpu_name,
1001 );
1002 }
1003
1004 let current_step = trainer.step();
1006 let is_resume_step =
1010 resume_batch_idx > 0 && current_step == resume_batch_idx / grad_accum;
1011 let do_save = !is_resume_step
1012 && should_save_checkpoint(current_step, last_save_step, save_interval);
1013 let do_eval = current_step > 0
1014 && current_step != last_eval_step
1015 && current_step.is_multiple_of(eval_interval);
1016
1017 if do_save {
1018 save_and_validate_checkpoint(
1019 trainer,
1020 spec,
1021 &model_name,
1022 current_step,
1023 epoch,
1024 iter_idx,
1025 max_checkpoints,
1026 seed,
1027 loss_ema,
1028 );
1029 last_save_step = current_step;
1030 }
1031
1032 if do_eval {
1034 last_eval_step = current_step;
1035 let eval_val_loss = run_validation_eval(
1036 trainer,
1037 &val_batches,
1038 current_step,
1039 &mut jsonl_file,
1040 &mut scaling_predictor,
1041 tokens_per_step,
1042 spec.training.max_steps,
1043 );
1044 if let Some(val_loss) = eval_val_loss {
1045 if val_loss < best_val_loss {
1046 best_val_loss = val_loss;
1047 evals_without_improvement = 0;
1048 save_best_model(trainer, spec, &model_name, current_step);
1049 } else {
1050 evals_without_improvement += 1;
1051 }
1052 if patience > 0 && evals_without_improvement >= patience {
1053 println!(
1054 " [early-stop] No improvement for {evals_without_improvement} evals (patience={patience}). \
1055 Best val_loss={best_val_loss:.4}. Stopping.",
1056 );
1057 write_jsonl_event_json(
1058 &mut jsonl_file,
1059 &serde_json::json!({
1060 "type": "early_stop",
1061 "step": current_step,
1062 "best_val_loss": best_val_loss,
1063 "evals_without_improvement": evals_without_improvement,
1064 "patience": patience,
1065 "timestamp": now_ms(),
1066 }),
1067 );
1068 break 'outer;
1069 }
1070 }
1071 }
1072 }
1073
1074 let avg_loss = total_loss / batches_processed.max(1) as f32;
1075 let ppl = crate::train::perplexity(avg_loss);
1076 println!(
1077 "Epoch {}/{}: loss={:.6}, perplexity={:.2}, time={:.1}s",
1078 epoch + 1,
1079 spec.training.epochs,
1080 avg_loss,
1081 ppl,
1082 epoch_start.elapsed().as_secs_f64(),
1083 );
1084
1085 if reached_max_steps(spec.training.max_steps, trainer.step()) {
1086 break;
1087 }
1088 }
1089
1090 let total_time = start_time.elapsed();
1091 println!("Total training time: {:.1}s", total_time.as_secs_f64());
1092
1093 trainer.print_profiler_report();
1095
1096 let final_loss = trainer.metrics.losses.last().copied().unwrap_or(0.0);
1098 write_training_snapshot(
1099 &state,
1100 start_ms,
1101 total_epochs,
1102 total_epochs,
1103 trainer.step(),
1104 num_batches,
1105 final_loss,
1106 &loss_history,
1107 trainer.current_lr(),
1108 0.0,
1109 TrainingStatus::Completed,
1110 spec,
1111 &gpu_name,
1112 );
1113
1114 tracker.complete();
1116
1117 write_jsonl_event_json(
1119 &mut jsonl_file,
1120 &serde_json::json!({
1121 "type": "complete",
1122 "step": trainer.step(),
1123 "final_loss": final_loss,
1124 "total_time_s": total_time.as_secs_f64(),
1125 "timestamp": now_ms(),
1126 }),
1127 );
1128
1129 save_trained_model_cuda(trainer, spec)
1130}
1131
1132#[cfg(feature = "cuda")]
1140fn spawn_coordinator_thread(
1142 coord_addr: std::net::SocketAddr,
1143 world_size: usize,
1144 num_blocks: usize,
1145 total_steps: usize,
1146) -> Result<std::thread::JoinHandle<()>> {
1147 use crate::finetune::distributed::DistributedConfig;
1148 use crate::finetune::GradientServer;
1149
1150 let server_config = DistributedConfig::coordinator(coord_addr, world_size);
1151 let mut server = GradientServer::bind(server_config)
1152 .map_err(|e| Error::ConfigError(format!("GradientServer bind failed: {e}")))?;
1153 println!(" ✓ GradientServer bound on {coord_addr}");
1154
1155 Ok(std::thread::spawn(move || {
1156 server.wait_for_workers().unwrap();
1157 eprintln!("[coordinator] All {world_size} workers connected");
1158
1159 for _step in 0..total_steps {
1160 for block_idx in (0..num_blocks).rev() {
1161 let result =
1162 server.collect_and_reduce_block(_step as u64, block_idx as u32).unwrap();
1163 server.broadcast_averaged_block(_step as u64, &result).unwrap();
1164 }
1165 for component in [0u8, 1, 2] {
1166 let result = server.collect_and_reduce_non_block(_step as u64, component).unwrap();
1167 server.broadcast_averaged_non_block(_step as u64, &result).unwrap();
1168 }
1169 }
1170 eprintln!("[coordinator] Training complete ({total_steps} steps)");
1171 }))
1172}
1173
1174#[cfg(feature = "cuda")]
1175fn train_loop_cuda_distributed(
1176 mut cuda_trainer: CudaTransformerTrainer,
1177 batches: &[LMBatch],
1178 spec: &TrainSpec,
1179) -> Result<()> {
1180 use crate::finetune::distributed::DistributedConfig;
1181 use crate::finetune::WorkerClient;
1182 use crate::train::{shard_batches, DistributedComm, DistributedCudaTrainer};
1183
1184 let dist_config = cuda_trainer
1185 .config()
1186 .distributed
1187 .clone()
1188 .ok_or_else(|| Error::ConfigError("missing distributed config".into()))?;
1189
1190 let rank = dist_config.rank;
1191 let world_size = dist_config.world_size;
1192 let coord_addr = dist_config.coordinator_addr;
1193
1194 println!("Starting distributed training (DDP)...");
1195 println!(" rank: {rank}/{world_size}");
1196 println!(" coordinator: {coord_addr}");
1197
1198 cuda_trainer.ensure_grad_accum();
1199
1200 let num_blocks = cuda_trainer
1201 .grad_accum_ref()
1202 .map_or(0, crate::train::PerBlockGradientAccumulator::num_blocks);
1203
1204 let server_handle = if rank == 0 {
1206 let max_steps = spec.training.max_steps.unwrap_or(usize::MAX);
1207 let batches_per_worker = batches.len().div_ceil(world_size);
1208 let total_steps = std::cmp::min(spec.training.epochs * batches_per_worker, max_steps);
1209 Some(spawn_coordinator_thread(coord_addr, world_size, num_blocks, total_steps)?)
1210 } else {
1211 std::thread::sleep(std::time::Duration::from_millis(100));
1212 None
1213 };
1214
1215 let worker_config = DistributedConfig::worker(coord_addr);
1217 let client = WorkerClient::connect(worker_config, 1, "cuda")
1218 .map_err(|e| Error::ConfigError(format!("WorkerClient connect failed: {e}")))?;
1219 println!(" ✓ Connected as worker {} (id={})", rank, client.worker_id());
1220
1221 let comm = DistributedComm::Remote { client };
1223 let mut ddp_trainer = DistributedCudaTrainer::new(cuda_trainer, comm, dist_config.clone());
1224
1225 let num_batches = batches.len();
1227 let start_time = std::time::Instant::now();
1228 let log_interval = std::cmp::max(num_batches / (world_size * 100).max(1), 1);
1229 let save_interval = spec.training.save_interval;
1230 let max_checkpoints = spec.training.max_checkpoints;
1231 let seed = spec.training.seed.unwrap_or(42);
1232
1233 let _scaling_predictor = ScalingLawPredictor::new();
1235 let seq_len_ddp = spec.data.seq_len.unwrap_or(128);
1236 let grad_accum_ddp = spec.training.gradient_accumulation.unwrap_or(1);
1237 let _tokens_per_step_ddp = spec.data.batch_size * seq_len_ddp * grad_accum_ddp;
1238
1239 let model_name = spec
1240 .model
1241 .path
1242 .file_name()
1243 .and_then(|n| n.to_str())
1244 .unwrap_or("entrenar-model")
1245 .to_string();
1246
1247 let _val_batches = load_val_batches(spec);
1249
1250 let mut loss_history: Vec<f32> = Vec::new();
1251 let mut last_save_step: usize = 0;
1252
1253 for epoch in 0..spec.training.epochs {
1254 let epoch_start = std::time::Instant::now();
1255 let mut total_loss = 0.0;
1256 let mut batches_processed = 0;
1257
1258 let my_batch_indices = shard_batches(num_batches, rank, world_size);
1260
1261 for (iter_idx, &batch_idx) in my_batch_indices.iter().enumerate() {
1262 if ddp_trainer.reached_max_steps() {
1263 break;
1264 }
1265
1266 let batch = &batches[batch_idx];
1267 let step_start = std::time::Instant::now();
1268 let batch_loss = ddp_trainer.train_batch(batch);
1269 let step_elapsed = step_start.elapsed();
1270
1271 if !batch_loss.is_finite() {
1272 continue;
1273 }
1274 total_loss += batch_loss;
1275 batches_processed += 1;
1276 push_capped(&mut loss_history, batch_loss, 100);
1277
1278 if rank == 0 && should_log(iter_idx, log_interval) {
1280 let step = ddp_trainer.step();
1281 let elapsed = epoch_start.elapsed().as_secs_f64();
1282 let seq_len = spec.data.seq_len.unwrap_or(128);
1283 let tokens_done = (iter_idx + 1) * spec.data.batch_size * seq_len * world_size;
1284 let tok_per_sec = tokens_done as f64 / elapsed.max(0.001);
1285 println!(
1286 " [DDP rank 0] step={} loss={:.4} tok/s={:.0} step_time={:.1}ms",
1287 step,
1288 batch_loss,
1289 tok_per_sec,
1290 step_elapsed.as_secs_f64() * 1000.0,
1291 );
1292 }
1293
1294 if rank == 0 {
1296 let current_step = ddp_trainer.step();
1297 if should_save_checkpoint(current_step, last_save_step, save_interval) {
1298 save_and_validate_checkpoint(
1299 ddp_trainer.trainer_mut(),
1300 spec,
1301 &model_name,
1302 current_step,
1303 epoch,
1304 iter_idx,
1305 max_checkpoints,
1306 seed,
1307 0.0,
1308 );
1309 last_save_step = current_step;
1310 }
1311 }
1312 }
1313
1314 if batches_processed > 0 {
1315 let avg_loss = total_loss / batches_processed as f32;
1316 let ppl = crate::train::perplexity(avg_loss);
1317 if rank == 0 {
1318 println!(
1319 "Epoch {}/{}: loss={:.6}, perplexity={:.2}, time={:.1}s",
1320 epoch + 1,
1321 spec.training.epochs,
1322 avg_loss,
1323 ppl,
1324 epoch_start.elapsed().as_secs_f64(),
1325 );
1326 }
1327 }
1328
1329 if ddp_trainer.reached_max_steps() {
1330 break;
1331 }
1332 }
1333
1334 let total_time = start_time.elapsed();
1335 if rank == 0 {
1336 println!("Total distributed training time: {:.1}s", total_time.as_secs_f64());
1337 }
1338
1339 if rank == 0 {
1341 save_trained_model_cuda(ddp_trainer.trainer_mut(), spec)?;
1342 }
1343
1344 if let Some(handle) = server_handle {
1346 let _: std::result::Result<(), _> = handle.join();
1347 }
1348
1349 Ok(())
1350}
1351
1352include!("helpers.rs");
1354
1355include!("data.rs");
1357
1358#[cfg(test)]
1359mod tests;