1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
use crate::ndarray_ext::{NdArray, NdArrayView};
use crate::op;
use crate::tensor::Tensor;
use crate::Float;
use crate::Graph;
/// Implement +, -, *, / operators for Tensor
/// +=, -=, *=, /= are provided as methods of c.inplace_*.
/// *=, /= don't propagate gradients.
use ndarray;
use std::mem;

pub struct AddOp;
pub struct SubOp;
pub struct MulOp;
pub struct DivOp;
pub struct MaybeReduce;
pub struct MaybeBroadcast;

#[cfg(feature = "mkl")]
macro_rules! bin_op_same_shape {
    ($vms_op:ident, $vmd_op:ident, $std_op:tt, $a:expr, $b:expr) => {
        unsafe {
            if same_type::<T, f32>() {
                let mut y = Vec::with_capacity($a.len());
                $vms_op($a.len() as MklInt, $a.as_ptr() as *const f32, $b.as_ptr() as *const f32, y.as_mut_ptr() as *mut f32);
                y.set_len($a.len());
                NdArray::from_shape_vec_unchecked($a.shape(), y)
            } else if same_type::<T, f64>() {
                let mut y = Vec::with_capacity($a.len());
                $vmd_op($a.len() as MklInt, $a.as_ptr() as *const f64, $b.as_ptr() as *const f64, y.as_mut_ptr() as *mut f64);
                y.set_len($a.len());
                NdArray::from_shape_vec_unchecked($a.shape(), y)
            } else {
                $a $std_op $b
            }
        }
    };
}

impl<T: Float> op::Op<T> for MaybeReduce {
    fn compute(&self, ctx: &mut crate::op::ComputeContext<T>) {
        let input = ctx.input(0);
        let target_shape__ = crate::ndarray_ext::as_shape(&ctx.input(1));
        let target_shape = target_shape__.as_slice();
        let input_shape = input.shape();

        if target_shape == input_shape {
            // The case where forward path didn't cause broadcast.
            ctx.append_output_view(input.clone());
        } else {
            // Broadcast occurred. We need reduction of `input`.
            // First, handle the case where `input` is scalar.
            let target_shape_is_scalar = crate::ndarray_ext::is_scalar_shape(target_shape);
            let target_shape_ = if target_shape_is_scalar {
                vec![1; input_shape.len()]
            } else {
                target_shape.to_vec()
            };

            // Reduce each dim as necessary
            let mut folded: Option<NdArray<T>> = None;
            for (i, (x_axis, gy_axis)) in target_shape_.iter().zip(input_shape).enumerate() {
                if x_axis < gy_axis {
                    if *x_axis == 1 {
                        // `fold_axis` squashes the axis automatically.
                        let axis = ndarray::Axis(if target_shape_is_scalar { 0 } else { i });
                        let ret = match folded {
                            Some(ref a) => a.fold_axis(axis.clone(), T::zero(), |&a, &b| a + b),
                            None => input.fold_axis(axis.clone(), T::zero(), |&a, &b| a + b),
                        };
                        if target_shape_is_scalar {
                            mem::swap(&mut folded, &mut Some(ret));
                        } else {
                            // Expands squashed axis.
                            mem::swap(
                                &mut folded,
                                &mut Some(crate::ndarray_ext::expand_dims(ret, i)),
                            );
                        }
                    } else {
                        ctx.set_error(op::OpError::IncompatibleShape(
                            "Incorrect gradient shape".to_string(),
                        ));
                        return;
                    }
                }
                // case of x_axis < gy_axis: unreachable
                // case of x_axis == gy_axis: nothing to do
            }
            // TODO
            let ret = folded.unwrap();
            debug_assert_eq!(target_shape, ret.shape());
            ctx.append_output(ret);
        };
    }

    fn grad(&self, ctx: &mut crate::op::GradientContext<T>) {
        let g = ctx.graph();
        let gx = Tensor::builder()
            .set_ro_inputs(&[&ctx.output_grad(), &g.shape(ctx.input(0))])
            .build(g, MaybeBroadcast);
        ctx.append_input_grad(Some(gx));
        ctx.append_input_grad(None);
    }
}

