Skip to main content

entrenar/train/trainer/
core.rs

1//! Core Trainer struct and basic methods
2
3use crate::io::{save_model, Model, ModelFormat, ModelMetadata, SaveConfig};
4use crate::optim::Optimizer;
5use crate::train::callback::{CallbackContext, CallbackManager, TrainerCallback};
6use crate::train::{LossFn, MetricsTracker, TrainConfig};
7use crate::Tensor;
8use provable_contracts_macros::requires;
9use std::path::Path;
10use std::time::Instant;
11
12/// High-level trainer that orchestrates the training loop
13///
14/// # Example
15///
16/// ```no_run
17/// use entrenar::train::{Trainer, TrainConfig, Batch, MSELoss, EarlyStopping};
18/// use entrenar::optim::Adam;
19/// use entrenar::Tensor;
20///
21/// // Setup
22/// let params = vec![Tensor::zeros(10, true)];
23/// let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
24/// let config = TrainConfig::default();
25///
26/// let mut trainer = Trainer::new(params, Box::new(optimizer), config);
27/// trainer.set_loss(Box::new(MSELoss));
28/// trainer.add_callback(EarlyStopping::new(5, 0.001));
29///
30/// // Training with callbacks
31/// // let result = trainer.train(10, || batches.clone(), |x| x.clone());
32/// ```
33pub struct Trainer {
34    /// Model parameters
35    pub(crate) params: Vec<Tensor>,
36
37    /// Optimizer
38    pub(crate) optimizer: Box<dyn Optimizer>,
39
40    /// Loss function
41    pub(crate) loss_fn: Option<Box<dyn LossFn>>,
42
43    /// Training configuration
44    pub(crate) config: TrainConfig,
45
46    /// Metrics tracker
47    pub metrics: MetricsTracker,
48
49    /// Callback manager
50    pub(crate) callbacks: CallbackManager,
51
52    /// Best loss achieved during training
53    pub(crate) best_loss: Option<f32>,
54
55    /// Training start time
56    pub(crate) start_time: Option<Instant>,
57}
58
59impl Trainer {
60    /// Create a new trainer
61    pub fn new(params: Vec<Tensor>, optimizer: Box<dyn Optimizer>, config: TrainConfig) -> Self {
62        Self {
63            params,
64            optimizer,
65            loss_fn: None,
66            config,
67            metrics: MetricsTracker::new(),
68            callbacks: CallbackManager::new(),
69            best_loss: None,
70            start_time: None,
71        }
72    }
73
74    /// Set the loss function
75    pub fn set_loss(&mut self, loss_fn: Box<dyn LossFn>) {
76        self.loss_fn = Some(loss_fn);
77    }
78
79    /// Add a callback to the trainer
80    pub fn add_callback<C: TrainerCallback + 'static>(&mut self, callback: C) {
81        self.callbacks.add(callback);
82    }
83
84    /// Get current learning rate
85    pub fn lr(&self) -> f32 {
86        self.optimizer.lr()
87    }
88
89    /// Set learning rate
90    pub fn set_lr(&mut self, lr: f32) {
91        self.optimizer.set_lr(lr);
92    }
93
94    /// Get reference to model parameters
95    pub fn params(&self) -> &[Tensor] {
96        &self.params
97    }
98
99    /// Get mutable reference to model parameters
100    pub fn params_mut(&mut self) -> &mut [Tensor] {
101        &mut self.params
102    }
103
104    /// Get reference to callback manager
105    pub fn callbacks(&self) -> &CallbackManager {
106        &self.callbacks
107    }
108
109    /// Get mutable reference to callback manager
110    pub fn callbacks_mut(&mut self) -> &mut CallbackManager {
111        &mut self.callbacks
112    }
113
114    /// Build callback context from current state
115    pub(crate) fn build_context(
116        &self,
117        epoch: usize,
118        max_epochs: usize,
119        step: usize,
120        steps_per_epoch: usize,
121        loss: f32,
122        val_loss: Option<f32>,
123    ) -> CallbackContext {
124        CallbackContext {
125            epoch,
126            max_epochs,
127            step,
128            steps_per_epoch,
129            global_step: self.metrics.steps,
130            loss,
131            lr: self.lr(),
132            best_loss: self.best_loss,
133            val_loss,
134            elapsed_secs: self.start_time.map_or(0.0, |t| t.elapsed().as_secs_f64()),
135        }
136    }
137
138    /// Save model parameters to a file
139    ///
140    /// This method persists the trained model weights to disk in SafeTensors format.
141    /// Call this after training completes to preserve the learned parameters.
142    ///
143    /// # Arguments
144    ///
145    /// * `path` - Output file path (should end in .safetensors)
146    /// * `name` - Model name for metadata
147    /// * `architecture` - Model architecture description
148    ///
149    /// # Example
150    ///
151    /// ```no_run
152    /// # use entrenar::train::{Trainer, TrainConfig};
153    /// # use entrenar::optim::Adam;
154    /// # use entrenar::Tensor;
155    /// # let params = vec![Tensor::zeros(10, true)];
156    /// # let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
157    /// # let mut trainer = Trainer::new(params, Box::new(optimizer), TrainConfig::default());
158    /// // After training...
159    /// trainer.save("model.safetensors", "my-model", "linear").expect("save failed");
160    /// ```
161    ///
162    /// # Errors
163    ///
164    /// Returns an error if the file cannot be written.
165    #[requires(!self.params.is_empty())]
166    pub fn save(
167        &self,
168        path: impl AsRef<Path>,
169        name: &str,
170        architecture: &str,
171    ) -> crate::Result<()> {
172        // Convert trainer params to io::Model format
173        let params: Vec<(String, Tensor)> = self
174            .params
175            .iter()
176            .enumerate()
177            .map(|(i, t)| (format!("param_{i}"), t.clone()))
178            .collect();
179
180        let metadata = ModelMetadata::new(name, architecture);
181        let model = Model::new(metadata, params);
182        let config = SaveConfig::new(ModelFormat::SafeTensors);
183
184        save_model(&model, path, &config)
185    }
186
187    /// Save model with custom parameter names
188    ///
189    /// Like `save()` but allows specifying custom names for each parameter tensor.
190    ///
191    /// # Arguments
192    ///
193    /// * `path` - Output file path
194    /// * `name` - Model name
195    /// * `architecture` - Architecture description
196    /// * `param_names` - Names for each parameter (must match params length)
197    ///
198    /// # Errors
199    ///
200    /// Returns an error if param_names length doesn't match params or file cannot be written.
201    pub fn save_with_names(
202        &self,
203        path: impl AsRef<Path>,
204        name: &str,
205        architecture: &str,
206        param_names: &[&str],
207    ) -> crate::Result<()> {
208        if param_names.len() != self.params.len() {
209            return Err(crate::Error::InvalidParameter(format!(
210                "param_names length {} doesn't match params length {}",
211                param_names.len(),
212                self.params.len()
213            )));
214        }
215
216        let params: Vec<(String, Tensor)> = self
217            .params
218            .iter()
219            .zip(param_names.iter())
220            .map(|(t, name)| (name.to_string(), t.clone()))
221            .collect();
222
223        let metadata = ModelMetadata::new(name, architecture);
224        let model = Model::new(metadata, params);
225        let config = SaveConfig::new(ModelFormat::SafeTensors);
226
227        save_model(&model, path, &config)
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use crate::optim::Adam;
235
236    #[test]
237    fn test_trainer_creation() {
238        let params = vec![Tensor::zeros(10, true)];
239        let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
240        let config = TrainConfig::default();
241
242        let trainer = Trainer::new(params, Box::new(optimizer), config);
243
244        assert_eq!(trainer.params().len(), 1);
245        assert_eq!(trainer.lr(), 0.001);
246    }
247
248    #[test]
249    fn test_set_lr() {
250        let params = vec![Tensor::zeros(10, true)];
251        let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
252        let config = TrainConfig::default();
253
254        let mut trainer = Trainer::new(params, Box::new(optimizer), config);
255        assert_eq!(trainer.lr(), 0.001);
256
257        trainer.set_lr(0.01);
258        assert_eq!(trainer.lr(), 0.01);
259    }
260
261    #[test]
262    fn test_params_mut() {
263        let params = vec![Tensor::from_vec(vec![1.0, 2.0], true)];
264        let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
265        let config = TrainConfig::default();
266
267        let mut trainer = Trainer::new(params, Box::new(optimizer), config);
268        let params = trainer.params_mut();
269        assert_eq!(params.len(), 1);
270        // Params should be mutable
271        params[0] = Tensor::from_vec(vec![3.0, 4.0], true);
272        assert_eq!(trainer.params()[0].data()[0], 3.0);
273    }
274
275    #[test]
276    fn test_add_callback() {
277        use crate::train::ProgressCallback;
278
279        let params = vec![Tensor::zeros(10, true)];
280        let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
281        let config = TrainConfig::default();
282
283        let mut trainer = Trainer::new(params, Box::new(optimizer), config);
284        trainer.add_callback(ProgressCallback::new(5));
285
286        // Verify callback was added
287        assert!(!trainer.callbacks().is_empty());
288    }
289
290    #[test]
291    fn test_callbacks_mut() {
292        use crate::train::ProgressCallback;
293
294        let params = vec![Tensor::zeros(10, true)];
295        let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
296        let config = TrainConfig::default();
297
298        let mut trainer = Trainer::new(params, Box::new(optimizer), config);
299        assert!(trainer.callbacks().is_empty());
300
301        // Add callback via mutable ref
302        trainer.callbacks_mut();
303        trainer.add_callback(ProgressCallback::new(10));
304        assert!(!trainer.callbacks().is_empty());
305    }
306
307    #[test]
308    fn test_set_loss() {
309        use crate::train::MSELoss;
310
311        let params = vec![Tensor::zeros(10, true)];
312        let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
313        let config = TrainConfig::default();
314
315        let mut trainer = Trainer::new(params, Box::new(optimizer), config);
316        assert!(trainer.loss_fn.is_none());
317
318        trainer.set_loss(Box::new(MSELoss));
319        assert!(trainer.loss_fn.is_some());
320    }
321
322    #[test]
323    fn test_build_context() {
324        let params = vec![Tensor::zeros(10, true)];
325        let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
326        let config = TrainConfig::default();
327
328        let mut trainer = Trainer::new(params, Box::new(optimizer), config);
329        trainer.best_loss = Some(0.5);
330        trainer.start_time = Some(Instant::now());
331
332        let ctx = trainer.build_context(2, 10, 5, 100, 0.1, Some(0.2));
333
334        assert_eq!(ctx.epoch, 2);
335        assert_eq!(ctx.max_epochs, 10);
336        assert_eq!(ctx.step, 5);
337        assert_eq!(ctx.steps_per_epoch, 100);
338        assert_eq!(ctx.loss, 0.1);
339        assert_eq!(ctx.val_loss, Some(0.2));
340        assert_eq!(ctx.best_loss, Some(0.5));
341        // When start_time is set, elapsed_secs should be a finite non-negative number
342        assert!(ctx.elapsed_secs.is_finite());
343    }
344
345    #[test]
346    fn test_build_context_no_start_time() {
347        let params = vec![Tensor::zeros(10, true)];
348        let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
349        let config = TrainConfig::default();
350
351        let trainer = Trainer::new(params, Box::new(optimizer), config);
352        // start_time is None
353
354        let ctx = trainer.build_context(0, 5, 0, 50, 1.0, None);
355
356        assert_eq!(ctx.epoch, 0);
357        assert_eq!(ctx.elapsed_secs, 0.0);
358        assert!(ctx.val_loss.is_none());
359        assert!(ctx.best_loss.is_none());
360    }
361
362    #[test]
363    fn test_save_with_names_length_mismatch() {
364        let params = vec![Tensor::zeros(10, true), Tensor::zeros(20, true)];
365        let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
366        let config = TrainConfig::default();
367
368        let trainer = Trainer::new(params, Box::new(optimizer), config);
369
370        // Mismatch: 2 params, 3 names
371        let result =
372            trainer.save_with_names("/tmp/test.safetensors", "test", "linear", &["a", "b", "c"]);
373
374        assert!(result.is_err());
375        let err = result.unwrap_err();
376        assert!(err.to_string().contains("doesn't match"));
377    }
378
379    #[test]
380    fn test_save() {
381        let params = vec![Tensor::from_vec(vec![1.0, 2.0, 3.0], false)];
382        let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
383        let config = TrainConfig::default();
384
385        let trainer = Trainer::new(params, Box::new(optimizer), config);
386
387        let temp_dir = std::env::temp_dir();
388        let path = temp_dir.join("test_trainer_save.safetensors");
389
390        let result = trainer.save(&path, "test-model", "linear");
391        assert!(result.is_ok());
392
393        // Clean up
394        let _ = std::fs::remove_file(&path);
395    }
396
397    #[test]
398    fn test_save_with_names() {
399        let params = vec![
400            Tensor::from_vec(vec![1.0, 2.0], false),
401            Tensor::from_vec(vec![3.0, 4.0, 5.0], false),
402        ];
403        let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
404        let config = TrainConfig::default();
405
406        let trainer = Trainer::new(params, Box::new(optimizer), config);
407
408        let temp_dir = std::env::temp_dir();
409        let path = temp_dir.join("test_trainer_save_names.safetensors");
410
411        let result = trainer.save_with_names(&path, "test-model", "mlp", &["weights", "bias"]);
412        assert!(result.is_ok());
413
414        // Clean up
415        let _ = std::fs::remove_file(&path);
416    }
417
418    #[test]
419    fn test_trainer_metrics_tracker() {
420        let params = vec![Tensor::zeros(10, true)];
421        let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
422        let config = TrainConfig::default();
423
424        let mut trainer = Trainer::new(params, Box::new(optimizer), config);
425
426        // Metrics tracker should be accessible
427        assert_eq!(trainer.metrics.steps, 0);
428        trainer.metrics.steps = 100;
429        assert_eq!(trainer.metrics.steps, 100);
430    }
431}