1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
#[cfg(test)]
mod tests {
    use super::prelude::*;
    use typenum::marker_traits::{Bit, Unsigned};
    use typenum::{U0, U1, U2, U3, U4, U5, U6};

    #[test]
    fn shape() {
        assert_eq!(<Shape2D<U3, U2> as StaticShape>::to_vec(), vec![3, 2]);
        assert_eq!(<Shape2D<U3, U2> as StaticShape>::strides(), vec![2, 1]);
        assert_eq!(<Shape2D<U3, U2> as StaticShape>::NUM_ELEMENTS, 6);
    }

    #[test]
    fn shape_constraints() {
        assert_eq!(
            <Shape3D<U2, U1, U2> as Broadcast<Shape4D<U5, U1, U3, U2>>>::Output::BOOL,
            true
        );
        assert_eq!(
            <Shape3D<U2, U4, U2> as Broadcast<Shape4D<U5, U1, U3, U2>>>::Output::BOOL,
            false
        );
        assert_eq!(
            <Shape1D<U6> as SameNumElements<i32, Shape2D<U3, U2>>>::Output::BOOL,
            true
        );
        assert_eq!(
            <Shape1D<U6> as SameNumElements<i32, Shape2D<U3, U3>>>::Output::BOOL,
            false
        );
        assert_eq!(
            <Shape2D<Dyn, U2> as Same<Shape2D<U3, U2>>>::Output::BOOL,
            true
        );
        assert_eq!(
            <Shape2D<Dyn, U2> as Same<Shape2D<U3, U3>>>::Output::BOOL,
            false
        );
        assert_eq!(
            <Shape4D<U5, U1, U3, U2> as NumElements<i32>>::Output::USIZE,
            30
        );
        assert_eq!(
            <Shape2D<Dyn, U2> as ReprShape<i32, Shape2D<U3, Dyn>>>::Output::to_vec(),
            vec![3, 2]
        );
        assert!(<<Shape2D<Dyn, Dyn> as ReprShapeDyn<i32, Shape2D<U2, Dyn>>>::Output as Shape>::runtime_compat(&[2, 3]));
        assert!(<<Shape2D<U2, Dyn> as ReprShapeDyn<i32, Shape2D<Dyn, Dyn>>>::Output as Shape>::runtime_compat(&[2, 3]));
        assert!(<<Shape2D<Dyn, Dyn> as ReprShapeDyn<
            i32,
            Shape2D<Dyn, Dyn>,
        >>::Output as Shape>::runtime_compat(&[3, 3]));
        assert!(<Shape2D<Dyn, Dyn> as Shape>::runtime_compat(&[3, 3]));
        assert!(<Shape2D<U3, Dyn> as Shape>::runtime_compat(&[3, 3]));
        assert_eq!(
            <<Shape2D<U2, U3> as Reduction<U1>>::Output as StaticShape>::to_vec(),
            vec![2, 1]
        );
        assert_eq!(
            <Shape2D<U2, U3> as ReductionOptChunckSize<i32, U0>>::Output::USIZE,
            3
        );
        assert_eq!(
            <Shape2D<U3, U3> as ReductionOptChunckSize<i32, U1>>::Output::USIZE,
            1
        );
        assert_eq!(<Shape4D<U5, U1, U3, U2> as At<U2>>::Output::USIZE, 3);
        assert_eq!(<<Shape4D<U4, U3, U6, U6> as Transpose>::Output as StaticShape>::to_vec(), vec![6, 6, 3, 4]);
    }

    #[test]
    fn broadcast_same_order() {
        let a: SliceTensor<i32, Shape2D<U1, U2>> = Tensor::from_slice(&[1, 2]);
        let b: StridedSliceTensor<_, Shape2D<U2, U2>> = a.broadcast();

        let c: StridedSliceTensor<i32, Shape2D<U2, U2>> = Tensor::from_slice(&[1, 2, 1, 2]);
        assert_eq!(b, c);
        assert_eq!(b.shape(), vec![2, 2]);
        assert_eq!(b.strides(), vec![0, 1]);
        assert_eq!(b.opt_chunk_size(), 2);
    }