// Do broadcast if necessary.
impl<T: Float> op::Op<T> for MaybeBroadcast {
    fn compute(&self, ctx: &mut crate::op::ComputeContext<T>) {
        let target_shape_ = ctx.input(1);
        let target_shape_ = crate::ndarray_ext::as_shape(&target_shape_);
        let target_shape = target_shape_.as_slice();

        let raw_input = ctx.input(0);
        if raw_input.shape() == target_shape {
            ctx.append_output_view(raw_input);
            return;
        }

        // make broadcast dims if needed
        let input_is_scalar = crate::ndarray_ext::is_scalar_shape(raw_input.shape());
        let input = if input_is_scalar {
            raw_input.into_shape(vec![1; target_shape.len()]).unwrap()
        } else {
            raw_input
        };

        // do broadcast
        if let Some(ret) = input.broadcast(target_shape) {
            ctx.append_output(ret.to_owned());
        } else {
            ctx.set_error(op::OpError::IncompatibleShape(
                "PreprocessBinOpGradGrad: Cant't broadcast.".to_string(),
            ));
        }
    }

    fn grad(&self, ctx: &mut crate::op::GradientContext<T>) {
        let g = ctx.graph();
        let gx = maybe_reduce(&g.shape(ctx.input(0)), &ctx.output_grad(), g);
        ctx.append_input_grad(Some(gx));
        ctx.append_input_grad(None);
    }
}

impl<T: Float> op::Op<T> for AddOp {
    fn compute(&self, ctx: &mut crate::op::ComputeContext<T>) {
        let ret = add_forward(&ctx.input(0), &ctx.input(1));
        ctx.append_output(ret);
    }

    fn grad(&self, ctx: &mut crate::op::GradientContext<T>) {
        let g = ctx.graph();
        let x0 = ctx.input(0);
        let x1 = ctx.input(1);
        let gy = ctx.output_grad();
        let shape0 = &ctx.graph().shape(x0);
        let shape1 = &ctx.graph().shape(x1);
        let gy0 = maybe_reduce(shape0, &gy, g);
        let gy1 = maybe_reduce(shape1, &gy, g);
        ctx.append_input_grad(Some(gy0));
        ctx.append_input_grad(Some(gy1));
    }
}

impl<T: Float> op::Op<T> for SubOp {
    fn compute(&self, ctx: &mut crate::op::ComputeContext<T>) {
        let x0 = &ctx.input(0);
        let x1 = &ctx.input(1);
        let shape0: &[usize] = x0.shape();
        let shape1: &[usize] = x1.shape();
        let ret = if shape0 == [] {
            // is scalar
            let x0_elem = x0[ndarray::IxDyn(&[])];
            x1.map(move |&a| x0_elem - a)
        } else if shape0 == shape1 {
            #[cfg(feature = "mkl")]
            {
                use crate::{ops::mkl_ffi::*, same_type};
                bin_op_same_shape!(vsSub, vdSub, -, x0, x1)
            }
            #[cfg(not(feature = "mkl"))]
            {
                x0 - x1
            }
        } else {
            x0 - x1
        };
        ctx.append_output(ret);
    }

    fn grad(&self, ctx: &mut crate::op::GradientContext<T>) {
        let g = ctx.graph();
        let x0 = ctx.input(0);
        let x1 = ctx.input(1);
        let shape0 = &ctx.graph().shape(x0);
        let shape1 = &ctx.graph().shape(x1);
        let gy = &ctx.output_grad();
        let gy0 = maybe_reduce(shape0, gy, g);
        let gy1 = maybe_reduce(shape1, gy, g);
        ctx.append_input_grad(Some(gy0));
        ctx.append_input_grad(Some(g.neg(&gy1)));
    }
}

impl<T: Float> op::Op<T> for MulOp {
    fn compute(&self, ctx: &mut crate::op::ComputeContext<T>) {
        let ret = mul_forward(&ctx.input(0), &ctx.input(1));
        ctx.append_output(ret);
    }

    fn grad(&self, ctx: &mut crate::op::GradientContext<T>) {
        let graph = ctx.graph();
        let x0 = ctx.input(0);
        let x1 = ctx.input(1);

        let shape0 = &graph.shape(x0);
        let shape1 = &graph.shape(x1);

        let gy = ctx.output_grad();

        let gx0 = gy * x1;
        let gx1 = gy * x0;

        let gx0 = maybe_reduce(shape0, &gx0, graph);
        let gx1 = maybe_reduce(shape1, &gx1, graph);

        ctx.append_input_grad(Some(gx0));
        ctx.append_input_grad(Some(gx1));
    }
}

