pub mod confusion_matrix;
pub mod regression_evaluator;
pub use self::confusion_matrix::ConfusionMatrix;
pub use self::regression_evaluator::{RegressionEvaluator, RegressionLoss};
use crate::co::prelude::*;
use crate::layer::*;
use crate::layers::SequentialConfig;
use crate::solvers::*;
use std::marker::PhantomData;
use crate::util::{ArcLock, LayerOps, SolverOps};
use std::rc::Rc;
#[derive(Debug)]
pub struct Solver<SolverB, B>
where
SolverB: IBackend + SolverOps<f32>,
B: IBackend + LayerOps<f32>,
{
net: Layer<B>,
objective: Layer<SolverB>,
pub worker: Box<dyn ISolver<SolverB, B>>,
config: SolverConfig,
iter: usize,
solver_backend: PhantomData<SolverB>,
}
impl<SolverB, B> Solver<SolverB, B>
where
SolverB: IBackend + SolverOps<f32> + 'static,
B: IBackend + LayerOps<f32> + 'static,
{
pub fn from_config(net_backend: Rc<B>, obj_backend: Rc<SolverB>, config: &SolverConfig) -> Solver<SolverB, B> {
let network = Layer::from_config(net_backend, &config.network);
let mut worker = config.solver.with_config(obj_backend.clone(), &config);
worker.init(&network);
Solver {
worker: worker,
net: network,
objective: Layer::from_config(obj_backend, &config.objective),
iter: 0,
config: config.clone(),
solver_backend: PhantomData::<SolverB>,
}
}
}
impl<SolverB, B> Solver<SolverB, B>
where
SolverB: IBackend + SolverOps<f32> + 'static,
B: IBackend + LayerOps<f32> + 'static,
{
fn init(&mut self, backend: Rc<B>) {
info!("Initializing solver from configuration");
let mut config = self.config.clone();
self.init_net(backend, &mut config);
}
fn init_net(&mut self, backend: Rc<B>, param: &mut SolverConfig) {
self.net = Layer::from_config(backend, ¶m.network);
}
pub fn train_minibatch(
&mut self,
mb_data: ArcLock<SharedTensor<f32>>,
mb_target: ArcLock<SharedTensor<f32>>,
) -> ArcLock<SharedTensor<f32>> {
let network_out = self.net.forward(&[mb_data])[0].clone();
let _ = self.objective.forward(&[network_out.clone(), mb_target]);
let classifier_gradient = self.objective.backward(&[]);
self.net.backward(&classifier_gradient[0..1]);
self.worker.compute_update(&self.config, &mut self.net, self.iter);
self.net.update_weights(self.worker.backend());
self.iter += 1;
network_out
}
pub fn network(&self) -> &Layer<B> {
&self.net
}
pub fn mut_network(&mut self) -> &mut Layer<B> {
&mut self.net
}
}
pub trait ISolver<SolverB, B>
where
B: IBackend + LayerOps<f32>,
SolverB: IBackend + SolverOps<f32>,
{
fn init(&mut self, net: &Layer<B>) {}
fn compute_update(&mut self, param: &SolverConfig, network: &mut Layer<B>, iter: usize);
fn backend(&self) -> &SolverB;
}
impl<SolverB, B: IBackend + LayerOps<f32>> ::std::fmt::Debug for dyn ISolver<SolverB, B> {
fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
write!(f, "({})", "ILayer")
}
}
#[derive(Debug, Clone)]
pub struct SolverConfig {
pub name: String,
pub network: LayerConfig,
pub objective: LayerConfig,
pub solver: SolverKind,
pub minibatch_size: usize,
pub lr_policy: LRPolicy,
pub base_lr: f32,
pub gamma: f32,
pub stepsize: usize,
pub clip_gradients: Option<f32>,
pub weight_decay: Option<f32>,
pub regularization_method: Option<RegularizationMethod>,
pub momentum: f32,
}
impl Default for SolverConfig {
fn default() -> SolverConfig {
SolverConfig {
name: "".to_owned(),
network: LayerConfig::new("default", SequentialConfig::default()),
objective: LayerConfig::new("default", SequentialConfig::default()),
solver: SolverKind::SGD(SGDKind::Momentum),
minibatch_size: 1,
lr_policy: LRPolicy::Fixed,
base_lr: 0.01f32,
gamma: 0.1f32,
stepsize: 10,
clip_gradients: None,
weight_decay: None,
regularization_method: None,
momentum: 0f32,
}
}
}
impl SolverConfig {
pub fn get_learning_rate(&self, iter: usize) -> f32 {
match self.lr_policy() {
LRPolicy::Fixed => self.base_lr(),
LRPolicy::Step => {
let current_step = self.step(iter);
self.base_lr() * self.gamma().powf(current_step as f32)
}
LRPolicy::Exp => self.base_lr() * self.gamma().powf(iter as f32),
}
}
fn step(&self, iter: usize) -> usize {
iter / self.stepsize()
}
fn lr_policy(&self) -> LRPolicy {
self.lr_policy
}
fn base_lr(&self) -> f32 {
self.base_lr
}
fn gamma(&self) -> f32 {
self.gamma
}
fn stepsize(&self) -> usize {
self.stepsize
}
}
#[derive(Debug, Copy, Clone)]
pub enum SolverKind {
SGD(SGDKind),
}
impl SolverKind {
pub fn with_config<B: IBackend + SolverOps<f32> + 'static, NetB: IBackend + LayerOps<f32> + 'static>(
&self,
backend: Rc<B>,
config: &SolverConfig,
) -> Box<dyn ISolver<B, NetB>> {
match *self {
SolverKind::SGD(sgd) => sgd.with_config(backend, config),
}
}
}
#[derive(Debug, Copy, Clone)]
pub enum SGDKind {
Momentum,
}
impl SGDKind {
pub fn with_config<B: IBackend + SolverOps<f32> + 'static, NetB: IBackend + LayerOps<f32> + 'static>(
&self,
backend: Rc<B>,
config: &SolverConfig,
) -> Box<dyn ISolver<B, NetB>> {
match *self {
SGDKind::Momentum => Box::new(Momentum::<B>::new(backend)),
}
}
}
#[derive(Debug, Copy, Clone)]
pub enum LRPolicy {
Fixed,
Step,
Exp,
}
#[derive(Debug, Copy, Clone)]
pub enum RegularizationMethod {
L2,
}