libdt/trainer/
mod.rs

1use nalgebra::DVector;
2use nalgebra::RowDVector;
3
4use super::network::Network;
5
6/// Neural network trainer.
7pub trait Trainer<N: Network> {
8    fn new(nn: N, p: Vec<f64>,
9           x_values: Vec<DVector<f64>>,
10           d_values: Vec<DVector<f64>>) -> Self;
11    fn make_step(&mut self);
12
13    fn cost(&self) -> f64;
14    fn grad(&mut self) -> RowDVector<f64>;
15    fn grad_norm(&mut self) -> f64;
16    fn params(&self) -> &[f64];
17}
18
19mod common;
20
21mod gd_trainer;
22pub use gd_trainer::*;
23
24mod cg_trainer;
25pub use cg_trainer::*;
26
27mod lm_trainer;
28pub use lm_trainer::*;