impl<T: Float> op::Op<T> for DivOp {
    fn compute(&self, ctx: &mut crate::op::ComputeContext<T>) {
        let x0 = &ctx.input(0);
        let x1 = &ctx.input(1);
        let shape0: &[usize] = x0.shape();
        let shape1: &[usize] = x1.shape();
        let is_scalar0 = shape0 == [] || shape0 == [0];
        let is_scalar1 = shape1 == [] || shape1 == [1];
        let ret = if is_scalar0 {
            // a is a scalar
            let x0_elem = x0[ndarray::IxDyn(&[])];
            x1.map(move |&a| x0_elem / a)
        } else if is_scalar1 {
            // b is a scalar
            let x1_elem = x1[ndarray::IxDyn(&[])];
            let rhs = T::one() / x1_elem;
            x0.mapv(|x0_elem| x0_elem * rhs)
        } else if shape0 == shape1 {
            #[cfg(feature = "mkl")]
            {
                use crate::{ops::mkl_ffi::*, same_type};
                bin_op_same_shape!(vsDiv, vdDiv, /, x0, x1)
            }
            #[cfg(not(feature = "mkl"))]
            {
                x0 / x1
            }
        } else {
            x0 / x1
        };
        ctx.append_output(ret);
    }

    fn grad(&self, ctx: &mut crate::op::GradientContext<T>) {
        let g = ctx.graph();
        let x0 = ctx.input(0);
        let x1 = ctx.input(1);
        let shape0 = &g.shape(x0);
        let shape1 = &g.shape(x1);
        let gy = ctx.output_grad();

        let gx0 = gy / x1;
        let gx1 = g.neg(x0) * g.pow(x1, T::from(-2.).unwrap()) * gy;

        let gx0 = maybe_reduce(shape0, &gx0, g);
        let gx1 = maybe_reduce(shape1, &gx1, g);

        ctx.append_input_grad(Some(gx0));
        ctx.append_input_grad(Some(gx1));
    }
}

fn maybe_reduce<'g, T: Float>(
    target_shape: &Tensor<'g, T>,
    x: &Tensor<'g, T>,
    graph: &'g Graph<T>,
) -> Tensor<'g, T> {
    Tensor::builder()
        .set_ro_inputs(&[x, target_shape])
        .set_shape(target_shape)
        .build(graph, MaybeReduce)
}

macro_rules! impl_bin_op_forward {
    ($forward_name:ident, $bin_op:tt, $vms_op:ident, $vmd_op:ident) => {
        fn $forward_name<'v, T: Float>(x0: &NdArrayView<'v, T>, x1: &NdArrayView<'v, T>) -> NdArray<T>
        {
            let shape0: &[usize] = x0.shape();
            let shape1: &[usize] = x1.shape();
            let scalar_shape = &[];
            let scalar_shape1 = &[0];

            let x0_is_scalar = shape0 == scalar_shape || shape0 == scalar_shape1;
            let x1_is_scalar = shape1 == scalar_shape || shape1 == scalar_shape1;

            if x0_is_scalar && !x1_is_scalar {
                let elem = x0[ndarray::IxDyn(&[])];
                x1.map(move |&a| a $bin_op elem)
            } else if x1_is_scalar && !x0_is_scalar {
                let elem = x1[ndarray::IxDyn(&[])];
                x0.map(move |&a| a $bin_op elem )
            } else if !x0_is_scalar && !x1_is_scalar {
                let len0: usize = shape0.iter().product();
                let len1: usize = shape1.iter().product();
                if len0 > len1 {
                    x0 $bin_op x1
                } else {
                    // tensor vs tensor (same shapes)
                    #[cfg(feature = "mkl")]
                    {
                        use crate::{ops::mkl_ffi::*, same_type};
                        bin_op_same_shape!($vms_op, $vmd_op, $bin_op, x0, x1)
                    }
                    #[cfg(not(feature = "mkl"))] {
                        x0 $bin_op x1
                    }
                }
            } else {
                // scalar vs scalar
                x0 $bin_op x1
            }
        }
    };
}

impl_bin_op_forward!(add_forward, +, vsAdd, vdAdd);
impl_bin_op_forward!(mul_forward, *, vsMul, vdMul);