gad/
array.rs

1// Copyright (c) Facebook, Inc. and its affiliates
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use crate::{
5    core::{CoreAlgebra, HasDims},
6    error::Result,
7    graph::{Config1, ConfigN, Graph, Value},
8    linked::LinkedAlgebra,
9    store::GradientStore,
10};
11
12/// Array operations.
13pub trait ArrayAlgebra<Value> {
14    type Dims;
15    type Scalar;
16
17    /// Re-shape the input into a single dimension array.
18    fn flat(&mut self, v: &Value) -> Value;
19
20    /// Re-shape the input into an array of the given dimensions.
21    fn moddims(&mut self, v: &Value, dims: Self::Dims) -> Result<Value>;
22
23    /// Repeats the input to match the given shape.
24    fn tile_as(&mut self, v: &Value, dims: Self::Dims) -> Result<Value>;
25
26    /// Sums some of the dimension of the input to fit the given shape.
27    fn sum_as(&mut self, v: &Value, dims: Self::Dims) -> Result<Value>;
28
29    /// Fill an array of the given shape with the given scalar value.
30    fn constant_as(&mut self, v: &Self::Scalar, dims: Self::Dims) -> Value;
31
32    /// Read the scalar value in a one-element array.
33    fn as_scalar(&mut self, v: &Value) -> Result<Self::Scalar>;
34
35    /// Multiply the array element-wise by the given scalar.
36    fn scale(&mut self, lambda: &Self::Scalar, v: &Value) -> Value;
37
38    /// Compute the dot-product of two arrays of the same shape.
39    fn dot(&mut self, v1: &Value, v2: &Value) -> Result<Self::Scalar>;
40
41    /// Compute the L2-norm of an array.
42    fn norm2(&mut self, v: &Value) -> Self::Scalar {
43        self.dot(v, v).expect("norm2 should not fail")
44    }
45}
46
47#[cfg(feature = "arrayfire")]
48mod af_arith {
49    use crate::{
50        array::ArrayAlgebra,
51        arrayfire::Float,
52        error::{check_equal_dimensions, Error, Result},
53        Check, Eval,
54    };
55    use arrayfire as af;
56
57    impl<T> ArrayAlgebra<af::Array<T>> for Eval
58    where
59        T: Float,
60    {
61        type Dims = af::Dim4;
62        type Scalar = T;
63
64        #[inline]
65        fn flat(&mut self, v: &af::Array<T>) -> af::Array<T> {
66            af::flat(v)
67        }
68
69        #[inline]
70        fn moddims(&mut self, v: &af::Array<T>, dims: af::Dim4) -> Result<af::Array<T>> {
71            self.check().moddims(&v.dims(), dims)?;
72            Ok(af::moddims(v, dims))
73        }
74
75        #[inline]
76        fn tile_as(&mut self, v: &af::Array<T>, rdims: af::Dim4) -> Result<af::Array<T>> {
77            self.check().tile_as(&v.dims(), rdims)?;
78            let vdims = v.dims();
79            let mut tdims = [1u64; 4];
80            for i in 0..4 {
81                tdims[i] = rdims[i] / vdims[i];
82            }
83            Ok(af::tile(&v, af::Dim4::new(&tdims)))
84        }
85
86        #[inline]
87        fn sum_as(&mut self, v: &af::Array<T>, rdims: af::Dim4) -> Result<af::Array<T>> {
88            self.check().sum_as(&v.dims(), rdims)?;
89            let vdims = v.dims();
90            let mut result = v.clone();
91            for i in 0..4 {
92                if rdims[i] == vdims[i] {
93                    continue;
94                }
95                result = af::sum(&result, i as i32);
96            }
97            Ok(result)
98        }
99
100        #[inline]
101        fn constant_as(&mut self, v: &T, dims: af::Dim4) -> af::Array<T> {
102            af::constant(*v, dims)
103        }
104
105        #[inline]
106        fn as_scalar(&mut self, v: &af::Array<T>) -> Result<T> {
107            self.check().as_scalar(&v.dims())?;
108            let mut res = vec![T::zero(); 1];
109            v.host(&mut res);
110            Ok(res[0])
111        }
112
113        #[inline]
114        fn scale(&mut self, lambda: &T, v: &af::Array<T>) -> af::Array<T> {
115            v * (*lambda)
116        }
117
118        #[inline]
119        fn dot(&mut self, v1: &af::Array<T>, v2: &af::Array<T>) -> Result<T> {
120            self.check().dot(&v1.dims(), &v2.dims())?;
121            let v1 = af::flat(v1);
122            let v2 = af::flat(v2);
123            let mut res = vec![T::zero(); 1];
124            af::dot(&v1, &v2, af::MatProp::CONJ, af::MatProp::NONE).host(&mut res);
125            Ok(res[0])
126        }
127    }
128
129    impl ArrayAlgebra<af::Dim4> for Check {
130        type Dims = af::Dim4;
131        type Scalar = ();
132
133        #[inline]
134        fn flat(&mut self, v: &af::Dim4) -> af::Dim4 {
135            af::dim4!(v.elements())
136        }
137
138        #[inline]
139        fn moddims(&mut self, v: &af::Dim4, dims: af::Dim4) -> Result<af::Dim4> {
140            if v.elements() != dims.elements() {
141                Err(Error::dimensions(func_name!(), &[v, &dims]))
142            } else {
143                Ok(dims)
144            }
145        }
146
147        #[inline]
148        fn tile_as(&mut self, v: &af::Dim4, rdims: af::Dim4) -> Result<af::Dim4> {
149            let mut tdims = [1u64; 4];
150            for i in 0..4 {
151                if rdims[i] % v[i] != 0 {
152                    return Err(Error::dimensions(func_name!(), &[v, &rdims]));
153                }
154                tdims[i] = rdims[i] / v[i];
155            }
156            Ok(rdims)
157        }
158
159        #[inline]
160        fn sum_as(&mut self, v: &af::Dim4, rdims: af::Dim4) -> Result<af::Dim4> {
161            for i in 0..4 {
162                if rdims[i] == v[i] {
163                    continue;
164                }
165                if rdims[i] != 1 {
166                    return Err(Error::dimensions(func_name!(), &[v, &rdims]));
167                }
168            }
169            Ok(rdims)
170        }
171
172        #[inline]
173        fn constant_as(&mut self, _v: &(), dims: af::Dim4) -> af::Dim4 {
174            dims
175        }
176
177        #[inline]
178        fn as_scalar(&mut self, v: &af::Dim4) -> Result<()> {
179            check_equal_dimensions(func_name!(), &[v, &af::dim4!(1)])?;
180            Ok(())
181        }
182
183        #[inline]
184        fn scale(&mut self, _lambda: &(), v: &af::Dim4) -> af::Dim4 {
185            *v
186        }
187
188        #[inline]
189        fn dot(&mut self, v1: &af::Dim4, v2: &af::Dim4) -> Result<()> {
190            check_equal_dimensions(func_name!(), &[v1, v2])?;
191            Ok(())
192        }
193    }
194}
195
196macro_rules! impl_graph {
197    ($config:ident) => {
198        impl<D, E, T, Dims> ArrayAlgebra<Value<D>> for Graph<$config<E>>
199        where
200            E: Default
201                + Clone
202                + CoreAlgebra<D, Value = D>
203                + CoreAlgebra<T, Value = T>
204                + LinkedAlgebra<Value<D>, D>
205                + LinkedAlgebra<Value<T>, T>
206                + ArrayAlgebra<D, Scalar = T, Dims = Dims>,
207            Dims: PartialEq + Clone + Copy + std::fmt::Debug + Default + 'static + Send + Sync,
208            D: HasDims<Dims = Dims> + Clone + 'static + Send + Sync,
209            T: crate::Number,
210        {
211            type Dims = Dims;
212            type Scalar = Value<T>;
213
214            fn flat(&mut self, v: &Value<D>) -> Value<D> {
215                let result = self.eval().flat(v.data());
216                self.make_node(result, vec![v.input()], {
217                    let vdims = v.data().dims();
218                    let id = v.id();
219                    move |graph, store, gradient| {
220                        if let Some(id) = id {
221                            let x = graph.moddims(&gradient, vdims)?;
222                            store.add_gradient::<D, _>(graph, id, &x)?;
223                        }
224                        Ok(())
225                    }
226                })
227            }
228
229            fn moddims(&mut self, v: &Value<D>, rdims: Dims) -> Result<Value<D>> {
230                let result = self.eval().moddims(v.data(), rdims)?;
231                let value = self.make_node(result, vec![v.input()], {
232                    let vdims = v.data().dims();
233                    let id = v.id();
234                    move |graph, store, gradient| {
235                        if let Some(id) = id {
236                            let x = graph.moddims(&gradient, vdims)?;
237                            store.add_gradient::<D, _>(graph, id, &x)?;
238                        }
239                        Ok(())
240                    }
241                });
242                Ok(value)
243            }
244
245            fn tile_as(&mut self, v: &Value<D>, rdims: Dims) -> Result<Value<D>> {
246                let result = self.eval().tile_as(v.data(), rdims)?;
247                let value = self.make_node(result, vec![v.input()], {
248                    let vdims = v.data().dims();
249                    let id = v.id();
250                    move |graph, store, gradient| {
251                        if let Some(id) = id {
252                            let x = graph.sum_as(&gradient, vdims)?;
253                            store.add_gradient::<D, _>(graph, id, &x)?;
254                        }
255                        Ok(())
256                    }
257                });
258                Ok(value)
259            }
260
261            fn sum_as(&mut self, v: &Value<D>, rdims: Dims) -> Result<Value<D>> {
262                let result = self.eval().sum_as(v.data(), rdims)?;
263                let value = self.make_node(result, vec![v.input()], {
264                    let vdims = v.data().dims();
265                    let id = v.id();
266                    move |graph, store, gradient| {
267                        if let Some(id) = id {
268                            let x = graph.tile_as(&gradient, vdims)?;
269                            store.add_gradient::<D, _>(graph, id, &x)?;
270                        }
271                        Ok(())
272                    }
273                });
274                Ok(value)
275            }
276
277            fn constant_as(&mut self, v: &Value<T>, dims: Dims) -> Value<D> {
278                let result = self.eval().constant_as(v.data(), dims);
279                let value = self.make_generic_node::<T, D, _, _, _, _>(result, vec![v.input()], {
280                    let id = v.id();
281                    move |graph, store, gradient| {
282                        if let Some(id) = id {
283                            let x = graph.sum_as(&gradient, Dims::default())?;
284                            let y = graph.as_scalar(&x)?;
285                            store.add_gradient::<T, _>(graph, id, &y)?;
286                        }
287                        Ok(())
288                    }
289                });
290                value
291            }
292
293            fn as_scalar(&mut self, v: &Value<D>) -> Result<Value<T>> {
294                let result = self.eval().as_scalar(v.data())?;
295                let value = self.make_generic_node::<D, T, _, _, _, _>(result, vec![v.input()], {
296                    let vdims = v.dims();
297                    let id = v.id();
298                    move |graph, store, gradient| {
299                        if let Some(id) = id {
300                            let x = graph.constant_as(&gradient, vdims);
301                            store.add_gradient::<D, _>(graph, id, &x)?;
302                        }
303                        Ok(())
304                    }
305                });
306                Ok(value)
307            }
308
309            fn scale(&mut self, v1: &Value<T>, v2: &Value<D>) -> Value<D> {
310                let result = self.eval().scale(v1.data(), v2.data());
311                let value = self.make_node(result, vec![v1.input(), v2.input()], {
312                    let v1 = v1.clone();
313                    let v2 = v2.clone();
314                    move |graph, store, gradient| {
315                        if let Some(id) = v1.id() {
316                            let c2 = graph.link(&v2);
317                            let grad = graph.dot(&gradient, c2)?;
318                            store.add_gradient::<T, _>(graph, id, &grad)?;
319                        }
320                        if let Some(id) = v2.id() {
321                            let c1 = graph.link(&v1);
322                            let grad = graph.scale(c1, &gradient);
323                            store.add_gradient::<D, _>(graph, id, &grad)?;
324                        }
325                        Ok(())
326                    }
327                });
328                value
329            }
330
331            fn dot(&mut self, v1: &Value<D>, v2: &Value<D>) -> Result<Value<T>> {
332                let result = self.eval().dot(v1.data(), v2.data())?;
333                let value = self.make_node(result, vec![v1.input(), v2.input()], {
334                    let v1 = v1.clone();
335                    let v2 = v2.clone();
336                    move |graph, store, gradient| {
337                        if let Some(id) = v1.id() {
338                            let c2 = graph.link(&v2);
339                            let grad = graph.scale(&gradient, c2);
340                            store.add_gradient::<D, _>(graph, id, &grad)?;
341                        }
342                        if let Some(id) = v2.id() {
343                            let c1 = graph.link(&v1);
344                            let grad = graph.scale(&gradient, c1);
345                            store.add_gradient::<D, _>(graph, id, &grad)?;
346                        }
347                        Ok(())
348                    }
349                });
350                Ok(value)
351            }
352        }
353    };
354}
355
356impl_graph!(Config1);
357impl_graph!(ConfigN);