mod cpu_kernel;
#[cfg(feature = "cuda")]
mod cuda_kernel;
use crate::{
shapes::{Dtype, Shape},
tensor::{Storage, Tensor},
};
use super::optim::{Momentum, WeightDecay};
#[derive(Debug, Clone, Copy)]
pub struct SgdConfig {
pub lr: f64,
pub momentum: Option<Momentum>,
pub weight_decay: Option<WeightDecay>,
}
impl Default for SgdConfig {
fn default() -> Self {
Self {
lr: 1e-2,
momentum: None,
weight_decay: None,
}
}
}
pub trait SgdKernel<E: Dtype>: Storage<E> {
fn sgd_kernel(
&self,
cfg: &SgdConfig,
param: &mut Self::Vec,
velocity: &mut Self::Vec,
grad: &Self::Vec,
) -> Result<(), Self::Err>;
}
impl SgdConfig {
pub fn try_update<S: Shape, E: Dtype, D: SgdKernel<E>>(
&self,
param: &mut Tensor<S, E, D>,
velocity: &mut D::Vec,
grad: &D::Vec,
) -> Result<(), D::Err> {
param.device.sgd_kernel(
self,
std::sync::Arc::make_mut(&mut param.data),
velocity,
grad,
)
}
}