gtensor 1.0.0

Reverse-mode autodifferentiation of computational graphs with tensors and more for machine learning.
Documentation

use super::*;

#[derive(Clone, Serialize, Deserialize)]
struct Flatten;

impl Operator for Flatten {
    fn forward(&mut self, _: &Node) -> Result<()> {
        Ok(())
    }

    fn backward(&mut self, _: &Node) -> Result<()> {
        Ok(())
    }
}

impl Display for Flatten {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        write!(f, "Flatten")
    }
}

pub fn flatten<'t>(x: Var<'t>) -> Var<'t> {
    let mut shape = x.shape();
    shape[1] = shape[1]*shape[2]*shape[3];

    if x.is_batched {
        shape = shape.add_batch(1);
    }

    if x.shape().len() != shape.len() {
        panic!("New shape len does not match old shape len! New: {}, Old: {}", shape, x.shape())
    }

    x.extend(NodeBuilder {
        op: Box::new(Flatten),
        deps: vec![x.index],
        shape: shape,
        skip: true,
        init: None,
        is_batched: x.is_batched,
    })
}