Skip to main content

entrenar/config/train/loader/
mod.rs

1#![allow(dead_code)]
2//! Main entry points for YAML-based training
3
4// Default model architecture constants (Qwen2.5-Coder-0.5B)
5const QWEN_HIDDEN_SIZE: usize = 896;
6const QWEN_NUM_ATTENTION_HEADS: usize = 14;
7const QWEN_NUM_KV_HEADS: usize = 2;
8const QWEN_INTERMEDIATE_SIZE: usize = 4864;
9const QWEN_NUM_HIDDEN_LAYERS: usize = 24;
10const QWEN_VOCAB_SIZE: usize = 151936;
11const QWEN_MAX_POSITION_EMBEDDINGS: usize = 32768;
12const QWEN_ROPE_THETA: f64 = 1_000_000.0;
13
14use super::batches::load_training_batches;
15use crate::config::schema::{ModelMode, TrainSpec};
16use crate::config::validate::validate_config;
17use crate::error::{Error, Result};
18use crate::monitor::tui::state::{TrainingSnapshot, TrainingState, TrainingStatus};
19use crate::storage::{ExperimentStorage, ParameterValue, RunStatus, SqliteBackend};
20use crate::tokenizer::HfTokenizer;
21use crate::trace::TRACER;
22#[cfg(feature = "cuda")]
23use crate::train::CudaTransformerTrainer;
24use crate::train::{LMBatch, TransformerTrainConfig, TransformerTrainer};
25use crate::transformer::{
26    load_safetensors_weights, Architecture, ModelArchitecture, Transformer, TransformerConfig,
27};
28use crate::yaml_mode;
29use std::fs;
30use std::path::{Path, PathBuf};
31use std::time::{SystemTime, UNIX_EPOCH};
32
33/// Train a model from YAML configuration file
34///
35/// This is the main entry point for declarative training. It:
36/// 1. Loads and parses the YAML config
37/// 2. Validates the configuration
38/// 3. Dispatches to appropriate trainer (Tabular or Transformer)
39/// 4. Runs the training loop
40/// 5. Saves the final model
41///
42/// # Example
43///
44/// ```no_run
45/// use entrenar::config::train_from_yaml;
46///
47/// let model = train_from_yaml("config.yaml")?;
48/// # Ok::<(), Box<dyn std::error::Error>>(())
49/// ```
50pub fn train_from_yaml<P: AsRef<Path>>(config_path: P) -> Result<()> {
51    let spec = load_config(config_path)?;
52
53    // Dispatch based on model mode
54    match spec.model.mode {
55        ModelMode::Transformer => train_transformer_from_spec(&spec),
56        ModelMode::Tabular => train_tabular_from_spec(&spec),
57    }
58}
59
60/// Build a TransformerTrainConfig from YAML spec, wiring all hyperparameters.
61fn build_train_config(
62    model_config: crate::transformer::TransformerConfig,
63    spec: &TrainSpec,
64) -> TransformerTrainConfig {
65    let mut config = TransformerTrainConfig::new(model_config)
66        .with_lr(spec.optimizer.lr)
67        .with_warmup_steps(spec.training.warmup_steps)
68        .with_max_seq_len({
69            let seq_len = spec.data.seq_len.unwrap_or_else(|| {
70                eprintln!("Warning: seq_len not specified in config, defaulting to 512");
71                512
72            });
73            seq_len
74        });
75
76    if let Some(clip) = spec.training.grad_clip {
77        config = config.with_grad_clip(clip);
78    }
79
80    // Wire optimizer hyperparameters from YAML (ALB-040)
81    if let Some(v) = spec.optimizer.params.get("beta2").and_then(serde_json::Value::as_f64) {
82        config = config.with_beta2(v as f32);
83    }
84    if let Some(v) = spec.optimizer.params.get("weight_decay").and_then(serde_json::Value::as_f64) {
85        config = config.with_weight_decay(v as f32);
86    }
87
88    if let Some(accum) = spec.training.gradient_accumulation {
89        config = config.with_accumulation_steps(accum);
90        if accum > 1 {
91            let eff_batch = spec.data.batch_size * accum * spec.data.seq_len.unwrap_or(1024);
92            println!("  Gradient accumulation: {accum} (effective batch: {eff_batch} tokens/step)");
93        }
94    }
95
96    if let Some(max_steps) = spec.training.max_steps {
97        config = config.with_max_steps(max_steps);
98    }
99
100    // Enable mixed precision if specified
101    if let Some(ref precision) = spec.training.mixed_precision {
102        match precision.as_str() {
103            "bf16" => config = config.with_bf16(),
104            "fp16" => config = config.with_fp16(),
105            "fp32" => {}
106            other => {
107                eprintln!("Warning: unknown mixed_precision value '{other}', defaulting to fp32");
108            }
109        }
110    }
111
112    // R-021: Activation checkpointing (gradient recomputation)
113    if let Some(num_segments) = spec.training.checkpoints {
114        config = config.with_checkpointing(num_segments);
115    }
116
117    // R-084: Bitwise deterministic training (C-DETERM-001)
118    if spec.training.deterministic {
119        config = config.with_deterministic(true);
120    }
121    if let Some(seed) = spec.training.seed {
122        config = config.with_seed(seed);
123    }
124
125    // KAIZEN-047: Step profiler (0 = disabled)
126    if spec.training.profile_interval > 0 {
127        config = config.with_profile_interval(spec.training.profile_interval);
128    }
129
130    // ENT-LoRA-001: Wire LoRA config from YAML spec
131    if let Some(ref lora) = spec.lora {
132        config = config.with_lora(lora.rank, lora.alpha, lora.target_modules.clone());
133        // ENT-LoRA-006: LoRA+ ratio from YAML
134        if lora.lora_plus_ratio != 1.0 {
135            config = config.with_lora_plus_ratio(lora.lora_plus_ratio);
136        }
137        // ENT-LoRA-008: Double quantization from YAML
138        if lora.double_quantize {
139            config = config.with_double_quantize(true);
140        }
141        // ENT-263: NF4 quantization for QLoRA pretraining
142        if lora.quantize_base {
143            config = config.with_quantize_nf4(true);
144        }
145    }
146
147    // Wire distributed config from YAML (#133)
148    if let Some(ref dist) = spec.training.distributed {
149        use crate::train::{DistributedBackend, DistributedRole, DistributedTrainConfig};
150
151        let role = match dist.role.as_str() {
152            "worker" => DistributedRole::Worker,
153            _ => DistributedRole::Coordinator,
154        };
155        let backend = match dist.backend.as_str() {
156            "cuda" => DistributedBackend::Cuda,
157            "wgpu" => DistributedBackend::Wgpu,
158            _ => DistributedBackend::Auto,
159        };
160        let addr: std::net::SocketAddr =
161            dist.coordinator_addr.parse().unwrap_or_else(|_| "0.0.0.0:9000".parse().unwrap());
162
163        config = config.with_distributed(DistributedTrainConfig {
164            world_size: dist.world_size,
165            rank: dist.rank,
166            local_rank: dist.local_rank,
167            role,
168            coordinator_addr: addr,
169            backend,
170        });
171    }
172
173    config
174}
175
176/// Train a transformer model (LLM) from spec
177///
178/// Uses TransformerTrainer with CausalLMLoss for language modeling.
179fn train_transformer_from_spec(spec: &TrainSpec) -> Result<()> {
180    println!("✓ Config loaded and validated (Transformer mode)");
181    println!("  Model: {}", spec.model.path.display());
182    println!("  Optimizer: {} (lr={})", spec.optimizer.name, spec.optimizer.lr);
183    println!("  Batch size: {}", spec.data.batch_size);
184    println!("  Epochs: {}", spec.training.epochs);
185    println!("  Training mode: {:?}", spec.training.mode);
186
187    if let Some(lora) = &spec.lora {
188        println!("  LoRA: rank={}, alpha={}", lora.rank, lora.alpha);
189        if lora.quantize_base {
190            println!("  QLoRA: NF4 quantized base weights (~8x VRAM compression)");
191        }
192    }
193    println!();
194
195    // Build TransformerConfig from spec
196    let model_config = build_transformer_config_from_spec(spec)?;
197
198    // Resolve model path (downloads from HF Hub if repo ID)
199    let resolved_path = resolve_model_path(&spec.model.path)?;
200
201    // C-INIT-001: Set init seed BEFORE model construction (entrenar#309)
202    crate::transformer::init::set_init_seed(spec.training.seed.unwrap_or(42));
203
204    // Try to load model weights if path exists (ENT-117)
205    // ALB-097: Check output_dir first for checkpoint resume, then model_path for initial weights
206    #[cfg(feature = "cuda")]
207    let (transformer, checkpoint_step) =
208        load_transformer_model(&resolved_path, &model_config, &spec.training.output_dir)?;
209    #[cfg(not(feature = "cuda"))]
210    let (transformer, _checkpoint_step) =
211        load_transformer_model(&resolved_path, &model_config, &spec.training.output_dir)?;
212
213    // Build TransformerTrainConfig from YAML spec fields
214    let train_config = build_train_config(model_config, spec);
215
216    // Apply deterministic settings before any CUDA operations
217    train_config.apply_deterministic_settings();
218
219    // Load training data as LMBatches (supports tokenizer + text data)
220    println!("Loading training data...");
221    let batches = load_lm_batches(spec)?;
222    println!("✓ {} LM batches created", batches.len());
223    println!();
224
225    // Try CUDA-resident training first (ALB-040), fall back to CPU
226    #[cfg(feature = "cuda")]
227    if train_config.use_cuda {
228        let cuda_config = train_config.clone();
229        let cuda_result = match transformer {
230            Some(loaded_model) => CudaTransformerTrainer::with_model(loaded_model, cuda_config),
231            None => CudaTransformerTrainer::new(cuda_config),
232        };
233
234        match cuda_result {
235            Ok(mut cuda_trainer) => {
236                // Restore step counter from checkpoint for LR schedule + AdamW bias correction
237                if checkpoint_step > 0 {
238                    cuda_trainer.set_initial_step(checkpoint_step);
239                    println!(
240                        "  Resumed at step {checkpoint_step} (lr={:.2e})",
241                        cuda_trainer.current_lr()
242                    );
243                    // ALB-096: Try APR optimizer state first, then fall back to JSON
244                    let apr_loaded = find_latest_apr_checkpoint(&spec.training.output_dir)
245                        .is_some_and(|p| {
246                            // ENT-276: Restore LoRA adapter weights from APR checkpoint
247                            let (restored, total) = cuda_trainer.restore_lora_from_apr(&p);
248                            if restored > 0 {
249                                println!("  ✓ LoRA adapters restored ({restored}/{total} layers)");
250                            }
251                            cuda_trainer.load_optimizer_state_apr(&p)
252                        });
253                    if apr_loaded {
254                        println!("  ✓ Embedding optimizer state restored (APR)");
255                    } else if cuda_trainer.load_optimizer_state(&spec.training.output_dir) {
256                        println!("  ✓ Embedding optimizer state restored (JSON)");
257                    }
258                }
259                println!("✓ CudaTransformerTrainer initialized (GPU: {})", cuda_trainer.gpu_name());
260                // #133: Dispatch to distributed training loop if distributed config present
261                if train_config.distributed.is_some() {
262                    return train_loop_cuda_distributed(cuda_trainer, &batches, spec);
263                }
264                return train_loop_cuda(&mut cuda_trainer, &batches, spec);
265            }
266            Err(e) => {
267                eprintln!("Warning: CUDA training failed ({e}), falling back to CPU");
268                // transformer was consumed — rebuild from config
269                let mut trainer = TransformerTrainer::new(train_config);
270                println!("✓ TransformerTrainer initialized (CPU fallback)");
271                println!("  Mixed precision: {}", trainer.is_mixed_precision());
272                println!("  Checkpointing: {}", trainer.is_checkpointing());
273                println!();
274                return train_loop_cpu(&mut trainer, &batches, spec);
275            }
276        }
277    }
278
279    // CPU-only path (use_cuda=false or no CUDA feature)
280    let mut trainer = if let Some(loaded_model) = transformer {
281        TransformerTrainer::with_model(loaded_model, train_config)
282    } else {
283        TransformerTrainer::new(train_config)
284    };
285    println!("✓ TransformerTrainer initialized (CPU)");
286    println!("  Mixed precision: {}", trainer.is_mixed_precision());
287    println!("  Checkpointing: {}", trainer.is_checkpointing());
288    println!();
289
290    train_loop_cpu(&mut trainer, &batches, spec)
291}
292
293/// Training loop for CPU TransformerTrainer (ALB-045, ALB-055/056)
294fn train_loop_cpu(
295    trainer: &mut TransformerTrainer,
296    batches: &[LMBatch],
297    spec: &TrainSpec,
298) -> Result<()> {
299    println!("Starting transformer training (CPU)...");
300    println!();
301
302    TRACER.enable();
303    TRACER.clear();
304
305    let num_batches = batches.len();
306    let start_time = std::time::Instant::now();
307    let log_interval = (num_batches / 100).clamp(1, 100);
308
309    // ALB-045: Initialize training state IPC for `apr monitor`
310    let state = TrainingState::new(&spec.training.output_dir);
311    let start_ms = now_ms();
312    let total_epochs = spec.training.epochs;
313
314    // ALB-055/056: Open SQLite experiment tracking (local + global)
315    let mut tracker = PretrainTracker::open(spec, "CPU");
316
317    write_training_snapshot(
318        &state,
319        start_ms,
320        0,
321        total_epochs,
322        0,
323        num_batches,
324        0.0,
325        &[],
326        0.0,
327        0.0,
328        TrainingStatus::Initializing,
329        spec,
330        "CPU",
331    );
332
333    if let Some(max_steps) = spec.training.max_steps {
334        println!("  max_steps: {max_steps} (will stop early when reached)");
335    }
336
337    let mut loss_history: Vec<f32> = Vec::new();
338
339    for epoch in 0..spec.training.epochs {
340        let epoch_start = std::time::Instant::now();
341        let avg_loss =
342            trainer.train_epoch_with_callback(batches, |batch_idx, batch_loss, trainer| {
343                loss_history.push(batch_loss);
344                if loss_history.len() > 100 {
345                    loss_history.remove(0);
346                }
347
348                if (batch_idx + 1) % log_interval == 0 || batch_idx == 0 {
349                    let elapsed = epoch_start.elapsed().as_secs_f64();
350                    let batches_done = batch_idx + 1;
351                    let seq_len = spec.data.seq_len.unwrap_or(128);
352                    let tokens_done = batches_done * spec.data.batch_size * seq_len;
353                    let batch_per_sec = batches_done as f64 / elapsed.max(0.001);
354                    let remaining = (num_batches - batches_done) as f64 / batch_per_sec.max(0.001);
355                    let tok_per_sec = tokens_done as f64 / elapsed.max(0.001);
356                    println!(
357                        "  [{}/{} batches] step={} loss={:.4} lr={:.2e} tok/s={:.0} eta={:.0}s",
358                        batches_done,
359                        num_batches,
360                        trainer.step(),
361                        batch_loss,
362                        trainer.current_lr(),
363                        tok_per_sec,
364                        remaining,
365                    );
366
367                    // ALB-045: Write snapshot for `apr monitor`
368                    write_training_snapshot(
369                        &state,
370                        start_ms,
371                        epoch + 1,
372                        total_epochs,
373                        trainer.step(),
374                        num_batches,
375                        batch_loss,
376                        &loss_history,
377                        trainer.current_lr(),
378                        tok_per_sec as f32,
379                        TrainingStatus::Running,
380                        spec,
381                        "CPU",
382                    );
383
384                    // ALB-055/056: Log step metrics to SQLite
385                    tracker.log_step(
386                        trainer.step() as u64,
387                        batch_loss,
388                        trainer.current_lr(),
389                        tok_per_sec as f32,
390                    );
391                }
392            });
393        let ppl = crate::train::perplexity(avg_loss);
394        println!(
395            "Epoch {}/{}: loss={:.6}, perplexity={:.2}, time={:.1}s",
396            epoch + 1,
397            spec.training.epochs,
398            avg_loss,
399            ppl,
400            epoch_start.elapsed().as_secs_f64(),
401        );
402
403        if trainer.reached_max_steps() {
404            println!(
405                "Reached max_steps={}, stopping training.",
406                spec.training.max_steps.unwrap_or(0)
407            );
408            break;
409        }
410    }
411
412    let total_time = start_time.elapsed();
413    println!("Total training time: {:.1}s", total_time.as_secs_f64());
414    println!("{}", TRACER.report());
415
416    // ALB-045: Write final "Completed" snapshot
417    let final_loss = trainer.metrics.losses.last().copied().unwrap_or(0.0);
418    write_training_snapshot(
419        &state,
420        start_ms,
421        total_epochs,
422        total_epochs,
423        trainer.step(),
424        num_batches,
425        final_loss,
426        &loss_history,
427        trainer.current_lr(),
428        0.0,
429        TrainingStatus::Completed,
430        spec,
431        "CPU",
432    );
433
434    // ALB-055/056: Mark run as completed in SQLite
435    tracker.complete();
436
437    save_trained_model_cpu(trainer, spec)
438}
439
440/// Get current Unix timestamp in milliseconds
441fn now_ms() -> u64 {
442    SystemTime::now().duration_since(UNIX_EPOCH).map(|d| d.as_millis() as u64).unwrap_or(0)
443}
444
445/// Query live GPU telemetry via nvidia-smi CLI (ALB-046)
446///
447/// Shells out to `nvidia-smi --query-gpu` with CSV output and parses
448/// the result into GpuTelemetry. Zero-dependency approach — nvidia-smi
449/// is always available when CUDA is. Returns None if nvidia-smi fails.
450fn query_gpu_telemetry(device_name: &str) -> Option<crate::monitor::tui::state::GpuTelemetry> {
451    let output = std::process::Command::new("nvidia-smi")
452        .args([
453            "--query-gpu=utilization.gpu,memory.used,memory.total,temperature.gpu,power.draw,power.limit",
454            "--format=csv,noheader,nounits",
455        ])
456        .output()
457        .ok()?;
458
459    if !output.status.success() {
460        return None;
461    }
462
463    let stdout = String::from_utf8_lossy(&output.stdout);
464    let line = stdout.lines().next()?.trim();
465    let fields: Vec<&str> = line.split(',').map(str::trim).collect();
466    if fields.len() < 6 {
467        return None;
468    }
469
470    Some(crate::monitor::tui::state::GpuTelemetry {
471        device_name: device_name.to_string(),
472        utilization_percent: fields[0].parse().unwrap_or(0.0),
473        vram_used_gb: fields[1].parse::<f32>().unwrap_or(0.0) / 1024.0, // MiB → GiB
474        vram_total_gb: fields[2].parse::<f32>().unwrap_or(0.0) / 1024.0,
475        temperature_celsius: fields[3].parse().unwrap_or(0.0),
476        power_watts: fields[4].parse().unwrap_or(0.0),
477        power_limit_watts: fields[5].parse().unwrap_or(0.0),
478        processes: Vec::new(),
479    })
480}
481
482/// Write a TrainingSnapshot to training_state.json (ALB-045)
483///
484/// This is the IPC mechanism that `apr monitor` reads. Called on every
485/// log interval so the TUI stays current. Uses atomic write (tmp+rename)
486/// via TrainingState::write().
487fn write_training_snapshot(
488    state: &TrainingState,
489    start_ms: u64,
490    epoch: usize,
491    total_epochs: usize,
492    step: usize,
493    steps_per_epoch: usize,
494    loss: f32,
495    loss_history: &[f32],
496    lr: f32,
497    tokens_per_second: f32,
498    status: TrainingStatus,
499    spec: &TrainSpec,
500    gpu_name: &str,
501) {
502    let snapshot = TrainingSnapshot {
503        timestamp_ms: now_ms(),
504        epoch,
505        total_epochs,
506        step,
507        steps_per_epoch,
508        loss,
509        loss_history: loss_history.to_vec(),
510        learning_rate: lr,
511        lr_history: Vec::new(),
512        gradient_norm: 0.0, // not tracked per-batch in current trainer
513        tokens_per_second,
514        start_timestamp_ms: start_ms,
515        gpu: query_gpu_telemetry(gpu_name).or_else(|| {
516            Some(crate::monitor::tui::state::GpuTelemetry {
517                device_name: gpu_name.to_string(),
518                ..Default::default()
519            })
520        }),
521        sample: None,
522        status,
523        experiment_id: spec.training.output_dir.display().to_string(),
524        model_name: spec.model.path.display().to_string(),
525        model_path: spec.model.path.display().to_string(),
526        optimizer_name: spec.optimizer.name.clone(),
527        batch_size: spec.data.batch_size,
528        checkpoint_path: spec.training.output_dir.display().to_string(),
529        executable_path: String::new(),
530        accuracy: 0.0,
531        samples_per_second: 0.0,
532    };
533    if let Err(e) = state.write(&snapshot) {
534        eprintln!("[ALB-045] Failed to write training_state.json: {e}");
535    }
536}
537
538// =============================================================================
539// SQLite Experiment Tracking (ALB-055/056)
540// =============================================================================
541
542/// Best-effort experiment tracker for pretrain loops.
543///
544/// Opens two SQLite databases:
545/// - **Local**: `<output_dir>/.entrenar/experiments.db` (per-experiment metrics history)
546/// - **Global**: `~/.entrenar/experiments.db` (cross-machine experiment registry)
547///
548/// All operations are best-effort — storage failures never block training.
549struct PretrainTracker {
550    local: Option<SqliteBackend>,
551    global: Option<SqliteBackend>,
552    run_id: Option<String>,
553    global_run_id: Option<String>,
554}
555
556impl PretrainTracker {
557    /// Open both local and global SQLite stores, create experiment + run.
558    fn open(spec: &TrainSpec, device: &str) -> Self {
559        let exp_name =
560            spec.training.output_dir.file_name().and_then(|n| n.to_str()).unwrap_or("pretrain");
561
562        let config_json = serde_json::json!({
563            "task": "pretrain",
564            "model": spec.model.path.display().to_string(),
565            "optimizer": &spec.optimizer.name,
566            "lr": spec.optimizer.lr,
567            "epochs": spec.training.epochs,
568            "batch_size": spec.data.batch_size,
569            "seq_len": spec.data.seq_len,
570            "max_steps": spec.training.max_steps,
571            "device": device,
572            "output_dir": spec.training.output_dir.display().to_string(),
573        });
574
575        // Local store: in the output/checkpoint directory
576        let local = SqliteBackend::open_project(&spec.training.output_dir).ok();
577
578        // Global store: ~/.entrenar/experiments.db
579        let global = dirs::home_dir().map(|h| h.join(".entrenar")).and_then(|p| {
580            fs::create_dir_all(&p).ok()?;
581            SqliteBackend::open(p.join("experiments.db").to_string_lossy().as_ref()).ok()
582        });
583
584        let mut tracker = Self { local, global, run_id: None, global_run_id: None };
585
586        // Create experiment + run in local store
587        if let Some(store) = tracker.local.as_mut() {
588            if let Ok(eid) = store.create_experiment(exp_name, Some(config_json.clone())) {
589                if let Ok(rid) = store.create_run(&eid) {
590                    let _ = store.start_run(&rid);
591                    log_run_params(store, &rid, spec, device);
592                    tracker.run_id = Some(rid);
593                }
594            }
595        }
596
597        // Create experiment + run in global store
598        if let Some(store) = tracker.global.as_mut() {
599            if let Ok(eid) = store.create_experiment(exp_name, Some(config_json)) {
600                if let Ok(rid) = store.create_run(&eid) {
601                    let _ = store.start_run(&rid);
602                    log_run_params(store, &rid, spec, device);
603                    tracker.global_run_id = Some(rid);
604                }
605            }
606        }
607
608        tracker
609    }
610
611    /// Log a training step's metrics to both local and global stores.
612    fn log_step(&mut self, step: u64, loss: f32, lr: f32, tok_per_sec: f32) {
613        for (store, run_id) in [
614            (self.local.as_mut(), self.run_id.as_deref()),
615            (self.global.as_mut(), self.global_run_id.as_deref()),
616        ] {
617            if let (Some(s), Some(rid)) = (store, run_id) {
618                let _ = s.log_metric(rid, "loss", step, f64::from(loss));
619                let _ = s.log_metric(rid, "learning_rate", step, f64::from(lr));
620                let _ = s.log_metric(rid, "tokens_per_second", step, f64::from(tok_per_sec));
621            }
622        }
623    }
624
625    /// Mark training as completed in both stores.
626    fn complete(&mut self) {
627        for (store, run_id) in [
628            (self.local.as_mut(), self.run_id.as_deref()),
629            (self.global.as_mut(), self.global_run_id.as_deref()),
630        ] {
631            if let (Some(s), Some(rid)) = (store, run_id) {
632                let _ = s.complete_run(rid, RunStatus::Success);
633            }
634        }
635    }
636
637    /// Mark training as failed in both stores.
638    #[allow(dead_code)]
639    fn fail(&mut self) {
640        for (store, run_id) in [
641            (self.local.as_mut(), self.run_id.as_deref()),
642            (self.global.as_mut(), self.global_run_id.as_deref()),
643        ] {
644            if let (Some(s), Some(rid)) = (store, run_id) {
645                let _ = s.complete_run(rid, RunStatus::Failed);
646            }
647        }
648    }
649}
650
651/// Log hyperparameters for a pretrain run (ALB-055/056)
652fn log_run_params(store: &SqliteBackend, run_id: &str, spec: &TrainSpec, device: &str) {
653    let _ = store.log_param(run_id, "task", ParameterValue::String("pretrain".into()));
654    let _ = store.log_param(
655        run_id,
656        "model",
657        ParameterValue::String(spec.model.path.display().to_string()),
658    );
659    let _ =
660        store.log_param(run_id, "optimizer", ParameterValue::String(spec.optimizer.name.clone()));
661    let _ = store.log_param(
662        run_id,
663        "learning_rate",
664        ParameterValue::Float(f64::from(spec.optimizer.lr)),
665    );
666    let _ = store.log_param(run_id, "epochs", ParameterValue::Int(spec.training.epochs as i64));
667    let _ = store.log_param(run_id, "batch_size", ParameterValue::Int(spec.data.batch_size as i64));
668    let _ = store.log_param(run_id, "device", ParameterValue::String(device.to_string()));
669    let _ = store.log_param(
670        run_id,
671        "output_dir",
672        ParameterValue::String(spec.training.output_dir.display().to_string()),
673    );
674    if let Some(seq_len) = spec.data.seq_len {
675        let _ = store.log_param(run_id, "seq_len", ParameterValue::Int(seq_len as i64));
676    }
677    if let Some(max_steps) = spec.training.max_steps {
678        let _ = store.log_param(run_id, "max_steps", ParameterValue::Int(max_steps as i64));
679    }
680}
681
682/// Training loop for GPU CudaTransformerTrainer
683///
684fn print_max_steps(max_steps: Option<usize>) {
685    if let Some(ms) = max_steps {
686        println!("  max_steps: {ms} (will stop early when reached)");
687    }
688}
689
690/// ALB-068: Manual batch loop for intermediate checkpoint saving.
691/// R-004: Gradient norm logging. R-008: Graceful shutdown.
692/// R-009: Multi-checkpoint retention. R-012: MFU tracking.
693/// R-014: JSONL experiment log. R-015: Per-epoch shuffling.
694/// R-006/R-007: Training state persistence.
695#[cfg(feature = "cuda")]
696fn train_loop_cuda(
697    trainer: &mut CudaTransformerTrainer,
698    batches: &[LMBatch],
699    spec: &TrainSpec,
700) -> Result<()> {
701    use std::sync::atomic::{AtomicBool, Ordering};
702    use std::sync::Arc;
703
704    println!("Starting transformer training (CUDA GPU-resident)...");
705    println!();
706
707    let num_batches = batches.len();
708
709    // ENT-275: Auto-compute max_steps for cosine LR scheduler when not explicit.
710    // Without this, current_lr() falls back to constant lr (no decay).
711    if spec.training.max_steps.is_none() {
712        let total_steps = spec.training.epochs * num_batches;
713        trainer.set_max_steps(total_steps);
714        println!(
715            "  max_steps: {total_steps} (auto: {epochs}×{num_batches})",
716            epochs = spec.training.epochs
717        );
718    }
719
720    let start_time = std::time::Instant::now();
721    // Cap log_interval so training_state.json updates at least every 100 steps
722    // (enables real-time monitoring via `apr monitor`). Previously num_batches/100
723    // gave 12905 for large datasets — too infrequent for a 12-day run.
724    let log_interval = (num_batches / 100).clamp(1, 100);
725    let save_interval = spec.training.save_interval;
726    let max_checkpoints = spec.training.max_checkpoints;
727
728    // ALB-087: Auto eval scheduling — eval_interval defaults to save_interval
729    let eval_interval =
730        if spec.training.eval_interval > 0 { spec.training.eval_interval } else { save_interval };
731    let patience = spec.training.patience;
732    let mut best_val_loss: f32 = f32::INFINITY;
733    let mut evals_without_improvement: usize = 0;
734    let mut last_eval_step: usize = 0;
735
736    // ALB-045: Initialize training state IPC for `apr monitor`
737    let state = TrainingState::new(&spec.training.output_dir);
738    let start_ms = now_ms();
739    let gpu_name = trainer.gpu_name();
740    let total_epochs = spec.training.epochs;
741
742    // ALB-055/056: Open SQLite experiment tracking (local + global)
743    let mut tracker = PretrainTracker::open(spec, &gpu_name);
744
745    // R-012: MFU calculation setup
746    let num_params = trainer.num_params();
747    let seq_len = spec.data.seq_len.unwrap_or(128);
748    let tokens_per_batch = spec.data.batch_size * seq_len;
749    // RTX 4090: 82.6 TFLOPS fp32 (query via cuDeviceGetAttribute when available)
750    let gpu_peak_tflops: f64 = 82.58e12;
751
752    // R-014: Open JSONL experiment log
753    let jsonl_path = spec.training.output_dir.join("training_log.jsonl");
754    std::fs::create_dir_all(&spec.training.output_dir).ok();
755    let mut jsonl_file =
756        std::fs::OpenOptions::new().create(true).append(true).open(&jsonl_path).ok();
757    // Write config header
758    write_jsonl_event_json(
759        &mut jsonl_file,
760        &serde_json::json!({
761            "type": "config",
762            "num_params": num_params,
763            "batch_size": spec.data.batch_size,
764            "seq_len": seq_len,
765            "max_steps": spec.training.max_steps,
766            "epochs": spec.training.epochs,
767            "lr": spec.optimizer.lr,
768            "gpu": &gpu_name,
769            "timestamp": now_ms(),
770        }),
771    );
772
773    // R-008: Graceful shutdown signal handler
774    let shutdown_flag = Arc::new(AtomicBool::new(false));
775    {
776        let flag = shutdown_flag.clone();
777        let _ = ctrlc::set_handler(move || {
778            flag.store(true, Ordering::SeqCst);
779            eprintln!("\n[SIGINT] Graceful shutdown requested. Saving checkpoint...");
780        });
781    }
782
783    // Write initial "Initializing" snapshot
784    write_training_snapshot(
785        &state,
786        start_ms,
787        0,
788        total_epochs,
789        0,
790        num_batches,
791        0.0,
792        &[],
793        0.0,
794        0.0,
795        TrainingStatus::Initializing,
796        spec,
797        &gpu_name,
798    );
799
800    print_max_steps(spec.training.max_steps);
801
802    // ALB-087: Print eval scheduling config
803    if eval_interval != save_interval {
804        println!("  eval_interval: {eval_interval} (decoupled from save_interval={save_interval})");
805    }
806    if patience > 0 {
807        println!("  early_stopping: patience={patience} eval intervals");
808    }
809
810    // ALB-082: Scaling law predictor for early convergence ceiling detection
811    let mut scaling_predictor = ScalingLawPredictor::new();
812    let tokens_per_step = tokens_per_batch * spec.training.gradient_accumulation.unwrap_or(1);
813
814    // Track loss history for TUI sparkline
815    let mut loss_history: Vec<f32> = Vec::new();
816    let mut last_save_step: usize = 0;
817
818    let model_name = spec
819        .model
820        .path
821        .file_name()
822        .and_then(|n| n.to_str())
823        .unwrap_or("entrenar-model")
824        .to_string();
825
826    // R-015: Prepare shuffled batch indices
827    let shuffle = spec.training.shuffle;
828    let seed = spec.training.seed.unwrap_or(42);
829
830    // R-005: Load validation batches if val path exists
831    let val_batches = load_val_batches(spec);
832
833    // R-018: NaN/Inf detection counter
834    let mut nan_skips: usize = 0;
835
836    // R-017: ZClip adaptive gradient clipping — EMA of gradient norms
837    let mut gnorm_ema: f64 = 0.0;
838    let mut gnorm_ema_sq: f64 = 0.0;
839    let zclip_alpha: f64 = 0.05; // EMA decay rate
840    let zclip_threshold: f64 = 2.0; // z-score threshold for spike detection
841
842    // R-003: Heartbeat file for crash detection
843    let heartbeat_path = spec.training.output_dir.join("heartbeat");
844
845    // R-016b: Loss spike rollback — EMA for spike detection
846    let mut loss_ema: f64 = 0.0;
847    let loss_ema_alpha: f64 = 0.05;
848    let loss_spike_threshold: f64 = 3.0; // spike if loss > threshold × EMA
849    let mut rollback_count: usize = 0;
850    let max_rollbacks: usize = 3;
851
852    // R-029: Gradient noise scale estimation — rolling window of grad norms
853    let mut gnorm_window: Vec<f64> = Vec::with_capacity(100);
854    let noise_scale_interval: usize = 100;
855    let mut last_noise_scale_step: usize = usize::MAX; // Dedup: only log once per optimizer step
856
857    // R-026: Save training config hash to JSONL for diff tracking
858    write_config_provenance(&mut jsonl_file, spec);
859
860    // R-023: Curriculum learning — track current stage index
861    let mut curriculum_stage: usize = 0;
862    let curriculum = spec.training.curriculum.as_deref();
863    print_curriculum_stages(curriculum);
864
865    // ALB-120: Skip batches already processed before checkpoint.
866    let grad_accum = spec.training.gradient_accumulation.unwrap_or(1);
867    let resume_batch_idx = trainer.step() * grad_accum;
868
869    'outer: for epoch in 0..spec.training.epochs {
870        let epoch_start = std::time::Instant::now();
871        let mut total_loss = 0.0;
872        let mut batches_processed = 0;
873
874        // R-015: Generate shuffled indices for this epoch
875        let batch_order = shuffled_batch_order(num_batches, shuffle, seed, epoch);
876
877        // ALB-068: Manual batch loop for intermediate checkpoint saving
878        for (iter_idx, &batch_idx) in batch_order.iter().enumerate() {
879            // ALB-120: Skip batches already processed before checkpoint
880            if iter_idx < resume_batch_idx {
881                continue;
882            }
883            // R-008: Check graceful shutdown flag
884            if shutdown_flag.load(Ordering::SeqCst) {
885                handle_graceful_shutdown(
886                    trainer,
887                    spec,
888                    &state,
889                    &mut tracker,
890                    start_ms,
891                    epoch,
892                    iter_idx,
893                    total_epochs,
894                    num_batches,
895                    &loss_history,
896                    &model_name,
897                    &gpu_name,
898                    seed,
899                    loss_ema,
900                );
901                return Ok(());
902            }
903
904            // Check max_steps before processing
905            if reached_max_steps(spec.training.max_steps, trainer.step()) {
906                break 'outer;
907            }
908
909            // R-023: Check curriculum stage transition
910            curriculum_stage = check_curriculum_transition(
911                curriculum,
912                curriculum_stage,
913                trainer.step(),
914                &mut jsonl_file,
915            );
916
917            let batch = &batches[batch_idx];
918            // R-028: Per-step timing
919            let step_start = std::time::Instant::now();
920            let batch_loss = trainer.train_batch(batch);
921            let step_elapsed = step_start.elapsed();
922
923            // R-018: NaN/Inf detection — skip step if loss is non-finite
924            if !batch_loss.is_finite() {
925                nan_skips += 1;
926                println!(
927                    "  [WARN] NaN/Inf loss at step {} (skip #{}) — skipping",
928                    trainer.step(),
929                    nan_skips
930                );
931                continue;
932            }
933            total_loss += batch_loss;
934            batches_processed += 1;
935
936            // ENT-283: Seed loss EMA to first observed loss to avoid cold-start false rollbacks
937            if loss_ema == 0.0 {
938                loss_ema = f64::from(batch_loss);
939            }
940
941            // R-016b: Loss spike detection + rollback
942            detect_loss_spike(
943                batch_loss,
944                trainer.step(),
945                &mut loss_ema,
946                loss_ema_alpha,
947                loss_spike_threshold,
948                &mut rollback_count,
949                max_rollbacks,
950                &mut jsonl_file,
951            );
952
953            // R-017: ZClip — update EMA and detect gradient spikes
954            zclip_update(
955                f64::from(trainer.last_grad_norm()),
956                trainer.step(),
957                &mut gnorm_ema,
958                &mut gnorm_ema_sq,
959                zclip_alpha,
960                zclip_threshold,
961            );
962
963            // R-029: Track grad norm for noise scale estimation
964            update_noise_scale(
965                f64::from(trainer.last_grad_norm()),
966                trainer.step(),
967                &mut gnorm_window,
968                noise_scale_interval,
969                &mut last_noise_scale_step,
970                &mut jsonl_file,
971            );
972
973            // R-003: Write heartbeat for crash detection
974            write_heartbeat(&heartbeat_path, trainer.step());
975
976            // Track loss history (keep last 100 for sparkline)
977            push_capped(&mut loss_history, batch_loss, 100);
978
979            // Logging at log_interval boundaries
980            if should_log(iter_idx, log_interval) {
981                log_step_metrics(
982                    trainer,
983                    &state,
984                    &mut tracker,
985                    &mut jsonl_file,
986                    &epoch_start,
987                    &start_time,
988                    &step_elapsed,
989                    epoch,
990                    total_epochs,
991                    iter_idx,
992                    num_batches,
993                    tokens_per_batch,
994                    num_params,
995                    gpu_peak_tflops,
996                    start_ms,
997                    batch_loss,
998                    &loss_history,
999                    spec,
1000                    &gpu_name,
1001                );
1002            }
1003
1004            // ALB-068/R-009: Intermediate checkpoint saving at save_interval
1005            let current_step = trainer.step();
1006            // ALB-132: Skip save on the first step after resume — the checkpoint we just
1007            // loaded is already at this step. Re-saving would overwrite the original with
1008            // potentially incomplete GPU optimizer state, corrupting subsequent resumes.
1009            let is_resume_step =
1010                resume_batch_idx > 0 && current_step == resume_batch_idx / grad_accum;
1011            let do_save = !is_resume_step
1012                && should_save_checkpoint(current_step, last_save_step, save_interval);
1013            let do_eval = current_step > 0
1014                && current_step != last_eval_step
1015                && current_step.is_multiple_of(eval_interval);
1016
1017            if do_save {
1018                save_and_validate_checkpoint(
1019                    trainer,
1020                    spec,
1021                    &model_name,
1022                    current_step,
1023                    epoch,
1024                    iter_idx,
1025                    max_checkpoints,
1026                    seed,
1027                    loss_ema,
1028                );
1029                last_save_step = current_step;
1030            }
1031
1032            // ALB-087: Decoupled eval + best-model tracking + early stopping
1033            if do_eval {
1034                last_eval_step = current_step;
1035                let eval_val_loss = run_validation_eval(
1036                    trainer,
1037                    &val_batches,
1038                    current_step,
1039                    &mut jsonl_file,
1040                    &mut scaling_predictor,
1041                    tokens_per_step,
1042                    spec.training.max_steps,
1043                );
1044                if let Some(val_loss) = eval_val_loss {
1045                    if val_loss < best_val_loss {
1046                        best_val_loss = val_loss;
1047                        evals_without_improvement = 0;
1048                        save_best_model(trainer, spec, &model_name, current_step);
1049                    } else {
1050                        evals_without_improvement += 1;
1051                    }
1052                    if patience > 0 && evals_without_improvement >= patience {
1053                        println!(
1054                            "  [early-stop] No improvement for {evals_without_improvement} evals (patience={patience}). \
1055                             Best val_loss={best_val_loss:.4}. Stopping.",
1056                        );
1057                        write_jsonl_event_json(
1058                            &mut jsonl_file,
1059                            &serde_json::json!({
1060                                "type": "early_stop",
1061                                "step": current_step,
1062                                "best_val_loss": best_val_loss,
1063                                "evals_without_improvement": evals_without_improvement,
1064                                "patience": patience,
1065                                "timestamp": now_ms(),
1066                            }),
1067                        );
1068                        break 'outer;
1069                    }
1070                }
1071            }
1072        }
1073
1074        let avg_loss = total_loss / batches_processed.max(1) as f32;
1075        let ppl = crate::train::perplexity(avg_loss);
1076        println!(
1077            "Epoch {}/{}: loss={:.6}, perplexity={:.2}, time={:.1}s",
1078            epoch + 1,
1079            spec.training.epochs,
1080            avg_loss,
1081            ppl,
1082            epoch_start.elapsed().as_secs_f64(),
1083        );
1084
1085        if reached_max_steps(spec.training.max_steps, trainer.step()) {
1086            break;
1087        }
1088    }
1089
1090    let total_time = start_time.elapsed();
1091    println!("Total training time: {:.1}s", total_time.as_secs_f64());
1092
1093    // KAIZEN-047: Print step profiler report at end of training
1094    trainer.print_profiler_report();
1095
1096    // ALB-045: Write final "Completed" snapshot
1097    let final_loss = trainer.metrics.losses.last().copied().unwrap_or(0.0);
1098    write_training_snapshot(
1099        &state,
1100        start_ms,
1101        total_epochs,
1102        total_epochs,
1103        trainer.step(),
1104        num_batches,
1105        final_loss,
1106        &loss_history,
1107        trainer.current_lr(),
1108        0.0,
1109        TrainingStatus::Completed,
1110        spec,
1111        &gpu_name,
1112    );
1113
1114    // ALB-055/056: Mark run as completed in SQLite
1115    tracker.complete();
1116
1117    // R-014: Write completion entry
1118    write_jsonl_event_json(
1119        &mut jsonl_file,
1120        &serde_json::json!({
1121            "type": "complete",
1122            "step": trainer.step(),
1123            "final_loss": final_loss,
1124            "total_time_s": total_time.as_secs_f64(),
1125            "timestamp": now_ms(),
1126        }),
1127    );
1128
1129    save_trained_model_cuda(trainer, spec)
1130}
1131
1132/// Distributed CUDA training loop (#133).
1133///
1134/// Multi-process DDP: each process runs this function with its own rank.
1135/// Rank 0 spawns the GradientServer in a background thread. All ranks
1136/// connect as workers and run the DDP training step in lockstep.
1137///
1138/// Data is sharded by rank: worker N processes batches N, N+ws, N+2*ws, ...
1139#[cfg(feature = "cuda")]
1140/// Spawn the coordinator (GradientServer) thread for DDP rank 0.
1141fn spawn_coordinator_thread(
1142    coord_addr: std::net::SocketAddr,
1143    world_size: usize,
1144    num_blocks: usize,
1145    total_steps: usize,
1146) -> Result<std::thread::JoinHandle<()>> {
1147    use crate::finetune::distributed::DistributedConfig;
1148    use crate::finetune::GradientServer;
1149
1150    let server_config = DistributedConfig::coordinator(coord_addr, world_size);
1151    let mut server = GradientServer::bind(server_config)
1152        .map_err(|e| Error::ConfigError(format!("GradientServer bind failed: {e}")))?;
1153    println!("  ✓ GradientServer bound on {coord_addr}");
1154
1155    Ok(std::thread::spawn(move || {
1156        server.wait_for_workers().unwrap();
1157        eprintln!("[coordinator] All {world_size} workers connected");
1158
1159        for _step in 0..total_steps {
1160            for block_idx in (0..num_blocks).rev() {
1161                let result =
1162                    server.collect_and_reduce_block(_step as u64, block_idx as u32).unwrap();
1163                server.broadcast_averaged_block(_step as u64, &result).unwrap();
1164            }
1165            for component in [0u8, 1, 2] {
1166                let result = server.collect_and_reduce_non_block(_step as u64, component).unwrap();
1167                server.broadcast_averaged_non_block(_step as u64, &result).unwrap();
1168            }
1169        }
1170        eprintln!("[coordinator] Training complete ({total_steps} steps)");
1171    }))
1172}
1173
1174#[cfg(feature = "cuda")]
1175fn train_loop_cuda_distributed(
1176    mut cuda_trainer: CudaTransformerTrainer,
1177    batches: &[LMBatch],
1178    spec: &TrainSpec,
1179) -> Result<()> {
1180    use crate::finetune::distributed::DistributedConfig;
1181    use crate::finetune::WorkerClient;
1182    use crate::train::{shard_batches, DistributedComm, DistributedCudaTrainer};
1183
1184    let dist_config = cuda_trainer
1185        .config()
1186        .distributed
1187        .clone()
1188        .ok_or_else(|| Error::ConfigError("missing distributed config".into()))?;
1189
1190    let rank = dist_config.rank;
1191    let world_size = dist_config.world_size;
1192    let coord_addr = dist_config.coordinator_addr;
1193
1194    println!("Starting distributed training (DDP)...");
1195    println!("  rank: {rank}/{world_size}");
1196    println!("  coordinator: {coord_addr}");
1197
1198    cuda_trainer.ensure_grad_accum();
1199
1200    let num_blocks = cuda_trainer
1201        .grad_accum_ref()
1202        .map_or(0, crate::train::PerBlockGradientAccumulator::num_blocks);
1203
1204    // Step 1: If rank 0, spawn GradientServer in background thread
1205    let server_handle = if rank == 0 {
1206        let max_steps = spec.training.max_steps.unwrap_or(usize::MAX);
1207        let batches_per_worker = batches.len().div_ceil(world_size);
1208        let total_steps = std::cmp::min(spec.training.epochs * batches_per_worker, max_steps);
1209        Some(spawn_coordinator_thread(coord_addr, world_size, num_blocks, total_steps)?)
1210    } else {
1211        std::thread::sleep(std::time::Duration::from_millis(100));
1212        None
1213    };
1214
1215    // Step 2: Connect as worker (all ranks, including rank 0)
1216    let worker_config = DistributedConfig::worker(coord_addr);
1217    let client = WorkerClient::connect(worker_config, 1, "cuda")
1218        .map_err(|e| Error::ConfigError(format!("WorkerClient connect failed: {e}")))?;
1219    println!("  ✓ Connected as worker {} (id={})", rank, client.worker_id());
1220
1221    // Step 3: Wrap in DistributedCudaTrainer
1222    let comm = DistributedComm::Remote { client };
1223    let mut ddp_trainer = DistributedCudaTrainer::new(cuda_trainer, comm, dist_config.clone());
1224
1225    // Step 4: Training loop with data sharding
1226    let num_batches = batches.len();
1227    let start_time = std::time::Instant::now();
1228    let log_interval = std::cmp::max(num_batches / (world_size * 100).max(1), 1);
1229    let save_interval = spec.training.save_interval;
1230    let max_checkpoints = spec.training.max_checkpoints;
1231    let seed = spec.training.seed.unwrap_or(42);
1232
1233    // ALB-082: Scaling law predictor for DDP path
1234    let _scaling_predictor = ScalingLawPredictor::new();
1235    let seq_len_ddp = spec.data.seq_len.unwrap_or(128);
1236    let grad_accum_ddp = spec.training.gradient_accumulation.unwrap_or(1);
1237    let _tokens_per_step_ddp = spec.data.batch_size * seq_len_ddp * grad_accum_ddp;
1238
1239    let model_name = spec
1240        .model
1241        .path
1242        .file_name()
1243        .and_then(|n| n.to_str())
1244        .unwrap_or("entrenar-model")
1245        .to_string();
1246
1247    // R-005: Load validation batches
1248    let _val_batches = load_val_batches(spec);
1249
1250    let mut loss_history: Vec<f32> = Vec::new();
1251    let mut last_save_step: usize = 0;
1252
1253    for epoch in 0..spec.training.epochs {
1254        let epoch_start = std::time::Instant::now();
1255        let mut total_loss = 0.0;
1256        let mut batches_processed = 0;
1257
1258        // Shard batches by rank: worker N gets N, N+ws, N+2*ws, ...
1259        let my_batch_indices = shard_batches(num_batches, rank, world_size);
1260
1261        for (iter_idx, &batch_idx) in my_batch_indices.iter().enumerate() {
1262            if ddp_trainer.reached_max_steps() {
1263                break;
1264            }
1265
1266            let batch = &batches[batch_idx];
1267            let step_start = std::time::Instant::now();
1268            let batch_loss = ddp_trainer.train_batch(batch);
1269            let step_elapsed = step_start.elapsed();
1270
1271            if !batch_loss.is_finite() {
1272                continue;
1273            }
1274            total_loss += batch_loss;
1275            batches_processed += 1;
1276            push_capped(&mut loss_history, batch_loss, 100);
1277
1278            // Logging (rank 0 only to avoid spam)
1279            if rank == 0 && should_log(iter_idx, log_interval) {
1280                let step = ddp_trainer.step();
1281                let elapsed = epoch_start.elapsed().as_secs_f64();
1282                let seq_len = spec.data.seq_len.unwrap_or(128);
1283                let tokens_done = (iter_idx + 1) * spec.data.batch_size * seq_len * world_size;
1284                let tok_per_sec = tokens_done as f64 / elapsed.max(0.001);
1285                println!(
1286                    "  [DDP rank 0] step={} loss={:.4} tok/s={:.0} step_time={:.1}ms",
1287                    step,
1288                    batch_loss,
1289                    tok_per_sec,
1290                    step_elapsed.as_secs_f64() * 1000.0,
1291                );
1292            }
1293
1294            // Checkpoint (rank 0 only)
1295            if rank == 0 {
1296                let current_step = ddp_trainer.step();
1297                if should_save_checkpoint(current_step, last_save_step, save_interval) {
1298                    save_and_validate_checkpoint(
1299                        ddp_trainer.trainer_mut(),
1300                        spec,
1301                        &model_name,
1302                        current_step,
1303                        epoch,
1304                        iter_idx,
1305                        max_checkpoints,
1306                        seed,
1307                        0.0,
1308                    );
1309                    last_save_step = current_step;
1310                }
1311            }
1312        }
1313
1314        if batches_processed > 0 {
1315            let avg_loss = total_loss / batches_processed as f32;
1316            let ppl = crate::train::perplexity(avg_loss);
1317            if rank == 0 {
1318                println!(
1319                    "Epoch {}/{}: loss={:.6}, perplexity={:.2}, time={:.1}s",
1320                    epoch + 1,
1321                    spec.training.epochs,
1322                    avg_loss,
1323                    ppl,
1324                    epoch_start.elapsed().as_secs_f64(),
1325                );
1326            }
1327        }
1328
1329        if ddp_trainer.reached_max_steps() {
1330            break;
1331        }
1332    }
1333
1334    let total_time = start_time.elapsed();
1335    if rank == 0 {
1336        println!("Total distributed training time: {:.1}s", total_time.as_secs_f64());
1337    }
1338
1339    // Save final model (rank 0 only)
1340    if rank == 0 {
1341        save_trained_model_cuda(ddp_trainer.trainer_mut(), spec)?;
1342    }
1343
1344    // Wait for coordinator thread to finish
1345    if let Some(handle) = server_handle {
1346        let _: std::result::Result<(), _> = handle.join();
1347    }
1348
1349    Ok(())
1350}
1351
1352// Training utility functions: checkpointing, logging, validation, metrics
1353include!("helpers.rs");
1354
1355// Model loading, data loading, and config parsing
1356include!("data.rs");
1357
1358#[cfg(test)]
1359mod tests;