use crate::error::Result;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum KvQuantMode {
PerTensor,
PerToken,
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub enum Int4GroupSize {
Group32 = 32,
#[default]
Group64 = 64,
Group128 = 128,
}
pub trait KvCacheQuantOps<R: Runtime> {
fn quantize_kv_fp8_per_token(
&self,
input: &Tensor<R>,
num_tokens: usize,
head_dim: usize,
) -> Result<(Tensor<R>, Tensor<R>)>;
fn dequantize_kv_fp8_per_token(
&self,
quantized: &Tensor<R>,
scales: &Tensor<R>,
num_tokens: usize,
head_dim: usize,
output_dtype: numr::dtype::DType,
) -> Result<Tensor<R>>;
fn quantize_kv_int4(
&self,
input: &Tensor<R>,
num_tokens: usize,
head_dim: usize,
group_size: Int4GroupSize,
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>;
fn dequantize_kv_int4(
&self,
packed: &Tensor<R>,
scales: &Tensor<R>,
zeros: &Tensor<R>,
num_tokens: usize,
head_dim: usize,
group_size: Int4GroupSize,
) -> Result<Tensor<R>>;
fn quantize_kv_int8(
&self,
input: &Tensor<R>,
num_tokens: usize,
head_dim: usize,
) -> Result<(Tensor<R>, Tensor<R>)>;
fn dequantize_kv_int8(
&self,
quantized: &Tensor<R>,
scales: &Tensor<R>,
num_tokens: usize,
head_dim: usize,
) -> Result<Tensor<R>>;
}