use crate::prelude::{CanUpdateWithGradients, Gradients, UnusedTensors};
pub trait Optimizer<M: CanUpdateWithGradients> {
fn update(&mut self, module: &mut M, gradients: Gradients) -> Result<(), UnusedParamsError>;
}
#[derive(Debug)]
pub struct UnusedParamsError(UnusedTensors);
impl std::fmt::Display for UnusedParamsError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UnusedParamsError")
.field("tensors", &self.0)
.finish()
}
}
impl std::error::Error for UnusedParamsError {}
#[allow(clippy::from_over_into)]
impl Into<Result<(), UnusedParamsError>> for UnusedTensors {
fn into(self) -> Result<(), UnusedParamsError> {
if self.is_empty() {
Ok(())
} else {
Err(UnusedParamsError(self))
}
}
}