mod cpu_kernel;
#[cfg(feature = "cuda")]
mod cuda_kernel;
use crate::{
shapes::{Dtype, Shape},
tensor::{Storage, Tensor},
};
use super::WeightDecay;
#[derive(Debug, Clone, Copy)]
pub struct AdamConfig {
pub lr: f64,
pub betas: [f64; 2],
pub eps: f64,
pub weight_decay: Option<WeightDecay>,
}
impl Default for AdamConfig {
fn default() -> Self {
Self {
lr: 1e-3,
betas: [0.9, 0.999],
eps: 1e-8,
weight_decay: None,
}
}
}
pub trait AdamKernel<E: Dtype>: Storage<E> {
fn adam_kernel(
&self,
t: i32,
cfg: &AdamConfig,
param: &mut Self::Vec,
moment1: &mut Self::Vec,
moment2: &mut Self::Vec,
grad: &Self::Vec,
) -> Result<(), Self::Err>;
}
impl AdamConfig {
pub fn try_update<S: Shape, E: Dtype, D: AdamKernel<E>>(
&self,
t: i32,
param: &mut Tensor<S, E, D>,
moment1: &mut D::Vec,
moment2: &mut D::Vec,
grad: &D::Vec,
) -> Result<(), D::Err> {
param.device.adam_kernel(
t,
self,
std::sync::Arc::make_mut(&mut param.data),
moment1,
moment2,
grad,
)
}
}