1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
//! Optimizers such as [Sgd], [Adam], and [RMSprop] that can optimize neural networks.
//!
//! # Initializing
//!
//! All the optimizer's provide [Default] implementations, and also provide a way to specify
//! all the relevant parameters through the corresponding config object:
//! - [Sgd::new()] with [SgdConfig]
//! - [Adam::new()] with [AdamConfig]
//! - [RMSprop::new()] with [RMSpropConfig]
//!
//! # Updating network parameters
//!
//! This is done via [Optimizer::update()], where you pass in a mutable [crate::nn::Module], and
//! the [crate::tensor::Gradients]:
//!
//! ```rust
//! # use dfdx::{prelude::*, optim::*, losses};
//! # type MyModel = Linear<5, 2>;
//! # let dev: Cpu = Default::default();
//! let mut model = MyModel::build_on_device(&dev);
//! let mut grads = model.alloc_grads();
//! let mut opt = Sgd::new(&model, Default::default());
//! # let x: Tensor<Rank1<5>, f32, _> = dev.zeros();
//! # let y = model.forward(x.traced(grads));
//! # let loss = losses::mse_loss(y, dev.zeros());
//! // -- snip loss computation --
//!
//! grads = loss.backward();
//! opt.update(&mut model, &grads);
//! model.zero_grads(&mut grads);
//! ```

mod adam;
mod optimizer;
mod rmsprop;
mod sgd;

pub use adam::Adam;
pub use optimizer::{Optimizer, OptimizerUpdateError, UnusedTensors};
pub use rmsprop::RMSprop;
pub use sgd::Sgd;

// re-exports
pub use crate::tensor_ops::{AdamConfig, Momentum, RMSpropConfig, SgdConfig, WeightDecay};

pub mod prelude {
    pub use super::{Optimizer, OptimizerUpdateError, UnusedTensors};
}