1use std::marker::PhantomData;
3use std::ops::{Add, Div, Mul, Neg, Sub};
4use num_traits::FromPrimitive;
5use crate::arr::{AsView, MakeView, SerializedVec, SerializedVecView, SliceSize};
6use crate::ope::{One, Sqrt, Sum};
7
8pub trait GraphNode<FI,FO,BI,BO> {
10 fn forward(&self,v:FI) -> FO;
15
16 fn backward(&self,d:BI) -> BO;
21}
22pub struct AddNode<U> where U: Add<Output = U> + Clone {
24 u:PhantomData<U>
25}
26impl<U> AddNode<U> where U: Add<Output = U> + Clone {
27 pub fn new() -> AddNode<U> {
28 AddNode {
29 u:PhantomData::<U>
30 }
31 }
32}
33impl<U> GraphNode<(U,U),U,U,(U,U)> for AddNode<U> where U: Add<Output = U> + Clone {
34 #[inline]
35 fn forward(&self,(l,r):(U,U)) -> U {
36 l + r
37 }
38
39 #[inline]
40 fn backward(&self,d:U) -> (U,U) {
41 (d.clone(),d)
42 }
43}
44pub struct MulNode<U> where U: Mul<Output = U> + Clone {
46 u:PhantomData<U>
47}
48impl<U> MulNode<U> where U: Mul<Output = U> + Clone {
49 pub fn new() -> MulNode<U> {
50 MulNode {
51 u:PhantomData::<U>
52 }
53 }
54}
55impl<U> GraphNode<(U,U),U,(U,U,U),(U,U)> for MulNode<U> where U: Mul<Output = U> + Clone {
56 #[inline]
57 fn forward(&self,(l,r):(U,U)) -> U {
58 l * r
59 }
60
61 #[inline]
62 fn backward(&self,(l,r,d):(U,U,U)) -> (U,U) {
63 (r * d.clone(), l * d)
64 }
65}
66pub struct BranchNode<U> where U: Add<Output = U> + Clone {
68 u:PhantomData<U>
69}
70impl<U> BranchNode<U> where U: Add<Output = U> + Clone {
71 pub fn new() -> BranchNode<U> {
72 BranchNode {
73 u:PhantomData::<U>
74 }
75 }
76}
77impl<U> GraphNode<U,(U,U),(U,U),U> for BranchNode<U> where U: Add<Output = U> + Clone {
78 #[inline]
79 fn forward(&self,v:U) -> (U,U) {
80 (v.clone(),v)
81 }
82
83 #[inline]
84 fn backward(&self,(d1,d2):(U,U)) -> U {
85 d1 + d2
86 }
87}
88pub struct SumNode<U,C> where U: Default + Clone + Send + Sync {
90 u:PhantomData<U>,
91 c:PhantomData<C>
92}
93impl<U,C> SumNode<U,C> where U: Default + Clone + Send + Sync {
94 pub fn new() -> SumNode<U,C> {
95 SumNode {
96 u:PhantomData::<U>,
97 c:PhantomData::<C>
98 }
99 }
100}
101impl<U,T> GraphNode<&SerializedVec<U,T>,T,(&T,usize),SerializedVec<U,T>> for SumNode<U,SerializedVec<U,T>>
102 where U: Default + Clone + Copy + Send + Sync + Add<Output=U> + 'static,
103 for<'a> T: SliceSize + AsView<'a> + MakeView<'a,U> + Default + Clone + Send + Sync +
104 Add<Output=T> + Add<<T as AsView<'a>>::ViewType,Output=T>,
105 for<'a> <T as AsView<'a>>::ViewType: Send,
106 SerializedVec<U,T>: From<Vec<T>> {
107 #[inline]
108 fn forward(&self,v: &SerializedVec<U,T>) -> T {
109 v.sum()
110 }
111
112 #[inline]
113 fn backward(&self,(d,n): (&T,usize)) -> SerializedVec<U,T> {
114 (0..n).map(|_| {
115 d.clone().into()
116 }).collect::<Vec<T>>().into()
117 }
118}
119impl<'a,U,T> GraphNode<SerializedVecView<'a,U,T>,T,(&T,usize),SerializedVec<U,T>> for SumNode<U,SerializedVecView<'a,U,T>>
120 where U: Default + Clone + Copy + Send + Sync + Add<Output=U> + 'static,
121 for<'b> T: SliceSize + AsView<'b> + MakeView<'b,U> +
122 Default + Clone + Send + Sync +
123 Add<Output=T> + Add<<T as AsView<'b>>::ViewType,Output=T>,
124 for<'b> <T as AsView<'b>>::ViewType: Send,
125 SerializedVec<U,T>: From<Vec<T>> {
126 #[inline]
127 fn forward(&self,v: SerializedVecView<'a,U,T>) -> T {
128 v.sum()
129 }
130
131 #[inline]
132 fn backward(&self,(d,n): (&T,usize)) -> SerializedVec<U,T> {
133 (0..n).map(|_| {
134 d.clone().into()
135 }).collect::<Vec<T>>().into()
136 }
137}
138pub struct BroadcastNode<U,C> where U: Default + Clone + Send + Sync {
140 u:PhantomData<U>,
141 c:PhantomData<C>
142}
143impl<U,C> BroadcastNode<U,C> where U: Default + Clone + Send + Sync {
144 pub fn new() -> BroadcastNode<U,C> {
145 BroadcastNode {
146 u:PhantomData::<U>,
147 c:PhantomData::<C>
148 }
149 }
150}
151impl<U,T> GraphNode<(&T,usize),SerializedVec<U,T>,&SerializedVec<U,T>,T> for BroadcastNode<U,&SerializedVec<U,T>>
152 where U: Default + Clone + Copy + Send + Sync + Add<Output=U> + 'static,
153 for<'a> T: SliceSize + AsView<'a> + MakeView<'a,U> + Default + Clone + Send + Sync +
154 Add<Output=T> + Add<<T as AsView<'a>>::ViewType,Output=T>,
155 for<'a> <T as AsView<'a>>::ViewType: Send,
156 SerializedVec<U,T>: From<Vec<T>> {
157 #[inline]
158 fn forward(&self,(v,n): (&T,usize)) -> SerializedVec<U,T> {
159 (0..n).map(|_| v.clone()).collect::<Vec<_>>().into()
160 }
161
162 #[inline]
163 fn backward(&self,d: &SerializedVec<U,T>) -> T {
164 d.sum()
165 }
166}
167impl<'b,U,T> GraphNode<(&T,usize),SerializedVec<U,T>,SerializedVecView<'b,U,T>,T> for BroadcastNode<U,SerializedVecView<'b,U,T>>
168 where U: Default + Clone + Copy + Send + Sync + Add<Output=U> + 'static,
169 for<'a> T: SliceSize + AsView<'a> + MakeView<'a,U> + Default + Clone + Send + Sync +
170 Add<Output=T> + Add<<T as AsView<'a>>::ViewType,Output=T>,
171 for<'a> <T as AsView<'a>>::ViewType: Send,
172 SerializedVec<U,T>: From<Vec<T>> {
173 #[inline]
174 fn forward(&self,(v,n): (&T,usize)) -> SerializedVec<U,T> {
175 (0..n).map(|_| v.clone()).collect::<Vec<_>>().into()
176 }
177
178 #[inline]
179 fn backward(&self,d: SerializedVecView<'b,U,T>) -> T {
180 d.sum()
181 }
182}
183pub struct ReciprocalNode<U> where U: Div + Div<Output = U> + Mul + Mul<Output = U> + Neg {
185 u:PhantomData<U>
186}
187impl<U> ReciprocalNode<U> where U: Div + Div<Output = U> + Mul + Mul<Output = U> + Neg {
188 pub fn new() -> ReciprocalNode<U> {
189 ReciprocalNode {
190 u:PhantomData::<U>
191 }
192 }
193}
194impl<U> GraphNode<U,U,U,U> for ReciprocalNode<U>
195 where U: Div + Div<Output = U> + Neg + Neg<Output = U> + One + Mul + Mul<Output = U> + One + Clone + Copy {
196 #[inline]
197 fn forward(&self,v: U) -> U {
198 U::one() / v
199 }
200
201 #[inline]
202 fn backward(&self,d: U) -> U {
203 -(U::one() / (d * d))
204 }
205}
206pub struct SqrtNode<U> where U: Sqrt + Div + Div<Output = U> + FromPrimitive {
208 u:PhantomData<U>
209}
210impl<U> SqrtNode<U> where U: Sqrt + Div + Div<Output = U> + FromPrimitive {
211 pub fn new() -> SqrtNode<U> {
212 SqrtNode {
213 u:PhantomData::<U>
214 }
215 }
216}
217impl<U> GraphNode<U,U,U,U> for SqrtNode<U>
218 where U: Sqrt + Div + Div<Output = U> + Mul + Mul<Output = U> + One + FromPrimitive {
219
220 #[inline]
221 fn forward(&self,v: U) -> U {
222 v.sqrt()
223 }
224
225 #[inline]
226 fn backward(&self,d: U) -> U {
227 U::one() / (U::from_f64(2.).expect("Error in type conversion from f64.") * d.sqrt())
228 }
229}
230pub struct SquareNode<U> where U: FromPrimitive + Mul + Mul<Output = U> {
232 u:PhantomData<U>
233}
234impl<U> SquareNode<U> where U: FromPrimitive + Mul + Mul<Output = U> {
235 pub fn new() -> SquareNode<U> {
236 SquareNode {
237 u:PhantomData::<U>
238 }
239 }
240}
241impl<U> GraphNode<U,U,(U,U),U> for SquareNode<U>
242 where U: FromPrimitive + Mul + Mul<Output = U> + Clone + Copy {
243
244 #[inline]
245 fn forward(&self,v: U) -> U {
246 v * v
247 }
248
249 #[inline]
250 fn backward(&self,(i,d): (U, U)) -> U {
251 U::from_f64(2.).expect("Error in type conversion from f64.") * i * d
252 }
253}
254pub struct SubNode<U> where U: Sub + Sub<Output = U> + Neg + Clone {
256 u:PhantomData<U>
257}
258impl<U> SubNode<U> where U: Sub + Sub<Output = U> + Neg + Clone{
259 pub fn new() -> SubNode<U> {
260 SubNode {
261 u:PhantomData::<U>
262 }
263 }
264}
265impl<U> GraphNode<(U,U),U,U,(U,U)> for SubNode<U> where U: Sub + Sub<Output = U> + Neg + Neg<Output = U> + Clone {
266 #[inline]
267 fn forward(&self,(l,r): (U, U)) -> U {
268 l - r
269 }
270
271 #[inline]
272 fn backward(&self,d: U) -> (U,U) {
273 (d.clone(),-d)
274 }
275}