burn-tensor 0.1.0

This library provides multiple tensor implementations hidden behind an easy to use API that supports reverse mode automatic differentiation.
#[macro_export]
macro_rules! define_ops {
    (
        name $name:ident
    ) => {
        #[derive(Debug)]
        struct $name<P, const D: usize> {
            _kind: $crate::tensor::backend::autodiff::ADKind<P>,
        }

        impl<P: Default, const D: usize> $name<P, D> {
            pub fn new() -> Self {
                Self {
                    _kind: $crate::tensor::backend::autodiff::ADKind::new(),
                }
            }
        }
    };
    (
        name $name:ident,
        state $state_ident:ident,
    ) => {
        #[derive(Debug)]
        struct $name<P, const D: usize> {
            pub state: $state_ident,
            _kind: $crate::tensor::backend::autodiff::ADKind<P>,
        }

        impl<P: Default, const D: usize> $name<P, D> {
            pub fn new(value: $state_ident) -> Self {
                Self {
                    state: value,
                    _kind: $crate::tensor::backend::autodiff::ADKind::new(),
                }
            }
        }
    };
}

#[macro_export]
macro_rules! register_ops {
    (
        ops $ops:ty,
        name $name:ident,
        partial_left $partial_left:expr,
        partial_right $partial_right:expr,
    ) => {
        $crate::define_ops!(
            name $name
        );

        impl<T, P, const D: usize> $ops for $name<P, D>
        where
            P: $crate::tensor::Element,
            T: $crate::tensor::Tensor<P, D>,
        {
            fn partial_left(&self, state: &$crate::graph::ops::BinaryOpsNodeState<T, T, T>) -> T {
                $partial_left(state)
            }

            fn partial_right(&self, state: &$crate::graph::ops::BinaryOpsNodeState<T, T, T>) -> T {
                $partial_right(state)
            }
        }
    };
    (
        ops $ops:ty,
        name $name:ident state $ops_tensor_state:ident,
        partial $partial:expr,
    ) => {
        $crate::define_ops!(
            name $name,
            state $ops_tensor_state,
        );

        impl<T, P, const D: usize> $ops for $name<P, D>
        where
            P: $crate::tensor::Element,
            T: $crate::tensor::Tensor<P, D>,
        {
            fn partial(&self, state: &$crate::graph::ops::UnaryOpsNodeState<T, T>) -> T {
                $partial(self.state, state)
            }
        }
    };
    (
        ops $ops:ty,
        name $name:ident,
        partial $partial:expr,
    ) => {
        $crate::define_ops!(
            name $name
        );

        impl<T, P, const D: usize> $ops for $name<P, D>
        where
            P: $crate::tensor::Element,
            T: $crate::tensor::Tensor<P, D>,
        {
            fn partial(&self, state: &$crate::graph::ops::UnaryOpsNodeState<T, T>) -> T {
                $partial(state)
            }
        }
    }

}

#[macro_export]
macro_rules! execute_ops {
    (
        lhs $lhs:expr,
        rhs $rhs:expr,
        out $out:expr,
        ops $ops:expr,
    ) => {{
        let callback = || {
            let state = $crate::graph::node::ForwardNodeState::new($out);

            let ops = std::sync::Arc::new($ops);
            let ops = $crate::graph::ops::ForwardBinaryRecordedOps::new($lhs, $rhs, ops.clone());
            let ops = std::sync::Arc::new(ops);

            let node = $crate::graph::node::ForwardNode::from_binary(&$lhs, &$rhs, state, ops);
            std::sync::Arc::new(node)
        };
        callback()
    }};
    (
        input $input:expr,
        out $out:expr,
        ops $ops:expr,
    ) => {{
        let callback = || {
            let state = $crate::graph::node::ForwardNodeState::new($out);

            let ops = std::sync::Arc::new($ops);
            let ops = $crate::graph::ops::ForwardUnaryRecordedOps::new($input, ops.clone());
            let ops = std::sync::Arc::new(ops);

            let node = $crate::graph::node::ForwardNode::from_unary(&$input, state, ops);
            std::sync::Arc::new(node)
        };
        callback()
    }};
    (
        init $out:expr
    ) => {{
        let callback = || {
            let state = $crate::graph::node::ForwardNodeState::new($out);

            let ops = $crate::graph::ops::InitRecordedOps::new();
            let ops = std::sync::Arc::new(ops);

            let node = $crate::graph::node::ForwardNode::from_root(state, ops);
            std::sync::Arc::new(node)
        };
        callback()
    }};
}