Skip to main content

hermes_llm/
training.rs

1use anyhow::Result;
2use candle_core::{DType, Device};
3use candle_nn::VarMap;
4use candle_nn::optim::{AdamW, Optimizer, ParamsAdamW};
5use indicatif::{ProgressBar, ProgressStyle};
6use serde::{Deserialize, Serialize};
7use std::path::Path;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicBool, Ordering};
10use tracing::info;
11
12use crate::config::TrainingConfig;
13use crate::data::DataLoader;
14use crate::mal::ModelDef;
15use crate::model::{Transformer, cross_entropy_loss};
16
17/// Create a styled progress bar for training.
18pub fn create_progress_bar(total: u64, show: bool) -> ProgressBar {
19    if show {
20        let pb = ProgressBar::new(total);
21        pb.set_style(
22            ProgressStyle::default_bar()
23                .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} loss: {msg}")
24                .unwrap()
25                .progress_chars("##-"),
26        );
27        pb.enable_steady_tick(std::time::Duration::from_millis(100));
28        pb
29    } else {
30        ProgressBar::hidden()
31    }
32}
33
34/// Training state for checkpointing and resume
35#[derive(Serialize, Deserialize, Clone, Debug)]
36pub struct TrainingState {
37    pub epoch: usize,
38    pub batch_position: usize,
39    pub global_step: usize,
40    pub shuffle_seed: u64,
41}
42
43impl Default for TrainingState {
44    fn default() -> Self {
45        Self {
46            epoch: 0,
47            batch_position: 0,
48            global_step: 0,
49            shuffle_seed: 42,
50        }
51    }
52}
53
54pub struct Trainer {
55    model: Transformer,
56    optimizer: AdamW,
57    var_map: VarMap,
58    #[allow(dead_code)]
59    config: ModelDef,
60    training_config: TrainingConfig,
61    device: Device,
62    global_step: usize,
63    /// Signal for graceful shutdown on Ctrl+C
64    interrupted: Arc<AtomicBool>,
65}
66
67impl Trainer {
68    pub fn new(config: ModelDef, training_config: TrainingConfig, device: Device) -> Result<Self> {
69        let var_map = VarMap::new();
70        let vb = candle_nn::VarBuilder::from_varmap(&var_map, DType::F32, &device);
71        let model = Transformer::new(&config, vb)?;
72
73        let params = ParamsAdamW {
74            lr: training_config.learning_rate,
75            beta1: training_config.beta1,
76            beta2: training_config.beta2,
77            weight_decay: training_config.weight_decay,
78            eps: 1e-8,
79        };
80        let optimizer = AdamW::new(var_map.all_vars(), params)?;
81
82        info!(
83            "Initialized model with {} parameters",
84            model.num_parameters()
85        );
86
87        // Set up Ctrl+C handler
88        let interrupted = Arc::new(AtomicBool::new(false));
89        let interrupted_clone = Arc::clone(&interrupted);
90        let _ = ctrlc::set_handler(move || {
91            // Use eprintln here since tracing may not be flushed before exit
92            eprintln!("\nInterrupt received, saving checkpoint...");
93            interrupted_clone.store(true, Ordering::SeqCst);
94        });
95
96        Ok(Self {
97            model,
98            optimizer,
99            var_map,
100            config,
101            training_config,
102            device,
103            global_step: 0,
104            interrupted,
105        })
106    }
107
108    /// Check if training was interrupted
109    pub fn is_interrupted(&self) -> bool {
110        self.interrupted.load(Ordering::SeqCst)
111    }
112
113    /// Reset interrupt flag
114    pub fn clear_interrupt(&self) {
115        self.interrupted.store(false, Ordering::SeqCst);
116    }
117
118    /// Freeze the first N layers (embeddings count as layer 0)
119    /// Frozen layers will not be updated during training
120    pub fn freeze_layers(&mut self, num_layers: usize) -> Result<()> {
121        let frozen_prefixes: Vec<String> =
122            (0..num_layers).map(|i| format!("layers.{}", i)).collect();
123
124        // Also freeze embeddings if num_layers > 0
125        let mut prefixes = frozen_prefixes;
126        if num_layers > 0 {
127            prefixes.push("tok_emb".to_string());
128        }
129
130        let mut frozen_count = 0;
131        for (name, var) in self.var_map.data().lock().unwrap().iter() {
132            for prefix in &prefixes {
133                if name.starts_with(prefix) {
134                    // Detach tensor to prevent gradient computation
135                    let tensor = var.as_tensor();
136                    let _ = var.set(&tensor.detach());
137                    frozen_count += 1;
138                    break;
139                }
140            }
141        }
142
143        info!("Frozen {} parameter tensors", frozen_count);
144        Ok(())
145    }
146
147    pub fn train_epoch(&mut self, train_loader: &mut DataLoader) -> Result<f64> {
148        self.train_epoch_distributed(train_loader, None)
149    }
150
151    pub fn train_epoch_distributed(
152        &mut self,
153        train_loader: &mut DataLoader,
154        comm: Option<&crate::distributed::NcclCommunicator>,
155    ) -> Result<f64> {
156        train_loader.reset();
157        let (loss, _interrupted) = self.train_epoch_interruptible(train_loader, comm)?;
158        Ok(loss)
159    }
160
161    pub fn evaluate(&self, eval_loader: &mut DataLoader) -> Result<f64> {
162        let mut total_loss = 0.0;
163        let mut num_batches = 0;
164
165        eval_loader.reset();
166
167        while let Some(batch_result) = eval_loader.next_batch(&self.device)? {
168            let (input, target) = (batch_result.0, batch_result.1);
169
170            let logits = self.model.forward(&input, 0, false)?;
171            let loss = cross_entropy_loss(&logits, &target)?;
172
173            total_loss += loss.to_scalar::<f32>()? as f64;
174            num_batches += 1;
175        }
176
177        if num_batches > 0 {
178            Ok(total_loss / num_batches as f64)
179        } else {
180            Ok(0.0)
181        }
182    }
183
184    pub fn train(
185        &mut self,
186        train_loader: &mut DataLoader,
187        eval_loader: Option<&mut DataLoader>,
188        checkpoint_dir: Option<&str>,
189    ) -> Result<bool> {
190        self.train_distributed(train_loader, eval_loader, checkpoint_dir, None)
191    }
192
193    pub fn train_distributed(
194        &mut self,
195        train_loader: &mut DataLoader,
196        eval_loader: Option<&mut DataLoader>,
197        checkpoint_dir: Option<&str>,
198        comm: Option<&crate::distributed::NcclCommunicator>,
199    ) -> Result<bool> {
200        self.train_resumable(train_loader, eval_loader, checkpoint_dir, comm, None)
201    }
202
203    /// Train with support for interruption and resume
204    /// Returns true if training completed, false if interrupted
205    pub fn train_resumable(
206        &mut self,
207        train_loader: &mut DataLoader,
208        mut eval_loader: Option<&mut DataLoader>,
209        checkpoint_dir: Option<&str>,
210        comm: Option<&crate::distributed::NcclCommunicator>,
211        resume_state: Option<TrainingState>,
212    ) -> Result<bool> {
213        let is_main = comm.is_none_or(|c| c.rank() == 0);
214
215        // Resume from saved state if provided
216        let (start_epoch, start_position) = match resume_state {
217            Some(ref state) => {
218                self.global_step = state.global_step;
219                if is_main {
220                    info!(
221                        "Resuming from epoch {}, batch position {}, global step {}",
222                        state.epoch + 1,
223                        state.batch_position,
224                        state.global_step
225                    );
226                }
227                (state.epoch, state.batch_position)
228            }
229            None => (0, 0),
230        };
231
232        if is_main {
233            info!(
234                "Starting training for {} epochs",
235                self.training_config.epochs
236            );
237        }
238
239        for epoch in start_epoch..self.training_config.epochs {
240            if is_main {
241                info!("Epoch {}/{}", epoch + 1, self.training_config.epochs);
242            }
243
244            // Use epoch as shuffle seed for reproducibility
245            let shuffle_seed = epoch as u64;
246            train_loader.reset_with_seed(shuffle_seed);
247
248            // Resume from position if this is the starting epoch
249            if epoch == start_epoch && start_position > 0 {
250                train_loader.set_position(start_position);
251                if is_main {
252                    info!("Resuming from batch position {}", start_position);
253                }
254            }
255
256            let (train_loss, interrupted) = self.train_epoch_interruptible(train_loader, comm)?;
257
258            // Handle interrupt
259            if interrupted {
260                if is_main && let Some(dir) = checkpoint_dir {
261                    let state = TrainingState {
262                        epoch,
263                        batch_position: train_loader.position(),
264                        global_step: self.global_step,
265                        shuffle_seed,
266                    };
267                    self.save_training_state(dir, &state)?;
268                    info!("Saved interrupt checkpoint to {}", dir);
269                }
270                return Ok(false);
271            }
272
273            if is_main {
274                info!("Epoch {} train loss: {:.4}", epoch + 1, train_loss);
275            }
276
277            if let Some(ref mut eval) = eval_loader {
278                let eval_loss = self.evaluate(eval)?;
279                if is_main {
280                    info!("Epoch {} eval loss: {:.4}", epoch + 1, eval_loss);
281                }
282            }
283
284            // Only main process saves checkpoints
285            if is_main && let Some(dir) = checkpoint_dir {
286                let path = format!("{}/checkpoint_epoch_{}.safetensors", dir, epoch + 1);
287                self.save_checkpoint(&path)?;
288                info!("Saved checkpoint to {}", path);
289            }
290
291            // Sync all ranks after checkpoint save
292            if let Some(c) = comm {
293                c.barrier()?;
294            }
295        }
296
297        Ok(true)
298    }
299
300    /// Train epoch with interrupt checking
301    /// Returns (loss, was_interrupted)
302    fn train_epoch_interruptible(
303        &mut self,
304        train_loader: &mut DataLoader,
305        comm: Option<&crate::distributed::NcclCommunicator>,
306    ) -> Result<(f64, bool)> {
307        let is_main = comm.is_none_or(|c| c.rank() == 0);
308        let num_batches = train_loader.num_batches();
309
310        let pb = create_progress_bar(num_batches as u64, is_main);
311
312        let mut total_loss = 0.0;
313        let mut num_steps = 0;
314        let mut accumulated_loss = 0.0;
315
316        while let Some(batch_result) = train_loader.next_batch(&self.device)? {
317            // Check for interrupt
318            if self.is_interrupted() {
319                pb.finish_with_message("interrupted");
320                return Ok((total_loss / num_steps.max(1) as f64, true));
321            }
322
323            let (input, target) = (batch_result.0, batch_result.1);
324
325            let logits = self.model.forward(&input, 0, true)?;
326            let loss = cross_entropy_loss(&logits, &target)?;
327
328            accumulated_loss += loss.to_scalar::<f32>()? as f64;
329            num_steps += 1;
330
331            if num_steps % self.training_config.gradient_accumulation_steps == 0 {
332                let avg_loss =
333                    accumulated_loss / self.training_config.gradient_accumulation_steps as f64;
334
335                self.optimizer.backward_step(&loss)?;
336
337                if let Some(c) = comm {
338                    crate::distributed::sync_gradients(&self.var_map, c)?;
339                }
340
341                if self.training_config.grad_clip > 0.0 {
342                    for var in self.var_map.all_vars() {
343                        let grad = var.as_tensor();
344                        let norm = grad.sqr()?.sum_all()?.sqrt()?.to_scalar::<f32>()?;
345                        if norm > self.training_config.grad_clip as f32 {
346                            let scale = self.training_config.grad_clip as f32 / norm;
347                            let _ = var.set(&grad.affine(scale as f64, 0.0)?);
348                        }
349                    }
350                }
351
352                total_loss += avg_loss;
353                accumulated_loss = 0.0;
354                self.global_step += 1;
355
356                if self
357                    .global_step
358                    .is_multiple_of(self.training_config.log_every)
359                {
360                    pb.set_message(format!("{:.4}", avg_loss));
361                }
362            }
363
364            pb.inc(1);
365        }
366
367        pb.finish_with_message("done");
368
369        let effective_steps = self.global_step;
370        if effective_steps > 0 {
371            Ok((total_loss / effective_steps as f64, false))
372        } else {
373            Ok((0.0, false))
374        }
375    }
376
377    pub fn save_checkpoint<P: AsRef<Path>>(&self, path: P) -> Result<()> {
378        self.var_map.save(path)?;
379        Ok(())
380    }
381
382    pub fn load_checkpoint<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
383        self.var_map.load(path)?;
384        Ok(())
385    }
386
387    /// Save full training state (weights + progress) for resumable training
388    pub fn save_training_state<P: AsRef<Path>>(&self, dir: P, state: &TrainingState) -> Result<()> {
389        let dir = dir.as_ref();
390        std::fs::create_dir_all(dir)?;
391
392        // Save model weights
393        let weights_path = dir.join("weights.safetensors");
394        self.var_map.save(&weights_path)?;
395
396        // Save training state
397        let state_path = dir.join("training_state.json");
398        let state_json = serde_json::to_string_pretty(state)?;
399        std::fs::write(&state_path, state_json)?;
400
401        Ok(())
402    }
403
404    /// Load full training state (weights + progress) for resuming
405    pub fn load_training_state<P: AsRef<Path>>(&mut self, dir: P) -> Result<TrainingState> {
406        let dir = dir.as_ref();
407
408        // Load model weights
409        let weights_path = dir.join("weights.safetensors");
410        if weights_path.exists() {
411            self.var_map.load(&weights_path)?;
412        }
413
414        // Load training state
415        let state_path = dir.join("training_state.json");
416        let state = if state_path.exists() {
417            let state_json = std::fs::read_to_string(&state_path)?;
418            serde_json::from_str(&state_json)?
419        } else {
420            TrainingState::default()
421        };
422
423        self.global_step = state.global_step;
424        Ok(state)
425    }
426
427    pub fn model(&self) -> &Transformer {
428        &self.model
429    }
430
431    pub fn var_map(&self) -> &VarMap {
432        &self.var_map
433    }
434
435    pub fn device(&self) -> &Device {
436        &self.device
437    }
438
439    pub fn global_step(&self) -> usize {
440        self.global_step
441    }
442}
443
444// Re-export from generate module for backward compatibility
445pub use crate::generate::TextGenerator;