acme_tensor/ops/
backprop.rs

1/*
2    Appellation: backprop <mod>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use super::TensorExpr;
6use crate::TensorBase;
7use acme::prelude::BinaryOp;
8use core::borrow::Borrow;
9use core::ops::{Deref, DerefMut};
10
11pub trait TensorOp {
12    type Output;
13
14    fn name(&self) -> &str;
15}
16
17#[derive(Clone, Debug, Eq, Hash, PartialEq)]
18#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
19pub struct BackpropOp<A = f64, B = A>(Option<TensorExpr<A, B>>);
20
21impl<A, B> BackpropOp<A, B> {
22    pub fn new(op: TensorExpr<A, B>) -> Self {
23        BackpropOp(Some(op))
24    }
25
26    pub fn none() -> Self {
27        BackpropOp(None)
28    }
29
30    pub fn binary(lhs: TensorBase<A>, rhs: TensorBase<B>, kind: BinaryOp) -> Self {
31        BackpropOp(Some(TensorExpr::binary(lhs, rhs, kind)))
32    }
33
34    pub fn is_none(&self) -> bool {
35        self.0.is_none()
36    }
37
38    pub fn op(&self) -> Option<&TensorExpr<A, B>> {
39        self.0.as_ref()
40    }
41
42    pub fn op_mut(&mut self) -> Option<&mut TensorExpr<A, B>> {
43        self.0.as_mut()
44    }
45
46    pub fn into_inner(self) -> Option<TensorExpr<A, B>> {
47        self.0
48    }
49
50    pub fn take(&mut self) -> Option<TensorExpr<A, B>> {
51        self.0.take()
52    }
53
54    pub fn view(&self) -> BackpropOp<&A, &B> {
55        BackpropOp(self.0.as_ref().map(|op| op.view()))
56    }
57
58    pub fn view_mut(&mut self) -> BackpropOp<&mut A, &mut B> {
59        BackpropOp(self.0.as_mut().map(|op| op.view_mut()))
60    }
61}
62
63impl<S, T> Borrow<Option<TensorExpr<S, T>>> for BackpropOp<S, T> {
64    fn borrow(&self) -> &Option<TensorExpr<S, T>> {
65        &self.0
66    }
67}
68
69impl<T> Default for BackpropOp<T> {
70    fn default() -> Self {
71        Self::none()
72    }
73}
74
75impl<T> Deref for BackpropOp<T> {
76    type Target = Option<TensorExpr<T>>;
77
78    fn deref(&self) -> &Self::Target {
79        &self.0
80    }
81}
82
83impl<T> DerefMut for BackpropOp<T> {
84    fn deref_mut(&mut self) -> &mut Self::Target {
85        &mut self.0
86    }
87}
88
89impl<T> From<Option<TensorExpr<T>>> for BackpropOp<T> {
90    fn from(op: Option<TensorExpr<T>>) -> Self {
91        BackpropOp(op)
92    }
93}
94
95impl<T> From<TensorExpr<T>> for BackpropOp<T> {
96    fn from(op: TensorExpr<T>) -> Self {
97        BackpropOp(Some(op))
98    }
99}
100
101impl<T> From<BackpropOp<T>> for Option<TensorExpr<T>> {
102    fn from(op: BackpropOp<T>) -> Option<TensorExpr<T>> {
103        op.into_inner()
104    }
105}