use crate::error::Result;
use crate::nn::weight::Weight;
use crate::quant::decomposed::DecomposedQuantLinear;
use crate::quant::tensor::QuantTensor;
use crate::quant::traits::QuantMatmulOps;
use numr::autograd::{Var, var_add, var_matmul, var_transpose};
use numr::dtype::DType;
use numr::ops::{BinaryOps, TensorOps, TypeConversionOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub struct Linear<R: Runtime> {
weight: Var<R>,
bias: Option<Var<R>>,
}
impl<R: Runtime> Linear<R> {
pub fn new(weight: Tensor<R>, bias: Option<Tensor<R>>, trainable: bool) -> Self {
Self {
weight: Var::new(weight, trainable),
bias: bias.map(|b| Var::new(b, trainable)),
}
}
pub fn forward<C>(&self, client: &C, input: &Var<R>) -> Result<Var<R>>
where
C: RuntimeClient<R> + TensorOps<R>,
R::Client: TensorOps<R>,
{
let w_t = var_transpose(&self.weight).map_err(crate::error::Error::Numr)?;
let output = var_matmul(input, &w_t, client).map_err(crate::error::Error::Numr)?;
match &self.bias {
Some(bias) => var_add(&output, bias, client).map_err(crate::error::Error::Numr),
None => Ok(output),
}
}
pub fn weight(&self) -> &Var<R> {
&self.weight
}
pub fn bias(&self) -> Option<&Var<R>> {
self.bias.as_ref()
}
}
pub struct QuantLinear<R: Runtime> {
weight: QuantTensor<R>,
bias: Option<Tensor<R>>,
}
impl<R: Runtime> QuantLinear<R> {
pub fn new(weight: QuantTensor<R>, bias: Option<Tensor<R>>) -> Self {
Self { weight, bias }
}
pub fn forward<C>(&self, client: &C, input: &Tensor<R>) -> Result<Tensor<R>>
where
C: QuantMatmulOps<R> + BinaryOps<R> + RuntimeClient<R>,
{
let output = client.quant_matmul(input, &self.weight)?;
match &self.bias {
Some(bias) => client.add(&output, bias).map_err(crate::error::Error::Numr),
None => Ok(output),
}
}
pub fn weight(&self) -> &QuantTensor<R> {
&self.weight
}
pub fn bias(&self) -> Option<&Tensor<R>> {
self.bias.as_ref()
}
}
pub enum MaybeQuantLinear<R: Runtime> {
Standard(Linear<R>),
Quantized(QuantLinear<R>),
DecomposedQuant(Box<DecomposedQuantLinear<R>>),
}
impl<R: Runtime> MaybeQuantLinear<R> {
pub fn from_weight(weight: Weight<R>, bias: Option<Tensor<R>>) -> Self {
match weight {
Weight::Standard(t) => Self::Standard(Linear::new(t, bias, false)),
Weight::Quantized(qt) => Self::Quantized(QuantLinear::new(qt, bias)),
Weight::DecomposedQuant(dq) => {
Self::DecomposedQuant(Box::new(DecomposedQuantLinear::new(*dq, bias)))
}
}
}
pub fn forward<C>(&self, client: &C, input: &Var<R>) -> Result<Var<R>>
where
C: RuntimeClient<R>
+ TensorOps<R>
+ QuantMatmulOps<R>
+ BinaryOps<R>
+ TypeConversionOps<R>,
R: Runtime<DType = DType>,
R::Client: TensorOps<R>,
{
match self {
Self::Standard(linear) => linear.forward(client, input),
Self::Quantized(qlinear) => {
let out = qlinear.forward(client, input.tensor())?;
Ok(Var::new(out, false))
}
Self::DecomposedQuant(dqlinear) => {
let out = dqlinear.forward(client, input.tensor())?;
Ok(Var::new(out, false))
}
}
}
pub fn forward_batch<C>(
layers: &[&MaybeQuantLinear<R>],
client: &C,
input: &Var<R>,
) -> Result<Vec<Var<R>>>
where
C: RuntimeClient<R>
+ TensorOps<R>
+ QuantMatmulOps<R>
+ BinaryOps<R>
+ TypeConversionOps<R>,
R: Runtime<DType = DType>,
R::Client: TensorOps<R>,
{
let all_quantized_no_bias = layers
.iter()
.all(|l| matches!(l, MaybeQuantLinear::Quantized(ql) if ql.bias().is_none()));
if all_quantized_no_bias {
let weights: Vec<&QuantTensor<R>> = layers
.iter()
.map(|l| match l {
MaybeQuantLinear::Quantized(ql) => ql.weight(),
_ => unreachable!(),
})
.collect();
let outputs = client.quant_matmul_batch(input.tensor(), &weights)?;
Ok(outputs.into_iter().map(|t| Var::new(t, false)).collect())
} else {
layers.iter().map(|l| l.forward(client, input)).collect()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::cpu_setup;
use numr::runtime::cpu::CpuRuntime;
#[test]
fn test_linear_output_shape() {
let (client, device) = cpu_setup();
let weight = Tensor::<CpuRuntime>::from_slice(&[1.0f32; 12], &[4, 3], &device);
let linear = Linear::new(weight, None, false);
let input = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32; 6], &[2, 3], &device),
false,
);
let out = linear.forward(&client, &input).unwrap();
assert_eq!(out.shape(), &[2, 4]);
}
#[test]
fn test_linear_with_bias() {
let (client, device) = cpu_setup();
let weight = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 0.0, 0.0, 1.0], &[2, 2], &device);
let bias = Tensor::<CpuRuntime>::from_slice(&[10.0f32, 20.0], &[2], &device);
let linear = Linear::new(weight, Some(bias), false);
let input = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[1, 2], &device),
false,
);
let out = linear.forward(&client, &input).unwrap();
let data: Vec<f32> = out.tensor().to_vec();
assert_eq!(data, vec![11.0, 22.0]);
}
#[test]
fn test_linear_batched() {
let (client, device) = cpu_setup();
let weight = Tensor::<CpuRuntime>::from_slice(&[1.0f32; 6], &[2, 3], &device);
let linear = Linear::new(weight, None, false);
let input = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.1f32; 60], &[4, 5, 3], &device),
false,
);
let out = linear.forward(&client, &input).unwrap();
assert_eq!(out.shape(), &[4, 5, 2]);
}
}