1use anyhow::Result;
2use candle_core::{DType, Device};
3use candle_nn::VarMap;
4use candle_nn::optim::{AdamW, Optimizer, ParamsAdamW};
5use indicatif::{ProgressBar, ProgressStyle};
6use serde::{Deserialize, Serialize};
7use std::path::Path;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicBool, Ordering};
10use tracing::info;
11
12use crate::config::TrainingConfig;
13use crate::data::DataLoader;
14use crate::mal::ModelDef;
15use crate::model::{Transformer, cross_entropy_loss};
16
17pub fn create_progress_bar(total: u64, show: bool) -> ProgressBar {
19 if show {
20 let pb = ProgressBar::new(total);
21 pb.set_style(
22 ProgressStyle::default_bar()
23 .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} loss: {msg}")
24 .unwrap()
25 .progress_chars("##-"),
26 );
27 pb.enable_steady_tick(std::time::Duration::from_millis(100));
28 pb
29 } else {
30 ProgressBar::hidden()
31 }
32}
33
34#[derive(Serialize, Deserialize, Clone, Debug)]
36pub struct TrainingState {
37 pub epoch: usize,
38 pub batch_position: usize,
39 pub global_step: usize,
40 pub shuffle_seed: u64,
41}
42
43impl Default for TrainingState {
44 fn default() -> Self {
45 Self {
46 epoch: 0,
47 batch_position: 0,
48 global_step: 0,
49 shuffle_seed: 42,
50 }
51 }
52}
53
54pub struct Trainer {
55 model: Transformer,
56 optimizer: AdamW,
57 var_map: VarMap,
58 #[allow(dead_code)]
59 config: ModelDef,
60 training_config: TrainingConfig,
61 device: Device,
62 global_step: usize,
63 interrupted: Arc<AtomicBool>,
65}
66
67impl Trainer {
68 pub fn new(config: ModelDef, training_config: TrainingConfig, device: Device) -> Result<Self> {
69 let var_map = VarMap::new();
70 let vb = candle_nn::VarBuilder::from_varmap(&var_map, DType::F32, &device);
71 let model = Transformer::new(&config, vb)?;
72
73 let params = ParamsAdamW {
74 lr: training_config.learning_rate,
75 beta1: training_config.beta1,
76 beta2: training_config.beta2,
77 weight_decay: training_config.weight_decay,
78 eps: 1e-8,
79 };
80 let optimizer = AdamW::new(var_map.all_vars(), params)?;
81
82 info!(
83 "Initialized model with {} parameters",
84 model.num_parameters()
85 );
86
87 let interrupted = Arc::new(AtomicBool::new(false));
89 let interrupted_clone = Arc::clone(&interrupted);
90 let _ = ctrlc::set_handler(move || {
91 eprintln!("\nInterrupt received, saving checkpoint...");
93 interrupted_clone.store(true, Ordering::SeqCst);
94 });
95
96 Ok(Self {
97 model,
98 optimizer,
99 var_map,
100 config,
101 training_config,
102 device,
103 global_step: 0,
104 interrupted,
105 })
106 }
107
108 pub fn is_interrupted(&self) -> bool {
110 self.interrupted.load(Ordering::SeqCst)
111 }
112
113 pub fn clear_interrupt(&self) {
115 self.interrupted.store(false, Ordering::SeqCst);
116 }
117
118 pub fn freeze_layers(&mut self, num_layers: usize) -> Result<()> {
121 let frozen_prefixes: Vec<String> =
122 (0..num_layers).map(|i| format!("layers.{}", i)).collect();
123
124 let mut prefixes = frozen_prefixes;
126 if num_layers > 0 {
127 prefixes.push("tok_emb".to_string());
128 }
129
130 let mut frozen_count = 0;
131 for (name, var) in self.var_map.data().lock().unwrap().iter() {
132 for prefix in &prefixes {
133 if name.starts_with(prefix) {
134 let tensor = var.as_tensor();
136 let _ = var.set(&tensor.detach());
137 frozen_count += 1;
138 break;
139 }
140 }
141 }
142
143 info!("Frozen {} parameter tensors", frozen_count);
144 Ok(())
145 }
146
147 pub fn train_epoch(&mut self, train_loader: &mut DataLoader) -> Result<f64> {
148 self.train_epoch_distributed(train_loader, None)
149 }
150
151 pub fn train_epoch_distributed(
152 &mut self,
153 train_loader: &mut DataLoader,
154 comm: Option<&crate::distributed::NcclCommunicator>,
155 ) -> Result<f64> {
156 train_loader.reset();
157 let (loss, _interrupted) = self.train_epoch_interruptible(train_loader, comm)?;
158 Ok(loss)
159 }
160
161 pub fn evaluate(&self, eval_loader: &mut DataLoader) -> Result<f64> {
162 let mut total_loss = 0.0;
163 let mut num_batches = 0;
164
165 eval_loader.reset();
166
167 while let Some(batch_result) = eval_loader.next_batch(&self.device)? {
168 let (input, target) = (batch_result.0, batch_result.1);
169
170 let logits = self.model.forward(&input, 0, false)?;
171 let loss = cross_entropy_loss(&logits, &target)?;
172
173 total_loss += loss.to_scalar::<f32>()? as f64;
174 num_batches += 1;
175 }
176
177 if num_batches > 0 {
178 Ok(total_loss / num_batches as f64)
179 } else {
180 Ok(0.0)
181 }
182 }
183
184 pub fn train(
185 &mut self,
186 train_loader: &mut DataLoader,
187 eval_loader: Option<&mut DataLoader>,
188 checkpoint_dir: Option<&str>,
189 ) -> Result<bool> {
190 self.train_distributed(train_loader, eval_loader, checkpoint_dir, None)
191 }
192
193 pub fn train_distributed(
194 &mut self,
195 train_loader: &mut DataLoader,
196 eval_loader: Option<&mut DataLoader>,
197 checkpoint_dir: Option<&str>,
198 comm: Option<&crate::distributed::NcclCommunicator>,
199 ) -> Result<bool> {
200 self.train_resumable(train_loader, eval_loader, checkpoint_dir, comm, None)
201 }
202
203 pub fn train_resumable(
206 &mut self,
207 train_loader: &mut DataLoader,
208 mut eval_loader: Option<&mut DataLoader>,
209 checkpoint_dir: Option<&str>,
210 comm: Option<&crate::distributed::NcclCommunicator>,
211 resume_state: Option<TrainingState>,
212 ) -> Result<bool> {
213 let is_main = comm.is_none_or(|c| c.rank() == 0);
214
215 let (start_epoch, start_position) = match resume_state {
217 Some(ref state) => {
218 self.global_step = state.global_step;
219 if is_main {
220 info!(
221 "Resuming from epoch {}, batch position {}, global step {}",
222 state.epoch + 1,
223 state.batch_position,
224 state.global_step
225 );
226 }
227 (state.epoch, state.batch_position)
228 }
229 None => (0, 0),
230 };
231
232 if is_main {
233 info!(
234 "Starting training for {} epochs",
235 self.training_config.epochs
236 );
237 }
238
239 for epoch in start_epoch..self.training_config.epochs {
240 if is_main {
241 info!("Epoch {}/{}", epoch + 1, self.training_config.epochs);
242 }
243
244 let shuffle_seed = epoch as u64;
246 train_loader.reset_with_seed(shuffle_seed);
247
248 if epoch == start_epoch && start_position > 0 {
250 train_loader.set_position(start_position);
251 if is_main {
252 info!("Resuming from batch position {}", start_position);
253 }
254 }
255
256 let (train_loss, interrupted) = self.train_epoch_interruptible(train_loader, comm)?;
257
258 if interrupted {
260 if is_main && let Some(dir) = checkpoint_dir {
261 let state = TrainingState {
262 epoch,
263 batch_position: train_loader.position(),
264 global_step: self.global_step,
265 shuffle_seed,
266 };
267 self.save_training_state(dir, &state)?;
268 info!("Saved interrupt checkpoint to {}", dir);
269 }
270 return Ok(false);
271 }
272
273 if is_main {
274 info!("Epoch {} train loss: {:.4}", epoch + 1, train_loss);
275 }
276
277 if let Some(ref mut eval) = eval_loader {
278 let eval_loss = self.evaluate(eval)?;
279 if is_main {
280 info!("Epoch {} eval loss: {:.4}", epoch + 1, eval_loss);
281 }
282 }
283
284 if is_main && let Some(dir) = checkpoint_dir {
286 let path = format!("{}/checkpoint_epoch_{}.safetensors", dir, epoch + 1);
287 self.save_checkpoint(&path)?;
288 info!("Saved checkpoint to {}", path);
289 }
290
291 if let Some(c) = comm {
293 c.barrier()?;
294 }
295 }
296
297 Ok(true)
298 }
299
300 fn train_epoch_interruptible(
303 &mut self,
304 train_loader: &mut DataLoader,
305 comm: Option<&crate::distributed::NcclCommunicator>,
306 ) -> Result<(f64, bool)> {
307 let is_main = comm.is_none_or(|c| c.rank() == 0);
308 let num_batches = train_loader.num_batches();
309
310 let pb = create_progress_bar(num_batches as u64, is_main);
311
312 let mut total_loss = 0.0;
313 let mut num_steps = 0;
314 let mut accumulated_loss = 0.0;
315
316 while let Some(batch_result) = train_loader.next_batch(&self.device)? {
317 if self.is_interrupted() {
319 pb.finish_with_message("interrupted");
320 return Ok((total_loss / num_steps.max(1) as f64, true));
321 }
322
323 let (input, target) = (batch_result.0, batch_result.1);
324
325 let logits = self.model.forward(&input, 0, true)?;
326 let loss = cross_entropy_loss(&logits, &target)?;
327
328 accumulated_loss += loss.to_scalar::<f32>()? as f64;
329 num_steps += 1;
330
331 if num_steps % self.training_config.gradient_accumulation_steps == 0 {
332 let avg_loss =
333 accumulated_loss / self.training_config.gradient_accumulation_steps as f64;
334
335 self.optimizer.backward_step(&loss)?;
336
337 if let Some(c) = comm {
338 crate::distributed::sync_gradients(&self.var_map, c)?;
339 }
340
341 if self.training_config.grad_clip > 0.0 {
342 for var in self.var_map.all_vars() {
343 let grad = var.as_tensor();
344 let norm = grad.sqr()?.sum_all()?.sqrt()?.to_scalar::<f32>()?;
345 if norm > self.training_config.grad_clip as f32 {
346 let scale = self.training_config.grad_clip as f32 / norm;
347 let _ = var.set(&grad.affine(scale as f64, 0.0)?);
348 }
349 }
350 }
351
352 total_loss += avg_loss;
353 accumulated_loss = 0.0;
354 self.global_step += 1;
355
356 if self
357 .global_step
358 .is_multiple_of(self.training_config.log_every)
359 {
360 pb.set_message(format!("{:.4}", avg_loss));
361 }
362 }
363
364 pb.inc(1);
365 }
366
367 pb.finish_with_message("done");
368
369 let effective_steps = self.global_step;
370 if effective_steps > 0 {
371 Ok((total_loss / effective_steps as f64, false))
372 } else {
373 Ok((0.0, false))
374 }
375 }
376
377 pub fn save_checkpoint<P: AsRef<Path>>(&self, path: P) -> Result<()> {
378 self.var_map.save(path)?;
379 Ok(())
380 }
381
382 pub fn load_checkpoint<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
383 self.var_map.load(path)?;
384 Ok(())
385 }
386
387 pub fn save_training_state<P: AsRef<Path>>(&self, dir: P, state: &TrainingState) -> Result<()> {
389 let dir = dir.as_ref();
390 std::fs::create_dir_all(dir)?;
391
392 let weights_path = dir.join("weights.safetensors");
394 self.var_map.save(&weights_path)?;
395
396 let state_path = dir.join("training_state.json");
398 let state_json = serde_json::to_string_pretty(state)?;
399 std::fs::write(&state_path, state_json)?;
400
401 Ok(())
402 }
403
404 pub fn load_training_state<P: AsRef<Path>>(&mut self, dir: P) -> Result<TrainingState> {
406 let dir = dir.as_ref();
407
408 let weights_path = dir.join("weights.safetensors");
410 if weights_path.exists() {
411 self.var_map.load(&weights_path)?;
412 }
413
414 let state_path = dir.join("training_state.json");
416 let state = if state_path.exists() {
417 let state_json = std::fs::read_to_string(&state_path)?;
418 serde_json::from_str(&state_json)?
419 } else {
420 TrainingState::default()
421 };
422
423 self.global_step = state.global_step;
424 Ok(state)
425 }
426
427 pub fn model(&self) -> &Transformer {
428 &self.model
429 }
430
431 pub fn var_map(&self) -> &VarMap {
432 &self.var_map
433 }
434
435 pub fn device(&self) -> &Device {
436 &self.device
437 }
438
439 pub fn global_step(&self) -> usize {
440 self.global_step
441 }
442}
443
444pub use crate::generate::TextGenerator;