1use yscv_autograd::{Graph, NodeId};
2use yscv_optim::{Adam, AdamW, LearningRate, LrScheduler, RmsProp, Sgd};
3use yscv_tensor::Tensor;
4
5use crate::{
6 BatchIterOptions, GradientAggregator, ModelError, SequentialModel, SupervisedDataset, bce_loss,
7 cross_entropy_loss, hinge_loss, huber_loss, mae_loss, mse_loss, nll_loss,
8};
9
10trait GraphOptimizer {
11 fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), ModelError>;
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Default)]
16pub enum SupervisedLoss {
17 #[default]
18 Mse,
19 Mae,
20 Huber {
21 delta: f32,
22 },
23 Hinge {
24 margin: f32,
25 },
26 Bce,
27 Nll,
28 CrossEntropy,
29}
30
31fn build_loss_node(
32 graph: &mut Graph,
33 prediction: NodeId,
34 target: NodeId,
35 loss: SupervisedLoss,
36) -> Result<NodeId, ModelError> {
37 match loss {
38 SupervisedLoss::Mse => mse_loss(graph, prediction, target),
39 SupervisedLoss::Mae => mae_loss(graph, prediction, target),
40 SupervisedLoss::Huber { delta } => huber_loss(graph, prediction, target, delta),
41 SupervisedLoss::Hinge { margin } => hinge_loss(graph, prediction, target, margin),
42 SupervisedLoss::Bce => bce_loss(graph, prediction, target),
43 SupervisedLoss::Nll => nll_loss(graph, prediction, target),
44 SupervisedLoss::CrossEntropy => cross_entropy_loss(graph, prediction, target),
45 }
46}
47
48impl GraphOptimizer for Sgd {
49 fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), ModelError> {
50 Sgd::step_graph_node(self, graph, node).map_err(Into::into)
51 }
52}
53
54impl GraphOptimizer for Adam {
55 fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), ModelError> {
56 Adam::step_graph_node(self, graph, node).map_err(Into::into)
57 }
58}
59
60impl GraphOptimizer for AdamW {
61 fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), ModelError> {
62 AdamW::step_graph_node(self, graph, node).map_err(Into::into)
63 }
64}
65
66impl GraphOptimizer for RmsProp {
67 fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), ModelError> {
68 RmsProp::step_graph_node(self, graph, node).map_err(Into::into)
69 }
70}
71
72fn train_step_with_optimizer<O: GraphOptimizer>(
73 graph: &mut Graph,
74 optimizer: &mut O,
75 prediction: NodeId,
76 target: NodeId,
77 trainable_nodes: &[NodeId],
78 loss: SupervisedLoss,
79) -> Result<f32, ModelError> {
80 let loss_node = build_loss_node(graph, prediction, target, loss)?;
81 graph.backward(loss_node)?;
82
83 let loss_value = graph.value(loss_node)?.data()[0];
84 for node in trainable_nodes {
85 optimizer.step_graph_node(graph, *node)?;
86 }
87 Ok(loss_value)
88}
89
90pub fn train_step_sgd(
92 graph: &mut Graph,
93 optimizer: &mut Sgd,
94 prediction: NodeId,
95 target: NodeId,
96 trainable_nodes: &[NodeId],
97) -> Result<f32, ModelError> {
98 train_step_sgd_with_loss(
99 graph,
100 optimizer,
101 prediction,
102 target,
103 trainable_nodes,
104 SupervisedLoss::Mse,
105 )
106}
107
108pub fn train_step_sgd_with_loss(
110 graph: &mut Graph,
111 optimizer: &mut Sgd,
112 prediction: NodeId,
113 target: NodeId,
114 trainable_nodes: &[NodeId],
115 loss: SupervisedLoss,
116) -> Result<f32, ModelError> {
117 train_step_with_optimizer(graph, optimizer, prediction, target, trainable_nodes, loss)
118}
119
120pub fn train_step_adam(
122 graph: &mut Graph,
123 optimizer: &mut Adam,
124 prediction: NodeId,
125 target: NodeId,
126 trainable_nodes: &[NodeId],
127) -> Result<f32, ModelError> {
128 train_step_adam_with_loss(
129 graph,
130 optimizer,
131 prediction,
132 target,
133 trainable_nodes,
134 SupervisedLoss::Mse,
135 )
136}
137
138pub fn train_step_adam_with_loss(
140 graph: &mut Graph,
141 optimizer: &mut Adam,
142 prediction: NodeId,
143 target: NodeId,
144 trainable_nodes: &[NodeId],
145 loss: SupervisedLoss,
146) -> Result<f32, ModelError> {
147 train_step_with_optimizer(graph, optimizer, prediction, target, trainable_nodes, loss)
148}
149
150pub fn train_step_adamw(
152 graph: &mut Graph,
153 optimizer: &mut AdamW,
154 prediction: NodeId,
155 target: NodeId,
156 trainable_nodes: &[NodeId],
157) -> Result<f32, ModelError> {
158 train_step_adamw_with_loss(
159 graph,
160 optimizer,
161 prediction,
162 target,
163 trainable_nodes,
164 SupervisedLoss::Mse,
165 )
166}
167
168pub fn train_step_adamw_with_loss(
170 graph: &mut Graph,
171 optimizer: &mut AdamW,
172 prediction: NodeId,
173 target: NodeId,
174 trainable_nodes: &[NodeId],
175 loss: SupervisedLoss,
176) -> Result<f32, ModelError> {
177 train_step_with_optimizer(graph, optimizer, prediction, target, trainable_nodes, loss)
178}
179
180pub fn train_step_rmsprop(
182 graph: &mut Graph,
183 optimizer: &mut RmsProp,
184 prediction: NodeId,
185 target: NodeId,
186 trainable_nodes: &[NodeId],
187) -> Result<f32, ModelError> {
188 train_step_rmsprop_with_loss(
189 graph,
190 optimizer,
191 prediction,
192 target,
193 trainable_nodes,
194 SupervisedLoss::Mse,
195 )
196}
197
198pub fn train_step_rmsprop_with_loss(
200 graph: &mut Graph,
201 optimizer: &mut RmsProp,
202 prediction: NodeId,
203 target: NodeId,
204 trainable_nodes: &[NodeId],
205 loss: SupervisedLoss,
206) -> Result<f32, ModelError> {
207 train_step_with_optimizer(graph, optimizer, prediction, target, trainable_nodes, loss)
208}
209
210#[derive(Debug, Clone, Copy, PartialEq)]
212pub struct EpochMetrics {
213 pub mean_loss: f32,
214 pub steps: usize,
215}
216
217#[derive(Debug, Clone, Copy, PartialEq)]
219pub struct ScheduledEpochMetrics {
220 pub epoch: usize,
221 pub mean_loss: f32,
222 pub steps: usize,
223 pub learning_rate: f32,
224}
225
226#[derive(Debug, Clone, PartialEq)]
228pub struct EpochTrainOptions {
229 pub batch_size: usize,
230 pub batch_iter_options: BatchIterOptions,
231}
232
233impl Default for EpochTrainOptions {
234 fn default() -> Self {
235 Self {
236 batch_size: 1,
237 batch_iter_options: BatchIterOptions::default(),
238 }
239 }
240}
241
242#[derive(Debug, Clone, PartialEq, Default)]
244pub struct SchedulerTrainOptions {
245 pub epoch_options: EpochTrainOptions,
246 pub loss: SupervisedLoss,
247}
248
249pub fn train_epoch_sgd(
251 graph: &mut Graph,
252 model: &SequentialModel,
253 optimizer: &mut Sgd,
254 dataset: &SupervisedDataset,
255 batch_size: usize,
256) -> Result<EpochMetrics, ModelError> {
257 train_epoch_sgd_with_loss(
258 graph,
259 model,
260 optimizer,
261 dataset,
262 batch_size,
263 SupervisedLoss::Mse,
264 )
265}
266
267pub fn train_epoch_sgd_with_loss(
269 graph: &mut Graph,
270 model: &SequentialModel,
271 optimizer: &mut Sgd,
272 dataset: &SupervisedDataset,
273 batch_size: usize,
274 loss: SupervisedLoss,
275) -> Result<EpochMetrics, ModelError> {
276 train_epoch_sgd_with_options_and_loss(
277 graph,
278 model,
279 optimizer,
280 dataset,
281 EpochTrainOptions {
282 batch_size,
283 batch_iter_options: BatchIterOptions::default(),
284 },
285 loss,
286 )
287}
288
289pub fn train_epoch_adam(
291 graph: &mut Graph,
292 model: &SequentialModel,
293 optimizer: &mut Adam,
294 dataset: &SupervisedDataset,
295 batch_size: usize,
296) -> Result<EpochMetrics, ModelError> {
297 train_epoch_adam_with_loss(
298 graph,
299 model,
300 optimizer,
301 dataset,
302 batch_size,
303 SupervisedLoss::Mse,
304 )
305}
306
307pub fn train_epoch_adam_with_loss(
309 graph: &mut Graph,
310 model: &SequentialModel,
311 optimizer: &mut Adam,
312 dataset: &SupervisedDataset,
313 batch_size: usize,
314 loss: SupervisedLoss,
315) -> Result<EpochMetrics, ModelError> {
316 train_epoch_adam_with_options_and_loss(
317 graph,
318 model,
319 optimizer,
320 dataset,
321 EpochTrainOptions {
322 batch_size,
323 batch_iter_options: BatchIterOptions::default(),
324 },
325 loss,
326 )
327}
328
329pub fn train_epoch_adamw(
331 graph: &mut Graph,
332 model: &SequentialModel,
333 optimizer: &mut AdamW,
334 dataset: &SupervisedDataset,
335 batch_size: usize,
336) -> Result<EpochMetrics, ModelError> {
337 train_epoch_adamw_with_loss(
338 graph,
339 model,
340 optimizer,
341 dataset,
342 batch_size,
343 SupervisedLoss::Mse,
344 )
345}
346
347pub fn train_epoch_adamw_with_loss(
349 graph: &mut Graph,
350 model: &SequentialModel,
351 optimizer: &mut AdamW,
352 dataset: &SupervisedDataset,
353 batch_size: usize,
354 loss: SupervisedLoss,
355) -> Result<EpochMetrics, ModelError> {
356 train_epoch_adamw_with_options_and_loss(
357 graph,
358 model,
359 optimizer,
360 dataset,
361 EpochTrainOptions {
362 batch_size,
363 batch_iter_options: BatchIterOptions::default(),
364 },
365 loss,
366 )
367}
368
369pub fn train_epoch_rmsprop(
371 graph: &mut Graph,
372 model: &SequentialModel,
373 optimizer: &mut RmsProp,
374 dataset: &SupervisedDataset,
375 batch_size: usize,
376) -> Result<EpochMetrics, ModelError> {
377 train_epoch_rmsprop_with_loss(
378 graph,
379 model,
380 optimizer,
381 dataset,
382 batch_size,
383 SupervisedLoss::Mse,
384 )
385}
386
387pub fn train_epoch_rmsprop_with_loss(
389 graph: &mut Graph,
390 model: &SequentialModel,
391 optimizer: &mut RmsProp,
392 dataset: &SupervisedDataset,
393 batch_size: usize,
394 loss: SupervisedLoss,
395) -> Result<EpochMetrics, ModelError> {
396 train_epoch_rmsprop_with_options_and_loss(
397 graph,
398 model,
399 optimizer,
400 dataset,
401 EpochTrainOptions {
402 batch_size,
403 batch_iter_options: BatchIterOptions::default(),
404 },
405 loss,
406 )
407}
408
409pub fn train_epoch_sgd_with_options(
411 graph: &mut Graph,
412 model: &SequentialModel,
413 optimizer: &mut Sgd,
414 dataset: &SupervisedDataset,
415 options: EpochTrainOptions,
416) -> Result<EpochMetrics, ModelError> {
417 train_epoch_sgd_with_options_and_loss(
418 graph,
419 model,
420 optimizer,
421 dataset,
422 options,
423 SupervisedLoss::Mse,
424 )
425}
426
427pub fn train_epoch_sgd_with_options_and_loss(
429 graph: &mut Graph,
430 model: &SequentialModel,
431 optimizer: &mut Sgd,
432 dataset: &SupervisedDataset,
433 options: EpochTrainOptions,
434 loss: SupervisedLoss,
435) -> Result<EpochMetrics, ModelError> {
436 train_epoch_with_options(graph, model, optimizer, dataset, options, loss)
437}
438
439pub fn train_epoch_adam_with_options(
441 graph: &mut Graph,
442 model: &SequentialModel,
443 optimizer: &mut Adam,
444 dataset: &SupervisedDataset,
445 options: EpochTrainOptions,
446) -> Result<EpochMetrics, ModelError> {
447 train_epoch_adam_with_options_and_loss(
448 graph,
449 model,
450 optimizer,
451 dataset,
452 options,
453 SupervisedLoss::Mse,
454 )
455}
456
457pub fn train_epoch_adam_with_options_and_loss(
459 graph: &mut Graph,
460 model: &SequentialModel,
461 optimizer: &mut Adam,
462 dataset: &SupervisedDataset,
463 options: EpochTrainOptions,
464 loss: SupervisedLoss,
465) -> Result<EpochMetrics, ModelError> {
466 train_epoch_with_options(graph, model, optimizer, dataset, options, loss)
467}
468
469pub fn train_epoch_adamw_with_options(
471 graph: &mut Graph,
472 model: &SequentialModel,
473 optimizer: &mut AdamW,
474 dataset: &SupervisedDataset,
475 options: EpochTrainOptions,
476) -> Result<EpochMetrics, ModelError> {
477 train_epoch_adamw_with_options_and_loss(
478 graph,
479 model,
480 optimizer,
481 dataset,
482 options,
483 SupervisedLoss::Mse,
484 )
485}
486
487pub fn train_epoch_adamw_with_options_and_loss(
489 graph: &mut Graph,
490 model: &SequentialModel,
491 optimizer: &mut AdamW,
492 dataset: &SupervisedDataset,
493 options: EpochTrainOptions,
494 loss: SupervisedLoss,
495) -> Result<EpochMetrics, ModelError> {
496 train_epoch_with_options(graph, model, optimizer, dataset, options, loss)
497}
498
499pub fn train_epoch_rmsprop_with_options(
501 graph: &mut Graph,
502 model: &SequentialModel,
503 optimizer: &mut RmsProp,
504 dataset: &SupervisedDataset,
505 options: EpochTrainOptions,
506) -> Result<EpochMetrics, ModelError> {
507 train_epoch_rmsprop_with_options_and_loss(
508 graph,
509 model,
510 optimizer,
511 dataset,
512 options,
513 SupervisedLoss::Mse,
514 )
515}
516
517pub fn train_epoch_rmsprop_with_options_and_loss(
519 graph: &mut Graph,
520 model: &SequentialModel,
521 optimizer: &mut RmsProp,
522 dataset: &SupervisedDataset,
523 options: EpochTrainOptions,
524 loss: SupervisedLoss,
525) -> Result<EpochMetrics, ModelError> {
526 train_epoch_with_options(graph, model, optimizer, dataset, options, loss)
527}
528
529pub fn train_epochs_sgd_with_scheduler<S: LrScheduler>(
531 graph: &mut Graph,
532 model: &SequentialModel,
533 optimizer: &mut Sgd,
534 scheduler: &mut S,
535 dataset: &SupervisedDataset,
536 epochs: usize,
537 options: EpochTrainOptions,
538) -> Result<Vec<ScheduledEpochMetrics>, ModelError> {
539 train_epochs_sgd_with_scheduler_and_loss(
540 graph,
541 model,
542 optimizer,
543 scheduler,
544 dataset,
545 epochs,
546 SchedulerTrainOptions {
547 epoch_options: options,
548 loss: SupervisedLoss::Mse,
549 },
550 )
551}
552
553pub fn train_epochs_sgd_with_scheduler_and_loss<S: LrScheduler>(
555 graph: &mut Graph,
556 model: &SequentialModel,
557 optimizer: &mut Sgd,
558 scheduler: &mut S,
559 dataset: &SupervisedDataset,
560 epochs: usize,
561 options: SchedulerTrainOptions,
562) -> Result<Vec<ScheduledEpochMetrics>, ModelError> {
563 train_epochs_with_scheduler(graph, model, optimizer, scheduler, dataset, epochs, options)
564}
565
566pub fn train_epochs_adam_with_scheduler<S: LrScheduler>(
568 graph: &mut Graph,
569 model: &SequentialModel,
570 optimizer: &mut Adam,
571 scheduler: &mut S,
572 dataset: &SupervisedDataset,
573 epochs: usize,
574 options: EpochTrainOptions,
575) -> Result<Vec<ScheduledEpochMetrics>, ModelError> {
576 train_epochs_adam_with_scheduler_and_loss(
577 graph,
578 model,
579 optimizer,
580 scheduler,
581 dataset,
582 epochs,
583 SchedulerTrainOptions {
584 epoch_options: options,
585 loss: SupervisedLoss::Mse,
586 },
587 )
588}
589
590pub fn train_epochs_adam_with_scheduler_and_loss<S: LrScheduler>(
592 graph: &mut Graph,
593 model: &SequentialModel,
594 optimizer: &mut Adam,
595 scheduler: &mut S,
596 dataset: &SupervisedDataset,
597 epochs: usize,
598 options: SchedulerTrainOptions,
599) -> Result<Vec<ScheduledEpochMetrics>, ModelError> {
600 train_epochs_with_scheduler(graph, model, optimizer, scheduler, dataset, epochs, options)
601}
602
603pub fn train_epochs_adamw_with_scheduler<S: LrScheduler>(
605 graph: &mut Graph,
606 model: &SequentialModel,
607 optimizer: &mut AdamW,
608 scheduler: &mut S,
609 dataset: &SupervisedDataset,
610 epochs: usize,
611 options: EpochTrainOptions,
612) -> Result<Vec<ScheduledEpochMetrics>, ModelError> {
613 train_epochs_adamw_with_scheduler_and_loss(
614 graph,
615 model,
616 optimizer,
617 scheduler,
618 dataset,
619 epochs,
620 SchedulerTrainOptions {
621 epoch_options: options,
622 loss: SupervisedLoss::Mse,
623 },
624 )
625}
626
627pub fn train_epochs_adamw_with_scheduler_and_loss<S: LrScheduler>(
629 graph: &mut Graph,
630 model: &SequentialModel,
631 optimizer: &mut AdamW,
632 scheduler: &mut S,
633 dataset: &SupervisedDataset,
634 epochs: usize,
635 options: SchedulerTrainOptions,
636) -> Result<Vec<ScheduledEpochMetrics>, ModelError> {
637 train_epochs_with_scheduler(graph, model, optimizer, scheduler, dataset, epochs, options)
638}
639
640pub fn train_epochs_rmsprop_with_scheduler<S: LrScheduler>(
642 graph: &mut Graph,
643 model: &SequentialModel,
644 optimizer: &mut RmsProp,
645 scheduler: &mut S,
646 dataset: &SupervisedDataset,
647 epochs: usize,
648 options: EpochTrainOptions,
649) -> Result<Vec<ScheduledEpochMetrics>, ModelError> {
650 train_epochs_rmsprop_with_scheduler_and_loss(
651 graph,
652 model,
653 optimizer,
654 scheduler,
655 dataset,
656 epochs,
657 SchedulerTrainOptions {
658 epoch_options: options,
659 loss: SupervisedLoss::Mse,
660 },
661 )
662}
663
664pub fn train_epochs_rmsprop_with_scheduler_and_loss<S: LrScheduler>(
666 graph: &mut Graph,
667 model: &SequentialModel,
668 optimizer: &mut RmsProp,
669 scheduler: &mut S,
670 dataset: &SupervisedDataset,
671 epochs: usize,
672 options: SchedulerTrainOptions,
673) -> Result<Vec<ScheduledEpochMetrics>, ModelError> {
674 train_epochs_with_scheduler(graph, model, optimizer, scheduler, dataset, epochs, options)
675}
676
677fn train_epoch_with_options<O: GraphOptimizer>(
678 graph: &mut Graph,
679 model: &SequentialModel,
680 optimizer: &mut O,
681 dataset: &SupervisedDataset,
682 options: EpochTrainOptions,
683 loss: SupervisedLoss,
684) -> Result<EpochMetrics, ModelError> {
685 if dataset.is_empty() {
686 return Err(ModelError::EmptyDataset);
687 }
688 let batches = dataset.batches_with_options(options.batch_size, options.batch_iter_options)?;
689 let trainable_nodes = model.trainable_nodes();
690
691 let mut loss_sum = 0.0f32;
692 let mut steps = 0usize;
693 for batch in batches {
694 graph.truncate(model.persistent_node_count())?;
695
696 let input = graph.constant(batch.inputs);
697 let target = graph.constant(batch.targets);
698 let prediction = model.forward(graph, input)?;
699 let loss_value = train_step_with_optimizer(
700 graph,
701 optimizer,
702 prediction,
703 target,
704 &trainable_nodes,
705 loss,
706 )?;
707 loss_sum += loss_value;
708 steps += 1;
709 }
710 if steps == 0 {
711 return Err(ModelError::EmptyDataset);
712 }
713
714 Ok(EpochMetrics {
715 mean_loss: loss_sum / steps as f32,
716 steps,
717 })
718}
719
720fn train_epochs_with_scheduler<O, S>(
721 graph: &mut Graph,
722 model: &SequentialModel,
723 optimizer: &mut O,
724 scheduler: &mut S,
725 dataset: &SupervisedDataset,
726 epochs: usize,
727 options: SchedulerTrainOptions,
728) -> Result<Vec<ScheduledEpochMetrics>, ModelError>
729where
730 O: GraphOptimizer + LearningRate,
731 S: LrScheduler,
732{
733 if epochs == 0 {
734 return Err(ModelError::InvalidEpochCount { epochs });
735 }
736
737 let mut all_metrics = Vec::with_capacity(epochs);
738 for epoch_index in 0..epochs {
739 let epoch_metrics = train_epoch_with_options(
740 graph,
741 model,
742 optimizer,
743 dataset,
744 options.epoch_options.clone(),
745 options.loss,
746 )?;
747 let learning_rate = scheduler.step(optimizer)?;
748 all_metrics.push(ScheduledEpochMetrics {
749 epoch: epoch_index + 1,
750 mean_loss: epoch_metrics.mean_loss,
751 steps: epoch_metrics.steps,
752 learning_rate,
753 });
754 }
755 Ok(all_metrics)
756}
757
758#[derive(Debug, Clone)]
762pub struct CnnTrainConfig {
763 pub lr: f32,
764 pub batch_size: usize,
765 pub loss: SupervisedLoss,
766 pub batch_iter_options: BatchIterOptions,
767}
768
769impl Default for CnnTrainConfig {
770 fn default() -> Self {
771 Self {
772 lr: 0.01,
773 batch_size: 16,
774 loss: SupervisedLoss::CrossEntropy,
775 batch_iter_options: BatchIterOptions::default(),
776 }
777 }
778}
779
780pub fn train_cnn_epoch_sgd(
784 graph: &mut Graph,
785 model: &mut SequentialModel,
786 dataset: &SupervisedDataset,
787 config: &CnnTrainConfig,
788) -> Result<EpochMetrics, ModelError> {
789 let mut optimizer = yscv_optim::Sgd::new(config.lr)?;
790 train_cnn_epoch_with_optimizer(graph, model, dataset, &mut optimizer, config)
791}
792
793pub fn train_cnn_epoch_adam(
795 graph: &mut Graph,
796 model: &mut SequentialModel,
797 dataset: &SupervisedDataset,
798 config: &CnnTrainConfig,
799) -> Result<EpochMetrics, ModelError> {
800 let mut optimizer = Adam::new(config.lr)?;
801 train_cnn_epoch_with_optimizer(graph, model, dataset, &mut optimizer, config)
802}
803
804pub fn train_cnn_epoch_adamw(
806 graph: &mut Graph,
807 model: &mut SequentialModel,
808 dataset: &SupervisedDataset,
809 config: &CnnTrainConfig,
810) -> Result<EpochMetrics, ModelError> {
811 let mut optimizer = AdamW::new(config.lr)?;
812 train_cnn_epoch_with_optimizer(graph, model, dataset, &mut optimizer, config)
813}
814
815fn train_cnn_epoch_with_optimizer<O: GraphOptimizer>(
816 graph: &mut Graph,
817 model: &mut SequentialModel,
818 dataset: &SupervisedDataset,
819 optimizer: &mut O,
820 config: &CnnTrainConfig,
821) -> Result<EpochMetrics, ModelError> {
822 model.register_cnn_params(graph);
823 let param_nodes = model.trainable_nodes();
824 let persistent = model.persistent_node_count();
825 let iter =
826 dataset.batches_with_options(config.batch_size, config.batch_iter_options.clone())?;
827
828 let mut total_loss = 0.0f32;
829 let mut steps = 0usize;
830
831 for batch in iter {
832 graph.truncate(persistent)?;
833 let input_node = graph.variable(batch.inputs);
834 let target_node = graph.variable(batch.targets);
835 let prediction = model.forward(graph, input_node)?;
836 let loss_val = train_step_with_optimizer(
837 graph,
838 optimizer,
839 prediction,
840 target_node,
841 ¶m_nodes,
842 config.loss,
843 )?;
844 model.sync_cnn_from_graph(graph)?;
845 total_loss += loss_val;
846 steps += 1;
847 }
848
849 Ok(EpochMetrics {
850 mean_loss: if steps > 0 {
851 total_loss / steps as f32
852 } else {
853 0.0
854 },
855 steps,
856 })
857}
858
859#[derive(Debug, Clone, Copy, PartialEq, Eq)]
861pub enum OptimizerType {
862 Sgd,
863 Adam,
864 AdamW,
865}
866
867pub fn train_cnn_epochs(
869 graph: &mut Graph,
870 model: &mut SequentialModel,
871 dataset: &SupervisedDataset,
872 epochs: usize,
873 config: &CnnTrainConfig,
874 optimizer_type: OptimizerType,
875) -> Result<Vec<EpochMetrics>, ModelError> {
876 if epochs == 0 {
877 return Err(ModelError::InvalidEpochCount { epochs });
878 }
879 let mut all = Vec::with_capacity(epochs);
880 for _ in 0..epochs {
881 let metrics = match optimizer_type {
882 OptimizerType::Sgd => train_cnn_epoch_sgd(graph, model, dataset, config)?,
883 OptimizerType::Adam => train_cnn_epoch_adam(graph, model, dataset, config)?,
884 OptimizerType::AdamW => train_cnn_epoch_adamw(graph, model, dataset, config)?,
885 };
886 all.push(metrics);
887 }
888 Ok(all)
889}
890
891pub fn scale_gradients(graph: &mut Graph, nodes: &[NodeId], scale: f32) -> Result<(), ModelError> {
897 for &node in nodes {
898 if let Some(grad) = graph.grad_mut(node)? {
899 let scaled = grad.scale(scale);
900 *grad = scaled;
901 }
902 }
903 Ok(())
904}
905
906pub fn accumulate_gradients(
915 graph: &mut Graph,
916 nodes: &[NodeId],
917 source_grads: &[Option<Tensor>],
918) -> Result<(), ModelError> {
919 assert_eq!(
920 nodes.len(),
921 source_grads.len(),
922 "nodes and source_grads must have the same length"
923 );
924 for (i, &node) in nodes.iter().enumerate() {
925 if let Some(src) = &source_grads[i] {
926 let existing = graph.grad(node)?;
927 let new_grad = match existing {
928 Some(current) => current.add(src)?,
929 None => src.clone(),
930 };
931 graph.set_grad(node, new_grad)?;
932 }
933 }
934 Ok(())
935}
936
937pub fn collect_gradients(
942 graph: &Graph,
943 nodes: &[NodeId],
944) -> Result<Vec<Option<Tensor>>, ModelError> {
945 let mut grads = Vec::with_capacity(nodes.len());
946 for &node in nodes {
947 grads.push(graph.grad(node)?.cloned());
948 }
949 Ok(grads)
950}
951
952pub fn train_step_sgd_with_accumulation<F>(
967 graph: &mut Graph,
968 optimizer: &mut Sgd,
969 trainable_nodes: &[NodeId],
970 accumulation_steps: usize,
971 loss_fn: SupervisedLoss,
972 mut micro_batch_fn: F,
973) -> Result<f32, ModelError>
974where
975 F: FnMut(&mut Graph) -> Result<(NodeId, NodeId), ModelError>,
976{
977 train_step_with_accumulation_impl(
978 graph,
979 optimizer,
980 trainable_nodes,
981 accumulation_steps,
982 loss_fn,
983 &mut micro_batch_fn,
984 )
985}
986
987pub fn train_step_adam_with_accumulation<F>(
990 graph: &mut Graph,
991 optimizer: &mut Adam,
992 trainable_nodes: &[NodeId],
993 accumulation_steps: usize,
994 loss_fn: SupervisedLoss,
995 mut micro_batch_fn: F,
996) -> Result<f32, ModelError>
997where
998 F: FnMut(&mut Graph) -> Result<(NodeId, NodeId), ModelError>,
999{
1000 train_step_with_accumulation_impl(
1001 graph,
1002 optimizer,
1003 trainable_nodes,
1004 accumulation_steps,
1005 loss_fn,
1006 &mut micro_batch_fn,
1007 )
1008}
1009
1010pub fn train_step_adamw_with_accumulation<F>(
1013 graph: &mut Graph,
1014 optimizer: &mut AdamW,
1015 trainable_nodes: &[NodeId],
1016 accumulation_steps: usize,
1017 loss_fn: SupervisedLoss,
1018 mut micro_batch_fn: F,
1019) -> Result<f32, ModelError>
1020where
1021 F: FnMut(&mut Graph) -> Result<(NodeId, NodeId), ModelError>,
1022{
1023 train_step_with_accumulation_impl(
1024 graph,
1025 optimizer,
1026 trainable_nodes,
1027 accumulation_steps,
1028 loss_fn,
1029 &mut micro_batch_fn,
1030 )
1031}
1032
1033pub fn train_step_rmsprop_with_accumulation<F>(
1036 graph: &mut Graph,
1037 optimizer: &mut RmsProp,
1038 trainable_nodes: &[NodeId],
1039 accumulation_steps: usize,
1040 loss_fn: SupervisedLoss,
1041 mut micro_batch_fn: F,
1042) -> Result<f32, ModelError>
1043where
1044 F: FnMut(&mut Graph) -> Result<(NodeId, NodeId), ModelError>,
1045{
1046 train_step_with_accumulation_impl(
1047 graph,
1048 optimizer,
1049 trainable_nodes,
1050 accumulation_steps,
1051 loss_fn,
1052 &mut micro_batch_fn,
1053 )
1054}
1055
1056#[allow(clippy::type_complexity)]
1057fn train_step_with_accumulation_impl<O: GraphOptimizer>(
1058 graph: &mut Graph,
1059 optimizer: &mut O,
1060 trainable_nodes: &[NodeId],
1061 accumulation_steps: usize,
1062 loss_fn: SupervisedLoss,
1063 micro_batch_fn: &mut dyn FnMut(&mut Graph) -> Result<(NodeId, NodeId), ModelError>,
1064) -> Result<f32, ModelError> {
1065 if accumulation_steps == 0 {
1066 return Err(ModelError::InvalidAccumulationSteps {
1067 steps: accumulation_steps,
1068 });
1069 }
1070
1071 let scale = 1.0 / accumulation_steps as f32;
1072 let mut accumulated: Vec<Option<Tensor>> = vec![None; trainable_nodes.len()];
1073 let mut total_loss = 0.0f32;
1074
1075 for _ in 0..accumulation_steps {
1076 graph.zero_grads();
1078
1079 let (prediction, target) = micro_batch_fn(graph)?;
1080 let loss_node = build_loss_node(graph, prediction, target, loss_fn)?;
1081 graph.backward(loss_node)?;
1082
1083 let loss_value = graph.value(loss_node)?.data()[0];
1084 total_loss += loss_value;
1085
1086 for (i, &node) in trainable_nodes.iter().enumerate() {
1088 if let Some(grad) = graph.grad(node)? {
1089 let scaled = grad.scale(scale);
1090 accumulated[i] = Some(match accumulated[i].take() {
1091 Some(acc) => acc.add(&scaled)?,
1092 None => scaled,
1093 });
1094 }
1095 }
1096 }
1097
1098 for (i, &node) in trainable_nodes.iter().enumerate() {
1100 if let Some(grad) = accumulated[i].take() {
1101 graph.set_grad(node, grad)?;
1102 }
1103 }
1104
1105 for &node in trainable_nodes {
1106 optimizer.step_graph_node(graph, node)?;
1107 }
1108
1109 Ok(total_loss / accumulation_steps as f32)
1110}
1111
1112pub fn infer_batch(
1114 model: &SequentialModel,
1115 input: &yscv_tensor::Tensor,
1116) -> Result<yscv_tensor::Tensor, ModelError> {
1117 model.forward_inference(input)
1118}
1119
1120pub fn infer_batch_graph(
1122 graph: &mut Graph,
1123 model: &SequentialModel,
1124 input: yscv_tensor::Tensor,
1125) -> Result<yscv_tensor::Tensor, ModelError> {
1126 let persistent = model.persistent_node_count();
1127 graph.truncate(persistent)?;
1128 let input_node = graph.variable(input);
1129 let output_node = model.forward(graph, input_node)?;
1130 Ok(graph.value(output_node)?.clone())
1131}
1132
1133#[allow(private_bounds)]
1152pub fn train_epoch_distributed<F, O: GraphOptimizer>(
1153 graph: &mut Graph,
1154 optimizer: &mut O,
1155 aggregator: &mut dyn GradientAggregator,
1156 trainable_nodes: &[NodeId],
1157 num_batches: usize,
1158 train_batch_fn: &mut F,
1159) -> Result<EpochMetrics, ModelError>
1160where
1161 F: FnMut(&mut Graph, usize) -> Result<f32, ModelError>,
1162{
1163 if num_batches == 0 {
1164 return Err(ModelError::EmptyDataset);
1165 }
1166
1167 let mut loss_sum = 0.0f32;
1168
1169 for batch_idx in 0..num_batches {
1170 let loss_value = train_batch_fn(graph, batch_idx)?;
1172 loss_sum += loss_value;
1173
1174 let mut local_grads = Vec::with_capacity(trainable_nodes.len());
1176 for &node in trainable_nodes {
1177 let grad = match graph.grad(node)?.cloned() {
1178 Some(g) => g,
1179 None => {
1180 let val = graph.value(node)?;
1183 Tensor::zeros(val.shape().to_vec())?
1184 }
1185 };
1186 local_grads.push(grad);
1187 }
1188
1189 let aggregated = aggregator.aggregate(&local_grads)?;
1191
1192 for (i, &node) in trainable_nodes.iter().enumerate() {
1194 graph.set_grad(node, aggregated[i].clone())?;
1195 }
1196
1197 for &node in trainable_nodes {
1198 optimizer.step_graph_node(graph, node)?;
1199 }
1200 }
1201
1202 Ok(EpochMetrics {
1203 mean_loss: loss_sum / num_batches as f32,
1204 steps: num_batches,
1205 })
1206}
1207
1208pub fn train_epoch_distributed_sgd(
1214 graph: &mut Graph,
1215 model: &SequentialModel,
1216 optimizer: &mut Sgd,
1217 aggregator: &mut dyn GradientAggregator,
1218 dataset: &SupervisedDataset,
1219 batch_size: usize,
1220 loss: SupervisedLoss,
1221) -> Result<EpochMetrics, ModelError> {
1222 if dataset.is_empty() {
1223 return Err(ModelError::EmptyDataset);
1224 }
1225 let batches: Vec<_> = dataset
1226 .batches_with_options(batch_size, BatchIterOptions::default())?
1227 .collect();
1228 let trainable_nodes = model.trainable_nodes();
1229 let persistent = model.persistent_node_count();
1230 let num_batches = batches.len();
1231
1232 let mut batch_iter = batches.into_iter();
1233
1234 train_epoch_distributed(
1235 graph,
1236 optimizer,
1237 aggregator,
1238 &trainable_nodes,
1239 num_batches,
1240 &mut |g, _batch_idx| {
1241 let batch = batch_iter.next().ok_or(ModelError::EmptyDataset)?;
1242 g.truncate(persistent)?;
1243 let input = g.constant(batch.inputs);
1244 let target = g.constant(batch.targets);
1245 let prediction = model.forward(g, input)?;
1246 let loss_node = build_loss_node(g, prediction, target, loss)?;
1247 g.backward(loss_node)?;
1248 let loss_value = g.value(loss_node)?.data()[0];
1249 Ok(loss_value)
1250 },
1251 )
1252}