nevermind_neu/optimizers/
mod.rs

1mod optim_sgd;
2mod optim_adagrad;
3mod optim_rms;
4mod optim_adam;
5
6#[cfg(feature = "opencl")]
7mod optim_ocl_sgd;
8#[cfg(feature = "opencl")]
9mod optim_ocl_rms;
10#[cfg(feature = "opencl")]
11mod optim_ocl_adam;
12#[cfg(feature = "opencl")]
13mod optim_ocl;
14#[cfg(feature = "opencl")]
15mod optim_ocl_fabric;
16
17mod optim_fabric;
18
19pub use optim_rms::*;
20pub use optim_adagrad::*;
21pub use optim_adam::*;
22pub use optim_sgd::*;
23pub use optim_fabric::*;
24#[cfg(feature = "opencl")]
25pub use optim_ocl_sgd::*;
26#[cfg(feature = "opencl")]
27pub use optim_ocl_rms::*;
28#[cfg(feature = "opencl")]
29pub use optim_ocl_adam::*;
30#[cfg(feature = "opencl")]
31pub use optim_ocl::*;
32#[cfg(feature ="opencl")]
33pub use optim_ocl_fabric::*;
34
35use crate::cpu_params::*;
36use crate::util::*;
37use crate::err::*;
38use crate::layers::*;
39
40pub trait Optimizer : WithParams {
41    fn optimize_params(&mut self, learn_params: &mut CpuParams, opt_prms: TrainableBufsIds);
42    fn parallel_optimize(&mut self, _learn_params: Vec<(CpuParams, TrainableBufsIds)>) { todo!("filler, default impl will be removed") }
43}
44
45pub fn optimizer_from_type(opt_type: &str) -> Result<Box<dyn Optimizer>, CustomError> {
46    match opt_type {
47        "rmsprop" => {
48            return Ok(Box::new(OptimizerRMS::default()));
49        },
50        "sgd" => {
51            return Ok(Box::new(OptimizerSGD::default()));
52        },
53        "adagrad" => {
54            return Ok(Box::new(OptimizerAdaGrad::default()));
55        },
56        "adam" => {
57            return Ok(Box::new(OptimizerAdam::default()));
58        },
59        _ => {
60            return Err(CustomError::WrongArg);
61        }
62    }
63}
64
65impl Default for Box<dyn Optimizer> {
66    fn default() -> Self {
67        Box::new(OptimizerRMS::new(1e-2, 0.9))
68    }
69}
70
71impl Clone for Box<dyn Optimizer> {
72    fn clone(&self) -> Self {
73        Box::new(OptimizerRMS::new(1e-2, 0.9))
74    }
75}