1use crate::{
5 core::{CoreAlgebra, HasDims},
6 error::Result,
7 graph::{Config1, ConfigN, Graph, Value},
8 linked::LinkedAlgebra,
9 store::GradientStore,
10};
11
12pub trait ArrayAlgebra<Value> {
14 type Dims;
15 type Scalar;
16
17 fn flat(&mut self, v: &Value) -> Value;
19
20 fn moddims(&mut self, v: &Value, dims: Self::Dims) -> Result<Value>;
22
23 fn tile_as(&mut self, v: &Value, dims: Self::Dims) -> Result<Value>;
25
26 fn sum_as(&mut self, v: &Value, dims: Self::Dims) -> Result<Value>;
28
29 fn constant_as(&mut self, v: &Self::Scalar, dims: Self::Dims) -> Value;
31
32 fn as_scalar(&mut self, v: &Value) -> Result<Self::Scalar>;
34
35 fn scale(&mut self, lambda: &Self::Scalar, v: &Value) -> Value;
37
38 fn dot(&mut self, v1: &Value, v2: &Value) -> Result<Self::Scalar>;
40
41 fn norm2(&mut self, v: &Value) -> Self::Scalar {
43 self.dot(v, v).expect("norm2 should not fail")
44 }
45}
46
47#[cfg(feature = "arrayfire")]
48mod af_arith {
49 use crate::{
50 array::ArrayAlgebra,
51 arrayfire::Float,
52 error::{check_equal_dimensions, Error, Result},
53 Check, Eval,
54 };
55 use arrayfire as af;
56
57 impl<T> ArrayAlgebra<af::Array<T>> for Eval
58 where
59 T: Float,
60 {
61 type Dims = af::Dim4;
62 type Scalar = T;
63
64 #[inline]
65 fn flat(&mut self, v: &af::Array<T>) -> af::Array<T> {
66 af::flat(v)
67 }
68
69 #[inline]
70 fn moddims(&mut self, v: &af::Array<T>, dims: af::Dim4) -> Result<af::Array<T>> {
71 self.check().moddims(&v.dims(), dims)?;
72 Ok(af::moddims(v, dims))
73 }
74
75 #[inline]
76 fn tile_as(&mut self, v: &af::Array<T>, rdims: af::Dim4) -> Result<af::Array<T>> {
77 self.check().tile_as(&v.dims(), rdims)?;
78 let vdims = v.dims();
79 let mut tdims = [1u64; 4];
80 for i in 0..4 {
81 tdims[i] = rdims[i] / vdims[i];
82 }
83 Ok(af::tile(&v, af::Dim4::new(&tdims)))
84 }
85
86 #[inline]
87 fn sum_as(&mut self, v: &af::Array<T>, rdims: af::Dim4) -> Result<af::Array<T>> {
88 self.check().sum_as(&v.dims(), rdims)?;
89 let vdims = v.dims();
90 let mut result = v.clone();
91 for i in 0..4 {
92 if rdims[i] == vdims[i] {
93 continue;
94 }
95 result = af::sum(&result, i as i32);
96 }
97 Ok(result)
98 }
99
100 #[inline]
101 fn constant_as(&mut self, v: &T, dims: af::Dim4) -> af::Array<T> {
102 af::constant(*v, dims)
103 }
104
105 #[inline]
106 fn as_scalar(&mut self, v: &af::Array<T>) -> Result<T> {
107 self.check().as_scalar(&v.dims())?;
108 let mut res = vec![T::zero(); 1];
109 v.host(&mut res);
110 Ok(res[0])
111 }
112
113 #[inline]
114 fn scale(&mut self, lambda: &T, v: &af::Array<T>) -> af::Array<T> {
115 v * (*lambda)
116 }
117
118 #[inline]
119 fn dot(&mut self, v1: &af::Array<T>, v2: &af::Array<T>) -> Result<T> {
120 self.check().dot(&v1.dims(), &v2.dims())?;
121 let v1 = af::flat(v1);
122 let v2 = af::flat(v2);
123 let mut res = vec![T::zero(); 1];
124 af::dot(&v1, &v2, af::MatProp::CONJ, af::MatProp::NONE).host(&mut res);
125 Ok(res[0])
126 }
127 }
128
129 impl ArrayAlgebra<af::Dim4> for Check {
130 type Dims = af::Dim4;
131 type Scalar = ();
132
133 #[inline]
134 fn flat(&mut self, v: &af::Dim4) -> af::Dim4 {
135 af::dim4!(v.elements())
136 }
137
138 #[inline]
139 fn moddims(&mut self, v: &af::Dim4, dims: af::Dim4) -> Result<af::Dim4> {
140 if v.elements() != dims.elements() {
141 Err(Error::dimensions(func_name!(), &[v, &dims]))
142 } else {
143 Ok(dims)
144 }
145 }
146
147 #[inline]
148 fn tile_as(&mut self, v: &af::Dim4, rdims: af::Dim4) -> Result<af::Dim4> {
149 let mut tdims = [1u64; 4];
150 for i in 0..4 {
151 if rdims[i] % v[i] != 0 {
152 return Err(Error::dimensions(func_name!(), &[v, &rdims]));
153 }
154 tdims[i] = rdims[i] / v[i];
155 }
156 Ok(rdims)
157 }
158
159 #[inline]
160 fn sum_as(&mut self, v: &af::Dim4, rdims: af::Dim4) -> Result<af::Dim4> {
161 for i in 0..4 {
162 if rdims[i] == v[i] {
163 continue;
164 }
165 if rdims[i] != 1 {
166 return Err(Error::dimensions(func_name!(), &[v, &rdims]));
167 }
168 }
169 Ok(rdims)
170 }
171
172 #[inline]
173 fn constant_as(&mut self, _v: &(), dims: af::Dim4) -> af::Dim4 {
174 dims
175 }
176
177 #[inline]
178 fn as_scalar(&mut self, v: &af::Dim4) -> Result<()> {
179 check_equal_dimensions(func_name!(), &[v, &af::dim4!(1)])?;
180 Ok(())
181 }
182
183 #[inline]
184 fn scale(&mut self, _lambda: &(), v: &af::Dim4) -> af::Dim4 {
185 *v
186 }
187
188 #[inline]
189 fn dot(&mut self, v1: &af::Dim4, v2: &af::Dim4) -> Result<()> {
190 check_equal_dimensions(func_name!(), &[v1, v2])?;
191 Ok(())
192 }
193 }
194}
195
196macro_rules! impl_graph {
197 ($config:ident) => {
198 impl<D, E, T, Dims> ArrayAlgebra<Value<D>> for Graph<$config<E>>
199 where
200 E: Default
201 + Clone
202 + CoreAlgebra<D, Value = D>
203 + CoreAlgebra<T, Value = T>
204 + LinkedAlgebra<Value<D>, D>
205 + LinkedAlgebra<Value<T>, T>
206 + ArrayAlgebra<D, Scalar = T, Dims = Dims>,
207 Dims: PartialEq + Clone + Copy + std::fmt::Debug + Default + 'static + Send + Sync,
208 D: HasDims<Dims = Dims> + Clone + 'static + Send + Sync,
209 T: crate::Number,
210 {
211 type Dims = Dims;
212 type Scalar = Value<T>;
213
214 fn flat(&mut self, v: &Value<D>) -> Value<D> {
215 let result = self.eval().flat(v.data());
216 self.make_node(result, vec![v.input()], {
217 let vdims = v.data().dims();
218 let id = v.id();
219 move |graph, store, gradient| {
220 if let Some(id) = id {
221 let x = graph.moddims(&gradient, vdims)?;
222 store.add_gradient::<D, _>(graph, id, &x)?;
223 }
224 Ok(())
225 }
226 })
227 }
228
229 fn moddims(&mut self, v: &Value<D>, rdims: Dims) -> Result<Value<D>> {
230 let result = self.eval().moddims(v.data(), rdims)?;
231 let value = self.make_node(result, vec![v.input()], {
232 let vdims = v.data().dims();
233 let id = v.id();
234 move |graph, store, gradient| {
235 if let Some(id) = id {
236 let x = graph.moddims(&gradient, vdims)?;
237 store.add_gradient::<D, _>(graph, id, &x)?;
238 }
239 Ok(())
240 }
241 });
242 Ok(value)
243 }
244
245 fn tile_as(&mut self, v: &Value<D>, rdims: Dims) -> Result<Value<D>> {
246 let result = self.eval().tile_as(v.data(), rdims)?;
247 let value = self.make_node(result, vec![v.input()], {
248 let vdims = v.data().dims();
249 let id = v.id();
250 move |graph, store, gradient| {
251 if let Some(id) = id {
252 let x = graph.sum_as(&gradient, vdims)?;
253 store.add_gradient::<D, _>(graph, id, &x)?;
254 }
255 Ok(())
256 }
257 });
258 Ok(value)
259 }
260
261 fn sum_as(&mut self, v: &Value<D>, rdims: Dims) -> Result<Value<D>> {
262 let result = self.eval().sum_as(v.data(), rdims)?;
263 let value = self.make_node(result, vec![v.input()], {
264 let vdims = v.data().dims();
265 let id = v.id();
266 move |graph, store, gradient| {
267 if let Some(id) = id {
268 let x = graph.tile_as(&gradient, vdims)?;
269 store.add_gradient::<D, _>(graph, id, &x)?;
270 }
271 Ok(())
272 }
273 });
274 Ok(value)
275 }
276
277 fn constant_as(&mut self, v: &Value<T>, dims: Dims) -> Value<D> {
278 let result = self.eval().constant_as(v.data(), dims);
279 let value = self.make_generic_node::<T, D, _, _, _, _>(result, vec![v.input()], {
280 let id = v.id();
281 move |graph, store, gradient| {
282 if let Some(id) = id {
283 let x = graph.sum_as(&gradient, Dims::default())?;
284 let y = graph.as_scalar(&x)?;
285 store.add_gradient::<T, _>(graph, id, &y)?;
286 }
287 Ok(())
288 }
289 });
290 value
291 }
292
293 fn as_scalar(&mut self, v: &Value<D>) -> Result<Value<T>> {
294 let result = self.eval().as_scalar(v.data())?;
295 let value = self.make_generic_node::<D, T, _, _, _, _>(result, vec![v.input()], {
296 let vdims = v.dims();
297 let id = v.id();
298 move |graph, store, gradient| {
299 if let Some(id) = id {
300 let x = graph.constant_as(&gradient, vdims);
301 store.add_gradient::<D, _>(graph, id, &x)?;
302 }
303 Ok(())
304 }
305 });
306 Ok(value)
307 }
308
309 fn scale(&mut self, v1: &Value<T>, v2: &Value<D>) -> Value<D> {
310 let result = self.eval().scale(v1.data(), v2.data());
311 let value = self.make_node(result, vec![v1.input(), v2.input()], {
312 let v1 = v1.clone();
313 let v2 = v2.clone();
314 move |graph, store, gradient| {
315 if let Some(id) = v1.id() {
316 let c2 = graph.link(&v2);
317 let grad = graph.dot(&gradient, c2)?;
318 store.add_gradient::<T, _>(graph, id, &grad)?;
319 }
320 if let Some(id) = v2.id() {
321 let c1 = graph.link(&v1);
322 let grad = graph.scale(c1, &gradient);
323 store.add_gradient::<D, _>(graph, id, &grad)?;
324 }
325 Ok(())
326 }
327 });
328 value
329 }
330
331 fn dot(&mut self, v1: &Value<D>, v2: &Value<D>) -> Result<Value<T>> {
332 let result = self.eval().dot(v1.data(), v2.data())?;
333 let value = self.make_node(result, vec![v1.input(), v2.input()], {
334 let v1 = v1.clone();
335 let v2 = v2.clone();
336 move |graph, store, gradient| {
337 if let Some(id) = v1.id() {
338 let c2 = graph.link(&v2);
339 let grad = graph.scale(&gradient, c2);
340 store.add_gradient::<D, _>(graph, id, &grad)?;
341 }
342 if let Some(id) = v2.id() {
343 let c1 = graph.link(&v1);
344 let grad = graph.scale(&gradient, c1);
345 store.add_gradient::<D, _>(graph, id, &grad)?;
346 }
347 Ok(())
348 }
349 });
350 Ok(value)
351 }
352 }
353 };
354}
355
356impl_graph!(Config1);
357impl_graph!(ConfigN);