acme_tensor/ops/
backprop.rs1use super::TensorExpr;
6use crate::TensorBase;
7use acme::prelude::BinaryOp;
8use core::borrow::Borrow;
9use core::ops::{Deref, DerefMut};
10
11pub trait TensorOp {
12 type Output;
13
14 fn name(&self) -> &str;
15}
16
17#[derive(Clone, Debug, Eq, Hash, PartialEq)]
18#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
19pub struct BackpropOp<A = f64, B = A>(Option<TensorExpr<A, B>>);
20
21impl<A, B> BackpropOp<A, B> {
22 pub fn new(op: TensorExpr<A, B>) -> Self {
23 BackpropOp(Some(op))
24 }
25
26 pub fn none() -> Self {
27 BackpropOp(None)
28 }
29
30 pub fn binary(lhs: TensorBase<A>, rhs: TensorBase<B>, kind: BinaryOp) -> Self {
31 BackpropOp(Some(TensorExpr::binary(lhs, rhs, kind)))
32 }
33
34 pub fn is_none(&self) -> bool {
35 self.0.is_none()
36 }
37
38 pub fn op(&self) -> Option<&TensorExpr<A, B>> {
39 self.0.as_ref()
40 }
41
42 pub fn op_mut(&mut self) -> Option<&mut TensorExpr<A, B>> {
43 self.0.as_mut()
44 }
45
46 pub fn into_inner(self) -> Option<TensorExpr<A, B>> {
47 self.0
48 }
49
50 pub fn take(&mut self) -> Option<TensorExpr<A, B>> {
51 self.0.take()
52 }
53
54 pub fn view(&self) -> BackpropOp<&A, &B> {
55 BackpropOp(self.0.as_ref().map(|op| op.view()))
56 }
57
58 pub fn view_mut(&mut self) -> BackpropOp<&mut A, &mut B> {
59 BackpropOp(self.0.as_mut().map(|op| op.view_mut()))
60 }
61}
62
63impl<S, T> Borrow<Option<TensorExpr<S, T>>> for BackpropOp<S, T> {
64 fn borrow(&self) -> &Option<TensorExpr<S, T>> {
65 &self.0
66 }
67}
68
69impl<T> Default for BackpropOp<T> {
70 fn default() -> Self {
71 Self::none()
72 }
73}
74
75impl<T> Deref for BackpropOp<T> {
76 type Target = Option<TensorExpr<T>>;
77
78 fn deref(&self) -> &Self::Target {
79 &self.0
80 }
81}
82
83impl<T> DerefMut for BackpropOp<T> {
84 fn deref_mut(&mut self) -> &mut Self::Target {
85 &mut self.0
86 }
87}
88
89impl<T> From<Option<TensorExpr<T>>> for BackpropOp<T> {
90 fn from(op: Option<TensorExpr<T>>) -> Self {
91 BackpropOp(op)
92 }
93}
94
95impl<T> From<TensorExpr<T>> for BackpropOp<T> {
96 fn from(op: TensorExpr<T>) -> Self {
97 BackpropOp(Some(op))
98 }
99}
100
101impl<T> From<BackpropOp<T>> for Option<TensorExpr<T>> {
102 fn from(op: BackpropOp<T>) -> Option<TensorExpr<T>> {
103 op.into_inner()
104 }
105}