pub use self::momentum::Momentum;
#[macro_export]
macro_rules! impl_isolver_sgd {
($t:ty) => {
impl<SolverB: IBackend + SolverOps<f32>, NetB: IBackend + LayerOps<f32> + 'static> ISolver<SolverB, NetB>
for $t
{
fn init(&mut self, net: &Layer<NetB>) {
self.history = Vec::with_capacity(net.learnable_weights_gradients().len());
for weight_gradient in net.learnable_weights_gradients() {
let shape = weight_gradient.read().unwrap().desc().clone();
let mut tensor = SharedTensor::new(&shape);
let filler = crate::weight::FillerType::Constant { value: 0f32 };
filler.fill(&mut tensor);
let history_tensor = Arc::new(RwLock::new(tensor));
self.history.push(history_tensor);
}
}
fn compute_update(&mut self, config: &SolverConfig, net: &mut Layer<NetB>, iter: usize) {
let rate = config.get_learning_rate(iter);
SGDSolver::<SolverB, NetB>::clip_gradients(self, config, net);
for (weight_id, weight_gradient) in net.learnable_weights_gradients().iter().enumerate() {
SGDSolver::<SolverB, NetB>::normalize(self, config, weight_gradient);
SGDSolver::<SolverB, NetB>::compute_update_value(
self,
config,
weight_gradient,
weight_id,
&rate,
&net.learnable_weights_lr()[weight_id].unwrap(),
);
}
}
fn backend(&self) -> &SolverB {
&self.backend
}
}
};
}
pub mod momentum;