use crate::core::error::ModelError;
use crate::core::types::{Matrix, Vector};
use crate::model::core::base::OptimizableModel;
use crate::optim::core::optimizer::Optimizer;
use crate::optim::core::state::OptimizerState;
use crate::optim::sgd::state::GradientDescentState;
pub struct GradientDescent<Input, Output, M: OptimizableModel<Input, Output>> {
epochs: usize,
pub cost_history: Vec<f64>,
state: GradientDescentState<Input, Output, M>,
}
impl<Input, Output, M: OptimizableModel<Input, Output>> GradientDescent<Input, Output, M> {
pub fn new(learning_rate: f64, epochs: usize) -> Self {
Self {
epochs,
cost_history: Vec::new(),
state: GradientDescentState::new(learning_rate),
}
}
}
impl<M: OptimizableModel<Matrix, Vector>> Optimizer<Matrix, Vector, M>
for GradientDescent<Matrix, Vector, M>
{
fn fit(&mut self, model: &mut M, x: &Matrix, y: &Vector) -> Result<(), ModelError> {
for _ in 0..self.epochs {
let cost = model.compute_cost(x, y)?;
self.cost_history.push(cost);
let output_gradient = model.compute_output_gradient(x, y)?;
model.backward(x, &output_gradient)?;
self.state.update_weights(model)?;
}
Ok(())
}
}