gtensor 1.0.0

Reverse-mode autodifferentiation of computational graphs with tensors and more for machine learning.
Documentation

use super::*;

#[derive(Clone, Serialize, Deserialize)]
struct Matmul;

impl Operator for Matmul {
    fn forward(&mut self, node: &Node) -> Result<()> {
        let (y, _) = node.y();
        let (x1, _) = node.x(1);
        let (x2, _) = node.x(2);

        bmls::matmul(
            &x1.read(), &x2.read(), &mut y.write(),
            x1.shape2(), x2.shape2()
        )?;

        Ok(())
    }

    fn backward(&mut self, node: &Node) -> Result<()> {
        let (_, gy) = node.y();
        let (x1, g1) = node.x(1);
        let (x2, g2) = node.x(2);

        let gy = gy.read();

        bmls::matmul_wrt_a(
            &gy, &x2.read(), &mut g1.write(),
            x1.shape2(), x2.shape2(),
        )?;

        bmls::matmul_wrt_b(
            &x1.read(), &gy, &mut g2.write(),
            x1.shape2(), x2.shape2(),
        )?;

        Ok(())
    }
}

impl Display for Matmul {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        write!(f, "Matmul")
    }
}

pub fn matmul<'t>(x1: Var<'t>, x2: Var<'t>) -> Var<'t> {
    if x1.shape()[1] != x2.shape()[0] {
        panic!("X1 cols must be equal to X2 rows! X1: {}, X2: {}", x1.shape(), x2.shape())
    }

    let shape = [x1.shape()[0], x2.shape()[1], 1, 1].to_shape();

    x1.extend(
        NodeBuilder {
            op: Box::new(Matmul),
            deps: vec![x1.index, x2.index],
            shape: shape,
            skip: false,
            init: None,
            is_batched: x1.is_batched || x2.is_batched,
        }
    )
}