use crate::error::Result;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
#[allow(clippy::too_many_arguments, clippy::type_complexity)]
pub trait FusedQkvOps<R: Runtime> {
fn fused_qkv_projection(
&self,
input: &Tensor<R>,
weight: &Tensor<R>,
bias: Option<&Tensor<R>>,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>;
fn fused_output_projection_residual(
&self,
attn_out: &Tensor<R>,
weight: &Tensor<R>,
bias: Option<&Tensor<R>>,
residual: &Tensor<R>,
) -> Result<Tensor<R>>;
fn fused_qkv_projection_bwd(
&self,
dq: &Tensor<R>,
dk: &Tensor<R>,
dv: &Tensor<R>,
input: &Tensor<R>,
weight: &Tensor<R>,
has_bias: bool,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
) -> Result<(Tensor<R>, Tensor<R>, Option<Tensor<R>>)>;
fn fused_output_projection_residual_bwd(
&self,
d_output: &Tensor<R>,
attn_out: &Tensor<R>,
weight: &Tensor<R>,
has_bias: bool,
) -> Result<(Tensor<R>, Tensor<R>, Option<Tensor<R>>, Tensor<R>)>;
}