use crate::core::error::ModelError;
use crate::core::types::{Matrix, Vector};
use crate::model::core::base::OptimizableModel;
use crate::optim::core::state::OptimizerState;
use ndarray::{ArrayView, ArrayViewMut, IxDyn};
use std::marker::PhantomData;
pub struct GradientDescentState<Input, Output, M: OptimizableModel<Input, Output>> {
learning_rate: f64,
_phantom: PhantomData<(Input, Output, M)>,
}
impl<Input, Output, M: OptimizableModel<Input, Output>> GradientDescentState<Input, Output, M> {
pub fn new(learning_rate: f64) -> Self {
Self {
learning_rate,
_phantom: PhantomData,
}
}
}
impl<M: OptimizableModel<Matrix, Vector>> OptimizerState<Matrix, Vector, M>
for GradientDescentState<Matrix, Vector, M>
{
fn update_weights(&mut self, model: &mut M) -> Result<(), ModelError> {
let mut updates = Vec::new();
for (key, param_view) in model.param_iter() {
let grad_view: ArrayView<f64, IxDyn> = model.get_gradient(key)?;
if param_view.ndim() != grad_view.ndim() {
return Err(ModelError::DimensionalityError(
"Parameter and gradient dimensions do not match".to_string(),
));
}
let param_view = param_view
.into_dimensionality::<ndarray::IxDyn>()
.map_err(|_| {
ModelError::DimensionalityError(
"Failed to convert parameter to dynamic dimensions".to_string(),
)
})?;
let grad_view = grad_view
.into_dimensionality::<ndarray::IxDyn>()
.map_err(|_| {
ModelError::DimensionalityError(
"Failed to convert gradient to dynamic dimensions".to_string(),
)
})?;
let updated_param = ¶m_view - (self.learning_rate * &grad_view);
updates.push((key.to_string(), updated_param.to_owned()));
}
for (key, updated_param) in updates {
let mut current_param: ArrayViewMut<f64, IxDyn> = model.get_mut(&key)?;
current_param.assign(&updated_param);
}
Ok(())
}
}