pub mod basic;
pub mod momentum;
pub mod nesterov;
pub mod adagrad;
pub mod adam;
pub use basic::BasicOptimizer as Basic;
pub use momentum::MomentumOptimizer as Momentum;
pub use nesterov::NesterovOptimizer as Nesterov;
pub use adagrad::AdagradOptimizer as Adagrad;
pub use adam::AdamOptimizer as Adam;
use intricate_macros::FromForAllUnnamedVariants;
use opencl3::{device::cl_float, error_codes::ClError, memory::Buffer};
use crate::utils::{opencl::BufferOperationError, OpenCLState};
#[derive(Debug, FromForAllUnnamedVariants)]
pub enum OptimizationError {
OpenCL(ClError),
BufferOperation(BufferOperationError),
NoCommandQueueFound,
UninitializedState,
}
pub trait Optimizer<'a>
where Self: std::fmt::Debug {
fn init(
&mut self,
opencl_state: &'a OpenCLState,
) -> Result<(), ClError>;
fn optimize_parameters(
&self,
parameters: &mut Buffer<cl_float>,
parameter_id: String,
timestep: usize,
layer_index: usize,
) -> Result<(), OptimizationError>;
fn compute_update_vectors(
&mut self,
gradients: &Buffer<cl_float>,
parameter_id: String,
timestep: usize,
layer_index: usize,
) -> Result<Buffer<cl_float>, OptimizationError>;
}