use crate::error::Result;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
#[allow(clippy::too_many_arguments, clippy::type_complexity)]
pub trait FusedOptimizerOps<R: Runtime> {
fn fused_adamw_step(
&self,
param: &Tensor<R>,
grad: &Tensor<R>,
m: &Tensor<R>,
v: &Tensor<R>,
lr: f64,
beta1: f64,
beta2: f64,
eps: f64,
wd: f64,
step_size: f64,
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>;
fn fused_sgd_step(
&self,
param: &Tensor<R>,
grad: &Tensor<R>,
momentum_buf: Option<&Tensor<R>>,
lr: f64,
momentum: f64,
dampening: f64,
wd: f64,
nesterov: bool,
) -> Result<(Tensor<R>, Tensor<R>)>;
fn fused_adagrad_step(
&self,
param: &Tensor<R>,
grad: &Tensor<R>,
accum: &Tensor<R>,
lr: f64,
eps: f64,
wd: f64,
) -> Result<(Tensor<R>, Tensor<R>)>;
fn fused_lamb_step(
&self,
param: &Tensor<R>,
grad: &Tensor<R>,
m: &Tensor<R>,
v: &Tensor<R>,
beta1: f64,
beta2: f64,
eps: f64,
wd: f64,
bias_corr1: f64,
bias_corr2: f64,
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>;
fn fused_multi_tensor_adamw(
&self,
groups: &[(&Tensor<R>, &Tensor<R>, &Tensor<R>, &Tensor<R>)],
lr: f64,
beta1: f64,
beta2: f64,
eps: f64,
wd: f64,
step_size: f64,
) -> Result<Vec<(Tensor<R>, Tensor<R>, Tensor<R>)>>;
}