nevermind_neu/optimizers/
mod.rs1mod optim_sgd;
2mod optim_adagrad;
3mod optim_rms;
4mod optim_adam;
5
6#[cfg(feature = "opencl")]
7mod optim_ocl_sgd;
8#[cfg(feature = "opencl")]
9mod optim_ocl_rms;
10#[cfg(feature = "opencl")]
11mod optim_ocl_adam;
12#[cfg(feature = "opencl")]
13mod optim_ocl;
14#[cfg(feature = "opencl")]
15mod optim_ocl_fabric;
16
17mod optim_fabric;
18
19pub use optim_rms::*;
20pub use optim_adagrad::*;
21pub use optim_adam::*;
22pub use optim_sgd::*;
23pub use optim_fabric::*;
24#[cfg(feature = "opencl")]
25pub use optim_ocl_sgd::*;
26#[cfg(feature = "opencl")]
27pub use optim_ocl_rms::*;
28#[cfg(feature = "opencl")]
29pub use optim_ocl_adam::*;
30#[cfg(feature = "opencl")]
31pub use optim_ocl::*;
32#[cfg(feature ="opencl")]
33pub use optim_ocl_fabric::*;
34
35use crate::cpu_params::*;
36use crate::util::*;
37use crate::err::*;
38use crate::layers::*;
39
40pub trait Optimizer : WithParams {
41 fn optimize_params(&mut self, learn_params: &mut CpuParams, opt_prms: TrainableBufsIds);
42 fn parallel_optimize(&mut self, _learn_params: Vec<(CpuParams, TrainableBufsIds)>) { todo!("filler, default impl will be removed") }
43}
44
45pub fn optimizer_from_type(opt_type: &str) -> Result<Box<dyn Optimizer>, CustomError> {
46 match opt_type {
47 "rmsprop" => {
48 return Ok(Box::new(OptimizerRMS::default()));
49 },
50 "sgd" => {
51 return Ok(Box::new(OptimizerSGD::default()));
52 },
53 "adagrad" => {
54 return Ok(Box::new(OptimizerAdaGrad::default()));
55 },
56 "adam" => {
57 return Ok(Box::new(OptimizerAdam::default()));
58 },
59 _ => {
60 return Err(CustomError::WrongArg);
61 }
62 }
63}
64
65impl Default for Box<dyn Optimizer> {
66 fn default() -> Self {
67 Box::new(OptimizerRMS::new(1e-2, 0.9))
68 }
69}
70
71impl Clone for Box<dyn Optimizer> {
72 fn clone(&self) -> Self {
73 Box::new(OptimizerRMS::new(1e-2, 0.9))
74 }
75}