#![warn(missing_docs)]
use std::{error::Error, marker::PhantomData, str::FromStr};
use crate::{Tensor, TensorElement};
use rayon::prelude::*;
pub trait Optimizer<'a, T>
where
T: TensorElement,
<T as FromStr>::Err: Error,
{
fn init(learning_rate: T, momentum: f32) -> Self;
fn minimize<F>(&mut self, cost: F, vars: &mut Vec<T>)
where
F: Fn(&Tensor<'a, T>, &Tensor<'a, T>) -> T;
}
pub struct Adam<T> {
lr: T,
momentum: f32,
decay_rate: f32,
beta1: f32,
beta2: f32,
epsilon: f32,
m_dw: f32,
v_dw: f32,
m_db: f32,
v_db: f32,
}
impl<'a, T> Optimizer<'a, T> for Adam<T>
where
T: TensorElement,
<T as FromStr>::Err: Error,
{
fn init(learning_rate: T, momentum: f32) -> Self {
Self {
lr: learning_rate,
momentum,
decay_rate: 0.2,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
m_dw: 0.0,
m_db: 0.0,
v_dw: 0.0,
v_db: 0.0,
}
}
fn minimize<F>(&mut self, cost: F, vars: &mut Vec<T>)
where
F: Fn(&Tensor<'a, T>, &Tensor<'a, T>) -> T,
{
unimplemented!()
}
}
pub struct SGD<T> {
lr: T,
momentum: f32,
decay_rate: f32,
}
impl<'a, T> Optimizer<'a, T> for SGD<T>
where
T: TensorElement,
<T as FromStr>::Err: Error,
{
fn init(learning_rate: T, momentum: f32) -> Self {
Self {
lr: learning_rate,
momentum,
decay_rate: 0.2,
}
}
fn minimize<F>(&mut self, cost: F, vars: &mut Vec<T>)
where
F: Fn(&Tensor<'a, T>, &Tensor<'a, T>) -> T,
{
todo!()
}
}