1use nalgebra::DVector;
2use nalgebra::RowDVector;
3
4use super::network::Network;
5
6pub trait Trainer<N: Network> {
8 fn new(nn: N, p: Vec<f64>,
9 x_values: Vec<DVector<f64>>,
10 d_values: Vec<DVector<f64>>) -> Self;
11 fn make_step(&mut self);
12
13 fn cost(&self) -> f64;
14 fn grad(&mut self) -> RowDVector<f64>;
15 fn grad_norm(&mut self) -> f64;
16 fn params(&self) -> &[f64];
17}
18
19mod common;
20
21mod gd_trainer;
22pub use gd_trainer::*;
23
24mod cg_trainer;
25pub use cg_trainer::*;
26
27mod lm_trainer;
28pub use lm_trainer::*;