    #[test]
    fn broadcast_different_order() {
        let a: SliceTensor<i32, Shape1D<U2>> = Tensor::from_slice(&[1, 2]);
        let b: StridedSliceTensor<_, Shape2D<U2, U2>> = a.broadcast();

        let c: StridedSliceTensor<i32, Shape2D<U2, U2>> = Tensor::from_slice(&[1, 2, 1, 2]);
        assert_eq!(b, c);
        assert_eq!(b.shape(), vec![2, 2]);
        assert_eq!(b.strides(), vec![0, 1]);
        assert_eq!(b.opt_chunk_size(), 2);
    }

    #[test]
    fn reshape() {
        let mut a: StaticTensor<i32, Shape1D<U4>> = Tensor::default();
        for (x, y) in a.iter_mut().zip(&[1, 2, 1, 2]) {
            *x = *y;
        }
        let b: SliceTensor<_, Shape2D<U2, U2>> = a.reshape();

        let c: SliceTensor<i32, Shape2D<U2, U2>> = Tensor::from_slice(&[1, 2, 1, 2]);
        assert_eq!(b, c);
        assert_eq!(b.shape(), vec![2, 2]);
        assert_eq!(b.strides(), vec![2, 1]);
        assert_eq!(b.opt_chunk_size(), 4);
    }

    #[test]
    fn add() {
        let a: SliceTensor<i32, Shape2D<U3, U3>> = Tensor::from_slice(&[1, 0, 0, 0, 1, 0, 0, 0, 1]);
        let b: SliceTensor<i32, Shape2D<U3, U3>> = Tensor::from_slice(&[1, 1, 1, 1, 1, 1, 1, 1, 1]);
        let c = a.add(&b);

        let d: SliceTensor<i32, Shape2D<U3, U3>> = Tensor::from_slice(&[2, 1, 1, 1, 2, 1, 1, 1, 2]);
        assert_eq!(c.as_view(), d);
    }

    #[test]
    fn mul() {
        let a: SliceTensor<i32, Shape2D<U3, U3>> = Tensor::from_slice(&[1, 0, 0, 0, 3, 0, 0, 0, 1]);
        let b: SliceTensor<i32, Shape2D<U3, U3>> = Tensor::from_slice(&[3, 1, 1, 1, 1, 1, 1, 1, 1]);
        let c = a.mul(&b);

        let d: SliceTensor<i32, Shape2D<U3, U3>> = Tensor::from_slice(&[3, 0, 0, 0, 3, 0, 0, 0, 1]);
        assert_eq!(c.as_view(), d);
    }

    #[test]
    fn add_static_ok() {
        let a: SliceTensor<i32, Shape2D<Dyn, Dyn>> =
            Tensor::from_slice_dyn(&[1, 0, 0, 0, 1, 0, 0, 0, 1], vec![3, 3]);
        let b: SliceTensor<i32, Shape2D<U3, U3>> = Tensor::from_slice(&[1, 1, 1, 1, 1, 1, 1, 1, 1]);
        let c = a.add_coerce(&b);

        let d: SliceTensor<i32, Shape2D<U3, U3>> = Tensor::from_slice(&[2, 1, 1, 1, 2, 1, 1, 1, 2]);
        assert_eq!(c.as_view(), d);
    }

    #[test]
    #[should_panic(expected = "Tensors must have same shape")]
    fn add_static_panic() {
        let a: SliceTensor<i32, Shape2D<Dyn, Dyn>> =
            Tensor::from_slice_dyn(&[1, 0, 0, 0, 1, 0, 0, 0, 1], vec![3, 3]);
        let b: SliceTensor<i32, Shape2D<U2, U2>> = Tensor::from_slice(&[1, 1, 1, 1]);
        let _c = a.add_coerce(&b);
    }

    #[test]
    fn add_dyn_ok() {
        let a: SliceTensor<i32, Shape2D<Dyn, Dyn>> =
            Tensor::from_slice_dyn(&[1, 0, 0, 0, 1, 0, 0, 0, 1], vec![3, 3]);
        let b: SliceTensor<i32, Shape2D<U3, Dyn>> =
            Tensor::from_slice_dyn(&[1, 1, 1, 1, 1, 1, 1, 1, 1], vec![3, 3]);
        let c = a.add_dynamic(&b);

        let d: SliceTensor<i32, Shape2D<U3, U3>> = Tensor::from_slice(&[2, 1, 1, 1, 2, 1, 1, 1, 2]);
        assert_eq!(c.as_static(), d);
    }

