#[allow(unused_import_braces)]
pub use self::sgd::Momentum;
pub mod sgd;
use crate::co::{IBackend, SharedTensor};
use crate::layer::*;
use crate::solver::*;
use crate::util::*;
trait SGDSolver<SolverB: IBackend + SolverOps<f32>, NetB: IBackend + LayerOps<f32>>: ISolver<SolverB, NetB> {
fn compute_update_value(
&mut self,
config: &SolverConfig,
weight_blob: &ArcLock<SharedTensor<f32>>,
history_blob_id: usize,
global_lr: &f32,
blob_lr: &f32,
);
#[allow(unused_must_use)]
fn clip_gradients<B: IBackend + LayerOps<f32> + 'static>(&self, config: &SolverConfig, net: &mut Layer<B>) {
if let Some(clip_threshold) = config.clip_gradients {
let native = native_backend();
let net_gradients = net.learnable_weights_gradients();
let mut sumsq_diff = 0f32;
let backend = self.backend();
for net_gradient in net_gradients.clone() {
let gradient = net_gradient.read().unwrap();
let mut result = SharedTensor::new(&[1]);
self.backend().dot(&gradient, &gradient, &mut result);
let sumsq_diff_slice = result.read(native.device()).unwrap().as_slice::<f32>();
sumsq_diff += sumsq_diff_slice[0];
}
let l2norm_diff = sumsq_diff.sqrt();
if l2norm_diff > clip_threshold {
let scale_factor = clip_threshold / l2norm_diff;
info!(
"Gradient clipping: scaling down gradients (L2 norm {} > {})
by scale factor {}",
l2norm_diff, clip_threshold, scale_factor
);
let mut scale_shared = native_scalar(scale_factor);
for weight_gradient in net_gradients {
let mut gradient = weight_gradient.write().unwrap();
backend.scal(&mut scale_shared, &mut gradient);
}
}
}
}
fn normalize(&self, config: &SolverConfig, weight_blob: &ArcLock<SharedTensor<f32>>) {
if config.minibatch_size > 1 {
let scale_factor = 1f32 / config.minibatch_size as f32;
let mut gradient = weight_blob.write().unwrap();
let native = native_backend();
let mut scale_factor_shared = native_scalar(scale_factor);
self.backend().scal(&mut scale_factor_shared, &mut gradient).unwrap();
}
}
fn regularize(
&self,
config: &SolverConfig,
weight_gradient: &ArcLock<SharedTensor<f32>>,
blob_weight_decay: Option<f32>,
) {
if let Some(global_weight_decay) = config.weight_decay {
if let Some(regularization_method) = config.regularization_method {
match blob_weight_decay {
Some(weight_decay_mult) => {
let local_decay = global_weight_decay * weight_decay_mult;
match regularization_method {
RegularizationMethod::L2 => {
let native = native_backend();
let decay_shared = native_scalar(local_decay);
let gradient = &mut weight_gradient.write().unwrap();
unimplemented!();
}
}
}
None => {
error!("Weight decay multiplier for gradient missing.");
}
}
}
}
}
}