use crate::ops::elementwise::add_bias;
use crate::ops::matmul::matmul_t_b;
use crate::tensor::Tensor;
pub fn linear(x: &Tensor, weight: &Tensor, bias: Option<&Tensor>) -> Tensor {
let y = matmul_t_b(x, weight);
match bias {
Some(b) => add_bias(&y, b),
None => y,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn linear_with_bias() {
let x = Tensor::from_vec(vec![1.0, 2.0], &[1, 2]);
let w = Tensor::from_vec(vec![1., 0., 0., 1., 1., 1.], &[3, 2]);
let b = Tensor::from_vec(vec![10., 20., 30.], &[3]);
let y = linear(&x, &w, Some(&b));
assert_eq!(y.shape().as_slice(), &[1, 3]);
assert_eq!(y.data(), &[11.0, 22.0, 33.0]);
}
#[test]
fn linear_without_bias() {
let x = Tensor::from_vec(vec![1.0, 2.0], &[1, 2]);
let w = Tensor::from_vec(vec![1., 0., 0., 1.], &[2, 2]);
let y = linear(&x, &w, None);
assert_eq!(y.data(), &[1.0, 2.0]);
}
}