    #[test]
    fn add_broadcast() {
        let a: SliceTensor<i32, Shape2D<U3, U3>> = Tensor::from_slice(&[1, 0, 0, 0, 1, 0, 0, 0, 1]);
        let b: SliceTensor<i32, Shape1D<U3>> = Tensor::from_slice(&[1, 1, 1]);
        let c: StaticTensor<_, _> = a.add(&b.broadcast());

        let d: SliceTensor<i32, Shape2D<U3, U3>> = Tensor::from_slice(&[2, 1, 1, 1, 2, 1, 1, 1, 2]);
        assert_eq!(c.as_view(), d);
    }

    #[test]
    fn scal_add() {
        let a: SliceTensor<i32, Shape2D<U3, U3>> = Tensor::from_slice(&[1, 0, 0, 0, 1, 0, 0, 0, 1]);
        let c = a.scal_add(1);

        let d: SliceTensor<i32, Shape2D<U3, U3>> = Tensor::from_slice(&[2, 1, 1, 1, 2, 1, 1, 1, 2]);
        assert_eq!(c.as_view(), d);
    }

    #[test]
    fn scal_add_dyn() {
        let a: SliceTensor<i32, Shape2D<Dyn, Dyn>> =
            Tensor::from_slice_dyn(&[1, 0, 0, 0, 1, 0, 0, 0, 1], vec![3, 3]);
        let c = a.scal_add_dynamic(1);

        let d: SliceTensor<i32, Shape2D<U3, U3>> = Tensor::from_slice(&[2, 1, 1, 1, 2, 1, 1, 1, 2]);
        assert_eq!(c.as_static(), d);
    }

    #[test]
    fn exp() {
        let data = [
            2.0_f64.ln(),
            0.0,
            0.0,
            0.0,
            2.0_f64.ln(),
            0.0,
            0.0,
            0.0,
            2.0_f64.ln(),
        ];
        let a: SliceTensor<f64, Shape2D<U3, U3>> = Tensor::from_slice(&data);
        let c = a.exp();

        let d: SliceTensor<f64, Shape2D<U3, U3>> =
            Tensor::from_slice(&[2.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 2.0]);
        assert_eq!(c.as_view(), d);
    }

    #[test]
    fn exp_dyn() {
        let data = [
            2.0_f64.ln(),
            0.0,
            0.0,
            0.0,
            2.0_f64.ln(),
            0.0,
            0.0,
            0.0,
            2.0_f64.ln(),
        ];
        let a: SliceTensor<f64, Shape2D<Dyn, Dyn>> = Tensor::from_slice_dyn(&data, vec![3, 3]);
        let c = a.exp_dynamic();

        let d: SliceTensor<f64, Shape2D<U3, U3>> =
            Tensor::from_slice(&[2.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 2.0]);
        assert_eq!(c.as_static(), d);
    }

    #[test]
    fn powi() {
        let a: SliceTensor<f64, Shape2D<U3, U3>> =
            Tensor::from_slice(&[2.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 2.0]);
        let c = a.powi(2);

        let d: SliceTensor<f64, Shape2D<U3, U3>> =
            Tensor::from_slice(&[4.0, 1.0, 1.0, 1.0, 4.0, 1.0, 1.0, 1.0, 4.0]);
        assert_eq!(c.as_view(), d);
    }

    #[test]
    fn powi_dyn() {
        let a: SliceTensor<f64, Shape2D<Dyn, Dyn>> =
            Tensor::from_slice_dyn(&[2.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 2.0], vec![3, 3]);
        let c = a.powi_dynamic(2);

        let d: SliceTensor<f64, Shape2D<U3, U3>> =
            Tensor::from_slice(&[4.0, 1.0, 1.0, 1.0, 4.0, 1.0, 1.0, 1.0, 4.0]);
        assert_eq!(c.as_static(), d);
    }

    #[test]
    fn max() {
        let a: SliceTensor<f64, Shape2D<U2, U2>> = Tensor::from_slice(&[1.0, 0.0, 5.0, 9.0]);
        let b: SliceTensor<f64, Shape2D<U2, U2>> = Tensor::from_slice(&[3.0, 1.0, 0.0, 1.0]);
        let c = a.max(&b);

        let d: SliceTensor<f64, Shape2D<U2, U2>> = Tensor::from_slice(&[3.0, 1.0, 5.0, 9.0]);
        assert_eq!(c.as_view(), d);
    }

