nncombinator/
computational_graph.rs

1//! Computational graph implementation
2use std::marker::PhantomData;
3use std::ops::{Add, Div, Mul, Neg, Sub};
4use num_traits::FromPrimitive;
5use crate::arr::{AsView, MakeView, SerializedVec, SerializedVecView, SliceSize};
6use crate::ope::{One, Sqrt, Sum};
7
8/// Trait that defines a computational graph for calculating forward and back propagation of a neural network
9pub trait GraphNode<FI,FO,BI,BO> {
10    /// Forward propagation calculation
11    /// # Arguments
12    /// * `v` - forward input value.
13    ///
14    fn forward(&self,v:FI) -> FO;
15
16    /// Back propagation calculation
17    /// # Arguments
18    /// * `d` - backward input value.
19    ///
20    fn backward(&self,d:BI) -> BO;
21}
22/// Implementation of additive nodes
23pub struct AddNode<U> where U: Add<Output = U> + Clone {
24    u:PhantomData<U>
25}
26impl<U> AddNode<U> where U: Add<Output = U> + Clone {
27    pub fn new() -> AddNode<U> {
28        AddNode {
29            u:PhantomData::<U>
30        }
31    }
32}
33impl<U> GraphNode<(U,U),U,U,(U,U)> for AddNode<U> where U: Add<Output = U> + Clone {
34    #[inline]
35    fn forward(&self,(l,r):(U,U)) -> U {
36        l + r
37    }
38
39    #[inline]
40    fn backward(&self,d:U) -> (U,U) {
41        (d.clone(),d)
42    }
43}
44/// Multiplication node implementation
45pub struct MulNode<U> where U: Mul<Output = U> + Clone {
46    u:PhantomData<U>
47}
48impl<U> MulNode<U> where U: Mul<Output = U> + Clone {
49    pub fn new() -> MulNode<U> {
50        MulNode {
51            u:PhantomData::<U>
52        }
53    }
54}
55impl<U> GraphNode<(U,U),U,(U,U,U),(U,U)> for MulNode<U> where U: Mul<Output = U> + Clone {
56    #[inline]
57    fn forward(&self,(l,r):(U,U)) -> U {
58        l * r
59    }
60
61    #[inline]
62    fn backward(&self,(l,r,d):(U,U,U)) -> (U,U) {
63        (r * d.clone(), l * d)
64    }
65}
66/// Branch node implementation
67pub struct BranchNode<U> where U: Add<Output = U> + Clone {
68    u:PhantomData<U>
69}
70impl<U> BranchNode<U> where U: Add<Output = U> + Clone {
71    pub fn new() -> BranchNode<U> {
72        BranchNode {
73            u:PhantomData::<U>
74        }
75    }
76}
77impl<U> GraphNode<U,(U,U),(U,U),U> for BranchNode<U> where U: Add<Output = U> + Clone {
78    #[inline]
79    fn forward(&self,v:U) -> (U,U) {
80        (v.clone(),v)
81    }
82
83    #[inline]
84    fn backward(&self,(d1,d2):(U,U)) -> U {
85        d1 + d2
86    }
87}
88/// Sum node implementation
89pub struct SumNode<U,C> where U: Default + Clone + Send + Sync {
90    u:PhantomData<U>,
91    c:PhantomData<C>
92}
93impl<U,C> SumNode<U,C> where U: Default + Clone + Send + Sync {
94    pub fn new() -> SumNode<U,C> {
95        SumNode {
96            u:PhantomData::<U>,
97            c:PhantomData::<C>
98        }
99    }
100}
101impl<U,T> GraphNode<&SerializedVec<U,T>,T,(&T,usize),SerializedVec<U,T>> for SumNode<U,SerializedVec<U,T>>
102    where U: Default + Clone + Copy + Send + Sync + Add<Output=U> + 'static,
103          for<'a> T: SliceSize + AsView<'a> + MakeView<'a,U> + Default + Clone + Send + Sync +
104                  Add<Output=T> + Add<<T as AsView<'a>>::ViewType,Output=T>,
105          for<'a> <T as AsView<'a>>::ViewType: Send,
106          SerializedVec<U,T>: From<Vec<T>> {
107    #[inline]
108    fn forward(&self,v: &SerializedVec<U,T>) -> T {
109        v.sum()
110    }
111
112    #[inline]
113    fn backward(&self,(d,n): (&T,usize)) -> SerializedVec<U,T> {
114        (0..n).map(|_| {
115            d.clone().into()
116        }).collect::<Vec<T>>().into()
117    }
118}
119impl<'a,U,T> GraphNode<SerializedVecView<'a,U,T>,T,(&T,usize),SerializedVec<U,T>> for SumNode<U,SerializedVecView<'a,U,T>>
120    where U: Default + Clone + Copy + Send + Sync + Add<Output=U> + 'static,
121          for<'b> T: SliceSize + AsView<'b> + MakeView<'b,U> +
122                  Default + Clone + Send + Sync +
123                  Add<Output=T> + Add<<T as AsView<'b>>::ViewType,Output=T>,
124          for<'b> <T as AsView<'b>>::ViewType: Send,
125          SerializedVec<U,T>: From<Vec<T>> {
126    #[inline]
127    fn forward(&self,v: SerializedVecView<'a,U,T>) -> T {
128        v.sum()
129    }
130
131    #[inline]
132    fn backward(&self,(d,n): (&T,usize)) -> SerializedVec<U,T> {
133        (0..n).map(|_| {
134            d.clone().into()
135        }).collect::<Vec<T>>().into()
136    }
137}
138/// Broadcast node implementation
139pub struct BroadcastNode<U,C> where U: Default + Clone + Send + Sync {
140    u:PhantomData<U>,
141    c:PhantomData<C>
142}
143impl<U,C> BroadcastNode<U,C> where U: Default + Clone + Send + Sync {
144    pub fn new() -> BroadcastNode<U,C> {
145        BroadcastNode {
146            u:PhantomData::<U>,
147            c:PhantomData::<C>
148        }
149    }
150}
151impl<U,T> GraphNode<(&T,usize),SerializedVec<U,T>,&SerializedVec<U,T>,T> for BroadcastNode<U,&SerializedVec<U,T>>
152    where U: Default + Clone + Copy + Send + Sync + Add<Output=U> + 'static,
153          for<'a> T: SliceSize + AsView<'a> + MakeView<'a,U> + Default + Clone + Send + Sync +
154                  Add<Output=T> + Add<<T as AsView<'a>>::ViewType,Output=T>,
155          for<'a> <T as AsView<'a>>::ViewType: Send,
156          SerializedVec<U,T>: From<Vec<T>> {
157    #[inline]
158    fn forward(&self,(v,n): (&T,usize)) -> SerializedVec<U,T> {
159        (0..n).map(|_| v.clone()).collect::<Vec<_>>().into()
160    }
161
162    #[inline]
163    fn backward(&self,d: &SerializedVec<U,T>) -> T {
164        d.sum()
165    }
166}
167impl<'b,U,T> GraphNode<(&T,usize),SerializedVec<U,T>,SerializedVecView<'b,U,T>,T> for BroadcastNode<U,SerializedVecView<'b,U,T>>
168    where U: Default + Clone + Copy + Send + Sync + Add<Output=U> + 'static,
169          for<'a> T: SliceSize + AsView<'a> + MakeView<'a,U> + Default + Clone + Send + Sync +
170                  Add<Output=T> + Add<<T as AsView<'a>>::ViewType,Output=T>,
171          for<'a> <T as AsView<'a>>::ViewType: Send,
172          SerializedVec<U,T>: From<Vec<T>> {
173    #[inline]
174    fn forward(&self,(v,n): (&T,usize)) -> SerializedVec<U,T> {
175        (0..n).map(|_| v.clone()).collect::<Vec<_>>().into()
176    }
177
178    #[inline]
179    fn backward(&self,d: SerializedVecView<'b,U,T>) -> T {
180        d.sum()
181    }
182}
183/// Implementation of reciprocal nodes
184pub struct ReciprocalNode<U> where U: Div + Div<Output = U> + Mul + Mul<Output = U> + Neg {
185    u:PhantomData<U>
186}
187impl<U> ReciprocalNode<U> where U: Div + Div<Output = U> + Mul + Mul<Output = U> + Neg {
188    pub fn new() -> ReciprocalNode<U> {
189        ReciprocalNode {
190            u:PhantomData::<U>
191        }
192    }
193}
194impl<U> GraphNode<U,U,U,U> for ReciprocalNode<U>
195    where U: Div + Div<Output = U> + Neg + Neg<Output = U> + One + Mul + Mul<Output = U> + One + Clone + Copy {
196    #[inline]
197    fn forward(&self,v: U) -> U {
198        U::one() / v
199    }
200
201    #[inline]
202    fn backward(&self,d: U) -> U {
203        -(U::one() / (d * d))
204    }
205}
206/// Square root node implementation
207pub struct SqrtNode<U> where U: Sqrt + Div + Div<Output = U> + FromPrimitive {
208    u:PhantomData<U>
209}
210impl<U> SqrtNode<U> where U: Sqrt + Div + Div<Output = U> + FromPrimitive {
211    pub fn new() -> SqrtNode<U> {
212        SqrtNode {
213            u:PhantomData::<U>
214        }
215    }
216}
217impl<U> GraphNode<U,U,U,U> for SqrtNode<U>
218    where U: Sqrt + Div + Div<Output = U> + Mul + Mul<Output = U> + One + FromPrimitive {
219
220    #[inline]
221    fn forward(&self,v: U) -> U {
222        v.sqrt()
223    }
224
225    #[inline]
226    fn backward(&self,d: U) -> U {
227        U::one() / (U::from_f64(2.).expect("Error in type conversion from f64.") * d.sqrt())
228    }
229}
230/// Squared node implementation
231pub struct SquareNode<U> where U: FromPrimitive + Mul + Mul<Output = U> {
232    u:PhantomData<U>
233}
234impl<U> SquareNode<U> where U: FromPrimitive + Mul + Mul<Output = U> {
235    pub fn new() -> SquareNode<U> {
236        SquareNode {
237            u:PhantomData::<U>
238        }
239    }
240}
241impl<U> GraphNode<U,U,(U,U),U> for SquareNode<U>
242    where U: FromPrimitive + Mul + Mul<Output = U> + Clone + Copy {
243
244    #[inline]
245    fn forward(&self,v: U) -> U {
246        v * v
247    }
248
249    #[inline]
250    fn backward(&self,(i,d): (U, U)) -> U {
251        U::from_f64(2.).expect("Error in type conversion from f64.") * i * d
252    }
253}
254/// Implementation of negative additive nodes
255pub struct SubNode<U> where U: Sub + Sub<Output = U> + Neg + Clone {
256    u:PhantomData<U>
257}
258impl<U> SubNode<U> where U: Sub + Sub<Output = U> + Neg + Clone{
259    pub fn new() -> SubNode<U> {
260        SubNode {
261            u:PhantomData::<U>
262        }
263    }
264}
265impl<U> GraphNode<(U,U),U,U,(U,U)> for SubNode<U> where U: Sub + Sub<Output = U> + Neg + Neg<Output = U> + Clone {
266    #[inline]
267    fn forward(&self,(l,r): (U, U)) -> U {
268        l - r
269    }
270
271    #[inline]
272    fn backward(&self,d: U) -> (U,U) {
273        (d.clone(),-d)
274    }
275}