1use crate::io::{save_model, Model, ModelFormat, ModelMetadata, SaveConfig};
4use crate::optim::Optimizer;
5use crate::train::callback::{CallbackContext, CallbackManager, TrainerCallback};
6use crate::train::{LossFn, MetricsTracker, TrainConfig};
7use crate::Tensor;
8use provable_contracts_macros::requires;
9use std::path::Path;
10use std::time::Instant;
11
12pub struct Trainer {
34 pub(crate) params: Vec<Tensor>,
36
37 pub(crate) optimizer: Box<dyn Optimizer>,
39
40 pub(crate) loss_fn: Option<Box<dyn LossFn>>,
42
43 pub(crate) config: TrainConfig,
45
46 pub metrics: MetricsTracker,
48
49 pub(crate) callbacks: CallbackManager,
51
52 pub(crate) best_loss: Option<f32>,
54
55 pub(crate) start_time: Option<Instant>,
57}
58
59impl Trainer {
60 pub fn new(params: Vec<Tensor>, optimizer: Box<dyn Optimizer>, config: TrainConfig) -> Self {
62 Self {
63 params,
64 optimizer,
65 loss_fn: None,
66 config,
67 metrics: MetricsTracker::new(),
68 callbacks: CallbackManager::new(),
69 best_loss: None,
70 start_time: None,
71 }
72 }
73
74 pub fn set_loss(&mut self, loss_fn: Box<dyn LossFn>) {
76 self.loss_fn = Some(loss_fn);
77 }
78
79 pub fn add_callback<C: TrainerCallback + 'static>(&mut self, callback: C) {
81 self.callbacks.add(callback);
82 }
83
84 pub fn lr(&self) -> f32 {
86 self.optimizer.lr()
87 }
88
89 pub fn set_lr(&mut self, lr: f32) {
91 self.optimizer.set_lr(lr);
92 }
93
94 pub fn params(&self) -> &[Tensor] {
96 &self.params
97 }
98
99 pub fn params_mut(&mut self) -> &mut [Tensor] {
101 &mut self.params
102 }
103
104 pub fn callbacks(&self) -> &CallbackManager {
106 &self.callbacks
107 }
108
109 pub fn callbacks_mut(&mut self) -> &mut CallbackManager {
111 &mut self.callbacks
112 }
113
114 pub(crate) fn build_context(
116 &self,
117 epoch: usize,
118 max_epochs: usize,
119 step: usize,
120 steps_per_epoch: usize,
121 loss: f32,
122 val_loss: Option<f32>,
123 ) -> CallbackContext {
124 CallbackContext {
125 epoch,
126 max_epochs,
127 step,
128 steps_per_epoch,
129 global_step: self.metrics.steps,
130 loss,
131 lr: self.lr(),
132 best_loss: self.best_loss,
133 val_loss,
134 elapsed_secs: self.start_time.map_or(0.0, |t| t.elapsed().as_secs_f64()),
135 }
136 }
137
138 #[requires(!self.params.is_empty())]
166 pub fn save(
167 &self,
168 path: impl AsRef<Path>,
169 name: &str,
170 architecture: &str,
171 ) -> crate::Result<()> {
172 let params: Vec<(String, Tensor)> = self
174 .params
175 .iter()
176 .enumerate()
177 .map(|(i, t)| (format!("param_{i}"), t.clone()))
178 .collect();
179
180 let metadata = ModelMetadata::new(name, architecture);
181 let model = Model::new(metadata, params);
182 let config = SaveConfig::new(ModelFormat::SafeTensors);
183
184 save_model(&model, path, &config)
185 }
186
187 pub fn save_with_names(
202 &self,
203 path: impl AsRef<Path>,
204 name: &str,
205 architecture: &str,
206 param_names: &[&str],
207 ) -> crate::Result<()> {
208 if param_names.len() != self.params.len() {
209 return Err(crate::Error::InvalidParameter(format!(
210 "param_names length {} doesn't match params length {}",
211 param_names.len(),
212 self.params.len()
213 )));
214 }
215
216 let params: Vec<(String, Tensor)> = self
217 .params
218 .iter()
219 .zip(param_names.iter())
220 .map(|(t, name)| (name.to_string(), t.clone()))
221 .collect();
222
223 let metadata = ModelMetadata::new(name, architecture);
224 let model = Model::new(metadata, params);
225 let config = SaveConfig::new(ModelFormat::SafeTensors);
226
227 save_model(&model, path, &config)
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use crate::optim::Adam;
235
236 #[test]
237 fn test_trainer_creation() {
238 let params = vec![Tensor::zeros(10, true)];
239 let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
240 let config = TrainConfig::default();
241
242 let trainer = Trainer::new(params, Box::new(optimizer), config);
243
244 assert_eq!(trainer.params().len(), 1);
245 assert_eq!(trainer.lr(), 0.001);
246 }
247
248 #[test]
249 fn test_set_lr() {
250 let params = vec![Tensor::zeros(10, true)];
251 let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
252 let config = TrainConfig::default();
253
254 let mut trainer = Trainer::new(params, Box::new(optimizer), config);
255 assert_eq!(trainer.lr(), 0.001);
256
257 trainer.set_lr(0.01);
258 assert_eq!(trainer.lr(), 0.01);
259 }
260
261 #[test]
262 fn test_params_mut() {
263 let params = vec![Tensor::from_vec(vec![1.0, 2.0], true)];
264 let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
265 let config = TrainConfig::default();
266
267 let mut trainer = Trainer::new(params, Box::new(optimizer), config);
268 let params = trainer.params_mut();
269 assert_eq!(params.len(), 1);
270 params[0] = Tensor::from_vec(vec![3.0, 4.0], true);
272 assert_eq!(trainer.params()[0].data()[0], 3.0);
273 }
274
275 #[test]
276 fn test_add_callback() {
277 use crate::train::ProgressCallback;
278
279 let params = vec![Tensor::zeros(10, true)];
280 let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
281 let config = TrainConfig::default();
282
283 let mut trainer = Trainer::new(params, Box::new(optimizer), config);
284 trainer.add_callback(ProgressCallback::new(5));
285
286 assert!(!trainer.callbacks().is_empty());
288 }
289
290 #[test]
291 fn test_callbacks_mut() {
292 use crate::train::ProgressCallback;
293
294 let params = vec![Tensor::zeros(10, true)];
295 let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
296 let config = TrainConfig::default();
297
298 let mut trainer = Trainer::new(params, Box::new(optimizer), config);
299 assert!(trainer.callbacks().is_empty());
300
301 trainer.callbacks_mut();
303 trainer.add_callback(ProgressCallback::new(10));
304 assert!(!trainer.callbacks().is_empty());
305 }
306
307 #[test]
308 fn test_set_loss() {
309 use crate::train::MSELoss;
310
311 let params = vec![Tensor::zeros(10, true)];
312 let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
313 let config = TrainConfig::default();
314
315 let mut trainer = Trainer::new(params, Box::new(optimizer), config);
316 assert!(trainer.loss_fn.is_none());
317
318 trainer.set_loss(Box::new(MSELoss));
319 assert!(trainer.loss_fn.is_some());
320 }
321
322 #[test]
323 fn test_build_context() {
324 let params = vec![Tensor::zeros(10, true)];
325 let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
326 let config = TrainConfig::default();
327
328 let mut trainer = Trainer::new(params, Box::new(optimizer), config);
329 trainer.best_loss = Some(0.5);
330 trainer.start_time = Some(Instant::now());
331
332 let ctx = trainer.build_context(2, 10, 5, 100, 0.1, Some(0.2));
333
334 assert_eq!(ctx.epoch, 2);
335 assert_eq!(ctx.max_epochs, 10);
336 assert_eq!(ctx.step, 5);
337 assert_eq!(ctx.steps_per_epoch, 100);
338 assert_eq!(ctx.loss, 0.1);
339 assert_eq!(ctx.val_loss, Some(0.2));
340 assert_eq!(ctx.best_loss, Some(0.5));
341 assert!(ctx.elapsed_secs.is_finite());
343 }
344
345 #[test]
346 fn test_build_context_no_start_time() {
347 let params = vec![Tensor::zeros(10, true)];
348 let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
349 let config = TrainConfig::default();
350
351 let trainer = Trainer::new(params, Box::new(optimizer), config);
352 let ctx = trainer.build_context(0, 5, 0, 50, 1.0, None);
355
356 assert_eq!(ctx.epoch, 0);
357 assert_eq!(ctx.elapsed_secs, 0.0);
358 assert!(ctx.val_loss.is_none());
359 assert!(ctx.best_loss.is_none());
360 }
361
362 #[test]
363 fn test_save_with_names_length_mismatch() {
364 let params = vec![Tensor::zeros(10, true), Tensor::zeros(20, true)];
365 let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
366 let config = TrainConfig::default();
367
368 let trainer = Trainer::new(params, Box::new(optimizer), config);
369
370 let result =
372 trainer.save_with_names("/tmp/test.safetensors", "test", "linear", &["a", "b", "c"]);
373
374 assert!(result.is_err());
375 let err = result.unwrap_err();
376 assert!(err.to_string().contains("doesn't match"));
377 }
378
379 #[test]
380 fn test_save() {
381 let params = vec![Tensor::from_vec(vec![1.0, 2.0, 3.0], false)];
382 let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
383 let config = TrainConfig::default();
384
385 let trainer = Trainer::new(params, Box::new(optimizer), config);
386
387 let temp_dir = std::env::temp_dir();
388 let path = temp_dir.join("test_trainer_save.safetensors");
389
390 let result = trainer.save(&path, "test-model", "linear");
391 assert!(result.is_ok());
392
393 let _ = std::fs::remove_file(&path);
395 }
396
397 #[test]
398 fn test_save_with_names() {
399 let params = vec![
400 Tensor::from_vec(vec![1.0, 2.0], false),
401 Tensor::from_vec(vec![3.0, 4.0, 5.0], false),
402 ];
403 let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
404 let config = TrainConfig::default();
405
406 let trainer = Trainer::new(params, Box::new(optimizer), config);
407
408 let temp_dir = std::env::temp_dir();
409 let path = temp_dir.join("test_trainer_save_names.safetensors");
410
411 let result = trainer.save_with_names(&path, "test-model", "mlp", &["weights", "bias"]);
412 assert!(result.is_ok());
413
414 let _ = std::fs::remove_file(&path);
416 }
417
418 #[test]
419 fn test_trainer_metrics_tracker() {
420 let params = vec![Tensor::zeros(10, true)];
421 let optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
422 let config = TrainConfig::default();
423
424 let mut trainer = Trainer::new(params, Box::new(optimizer), config);
425
426 assert_eq!(trainer.metrics.steps, 0);
428 trainer.metrics.steps = 100;
429 assert_eq!(trainer.metrics.steps, 100);
430 }
431}