    #[test]
    fn max_static() {
        let a: SliceTensor<f64, Shape2D<Dyn, U2>> =
            Tensor::from_slice_dyn(&[1.0, 0.0, 5.0, 9.0], vec![2, 2]);
        let b: SliceTensor<f64, Shape2D<U2, Dyn>> =
            Tensor::from_slice_dyn(&[3.0, 1.0, 0.0, 1.0], vec![2, 2]);
        let c = a.max_coerce(&b);

        let d: SliceTensor<f64, Shape2D<U2, U2>> = Tensor::from_slice(&[3.0, 1.0, 5.0, 9.0]);
        assert_eq!(c.as_view(), d);
    }

    #[test]
    fn max_dyn() {
        let a: SliceTensor<f64, Shape2D<Dyn, Dyn>> =
            Tensor::from_slice_dyn(&[1.0, 0.0, 5.0, 9.0], vec![2, 2]);
        let b: SliceTensor<f64, Shape2D<Dyn, Dyn>> =
            Tensor::from_slice_dyn(&[3.0, 1.0, 0.0, 1.0], vec![2, 2]);
        let c = a.max_dynamic(&b);

        let d: SliceTensor<f64, Shape2D<U2, U2>> = Tensor::from_slice(&[3.0, 1.0, 5.0, 9.0]);
        assert_eq!(c.as_static(), d);
    }

    #[test]
    fn mul_add() {
        let a: SliceTensor<f64, Shape2D<U3, U3>> =
            Tensor::from_slice(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]);
        let x: SliceTensor<f64, Shape2D<U3, U3>> =
            Tensor::from_slice(&[1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 1.0, 1.0, 1.0]);
        let b: SliceTensor<f64, Shape2D<U3, U3>> =
            Tensor::from_slice(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
        let c = a.mul_add(&x, &b);

        let d: SliceTensor<f64, Shape2D<U3, U3>> =
            Tensor::from_slice(&[1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 1.0, 1.0, 2.0]);
        assert_eq!(c.as_view(), d);
    }

    #[test]
    fn mul_add_static() {
        let a: SliceTensor<f64, Shape2D<Dyn, U3>> =
            Tensor::from_slice_dyn(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], vec![3, 3]);
        let x: SliceTensor<f64, Shape2D<Dyn, Dyn>> =
            Tensor::from_slice_dyn(&[1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 1.0, 1.0, 1.0], vec![3, 3]);
        let b: SliceTensor<f64, Shape2D<U3, Dyn>> =
            Tensor::from_slice_dyn(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0], vec![3, 3]);
        let c = a.mul_add_coerce(&x, &b);

        let d: SliceTensor<f64, Shape2D<U3, U3>> =
            Tensor::from_slice(&[1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 1.0, 1.0, 2.0]);
        assert_eq!(c.as_view(), d);
    }

    #[test]
    fn mul_add_dyn() {
        let a: SliceTensor<f64, Shape2D<Dyn, U3>> =
            Tensor::from_slice_dyn(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], vec![3, 3]);
        let x: SliceTensor<f64, Shape2D<Dyn, Dyn>> =
            Tensor::from_slice_dyn(&[1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 1.0, 1.0, 1.0], vec![3, 3]);
        let b: SliceTensor<f64, Shape2D<Dyn, Dyn>> =
            Tensor::from_slice_dyn(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0], vec![3, 3]);
        let c = a.mul_add_dynamic(&x, &b);

        let d: SliceTensor<f64, Shape2D<U3, U3>> =
            Tensor::from_slice(&[1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 1.0, 1.0, 2.0]);
        assert_eq!(c.as_static(), d);
    }

    #[test]
    fn scal_mal_add() {
        let a: SliceTensor<f64, Shape2D<U3, U3>> =
            Tensor::from_slice(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]);
        let c = a.scal_mul_add(2.0, 1.0);

        let d: SliceTensor<f64, Shape2D<U3, U3>> =
            Tensor::from_slice(&[3.0, 1.0, 1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 3.0]);
        assert_eq!(c.as_view(), d);
    }

    #[test]
    fn scal_mul_add_dyn() {
        let a: SliceTensor<f64, Shape2D<Dyn, U3>> =
            Tensor::from_slice_dyn(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], vec![3, 3]);
        let c = a.scal_mul_add_dynamic(2.0, 1.0);

        let d: SliceTensor<f64, Shape2D<U3, U3>> =
            Tensor::from_slice(&[3.0, 1.0, 1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 3.0]);
        assert_eq!(c.as_static(), d);
    }

    #[test]
    fn sum() {
        let a: SliceTensor<f64, Shape2D<U3, U3>> =
            Tensor::from_slice(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]);
        let c = a.sum::<U1>();

        let d: SliceTensor<f64, Shape2D<U3, U1>> = Tensor::from_slice(&[1.0, 1.0, 1.0]);
        assert_eq!(c.as_view(), d);
    }

    #[test]
    fn inverse_dot() {
        let a: SliceTensor<f64, Shape2D<U2, U2>> = Tensor::from_slice(&[1.0, 1.0, 0.0, 1.0]);
        let b: SliceTensor<f64, Shape2D<U2, U2>> = Tensor::from_slice(&[1.0, -1.0, 0.0, 1.0]);
        let c = a.dot(&b);

        let d: SliceTensor<f64, Shape2D<U2, U2>> = Tensor::from_slice(&[1.0, 0.0, 0.0, 1.0]);
        assert_eq!(c.as_view(), d);
    }

    #[test]
    fn rotation_dot() {
        use std::f64::consts::FRAC_PI_4;
        use std::f64::EPSILON;
        let a_data = [FRAC_PI_4.cos(), -FRAC_PI_4.sin(), 0.0, FRAC_PI_4.sin(), FRAC_PI_4.cos(), 0.0, 0.0, 0.0, 1.0];
        let d_data = [FRAC_PI_4.cos(), FRAC_PI_4.cos(), 3.0];
        let a: SliceTensor<f64, Shape2D<U3, U3>> = Tensor::from_slice(&a_data);
        let b: SliceTensor<f64, Shape2D<U3, U1>> = Tensor::from_slice(&[1.0, 0.0, 3.0]);
        let c = a.dot(&b);
        
        let t: SliceTensor<f64, Shape2D<U3, U1>> = Tensor::from_slice(&d_data);
        let d = c.sub(&t).sum::<U0>().chunks(1).nth(0).unwrap()[0];
        assert!(d < EPSILON);
    }

    #[test]
    fn transpose() {
        let a: SliceTensor<i32, Shape2D<U2, U3>> = Tensor::from_slice(&[1, 2, 3, 4, 5, 6]);
        let b = a.transpose();

        let c: SliceTensor<i32, Shape2D<U3, U2>> = Tensor::from_slice(&[1, 4, 2, 5, 3, 6]);
        assert_eq!(*b, *c);
        assert_eq!(b.shape(), vec![3, 2]);
        assert_eq!(b.strides(), vec![1, 3]);
        assert_eq!(b.opt_chunk_size(), 1);
    }

    #[test]
    fn backprop() {
        let a: SliceTensor<f64, Shape2D<U2, U2>> = Tensor::from_slice(&[1.0, 1.0, 0.0, 1.0]);
        let b: SliceTensor<f64, Shape2D<U2, U2>> = Tensor::from_slice(&[1.0, 1.0, 0.0, 1.0]);

        let a = Variable::new(a, true);
        let b = Variable::new(b, false);

        let c = Variable::clone(&a) + b;
        c.backward(StaticTensor::fill(1.0));
        
        assert_eq!(a.grad().unwrap(), StaticTensor::fill(1.0));
    }

    #[test]
    fn backprop2() {
        let a: SliceTensor<f64, Shape2D<U2, U2>> = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0]);
        let b: SliceTensor<f64, Shape2D<U2, U2>> = Tensor::from_slice(&[2.0, 1.0, 0.0, 2.0]);
        let c: SliceTensor<f64, Shape2D<U2, U2>> = Tensor::from_slice(&[1.0, 0.0, 0.0, 1.0]);

        let a = Variable::new(a, true);
        let b = Variable::new(b, false);
        let c = Variable::new(c, false);

        let a_times_b = Variable::clone(&a) * b;
        let result = a_times_b + c;
        result.backward(StaticTensor::fill(1.0));
        
        assert_eq!(a.grad().unwrap().as_view(), Tensor::from_slice(&[2.0, 1.0, 0.0, 2.0]));
    }
}

pub mod prelude;
pub mod tensor;
pub mod backprop;
pub mod ring;