1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
use super::ops::{BinaryKernel, UnaryKernel};
use crate::{
shapes::{Dtype, Shape},
tensor::{
cpu::{Cpu, LendingIterator, NdIndex},
unique_id, Tensor, ZerosTensor,
},
};
pub trait UnaryDerivative<E> {
fn f(&self, x: &E) -> E;
fn df(&self, x: &E) -> E;
}
pub trait BinaryDerivative<E> {
fn f(&self, x: &E, y: &E) -> E;
fn dfdx(&self, x: &E, y: &E) -> E;
fn dfdy(&self, x: &E, y: &E) -> E;
}
impl<E: Dtype, Op: UnaryDerivative<E>> UnaryKernel<Op, E> for Cpu {
fn forward<S: Shape>(
&self,
op: Op,
inp: &Tensor<S, E, Self>,
) -> Result<Tensor<S, E, Self>, Self::Err> {
let mut out = Tensor {
id: unique_id(),
data: inp.data.clone(),
shape: inp.shape,
strides: inp.strides,
device: self.clone(),
tape: Default::default(),
};
for x in out.buf_iter_mut() {
*x = op.f(x);
}
Ok(out)
}
fn backward<S: Shape>(
&self,
op: Op,
inp: &Tensor<S, E, Self>,
grad_inp: &mut Self::Vec<E>,
grad_out: &Self::Vec<E>,
) -> Result<(), Self::Err> {
debug_assert_eq!(grad_inp.len(), grad_out.len());
debug_assert_eq!(inp.data.len(), grad_out.len());
for (i, x) in grad_inp.iter_mut().enumerate() {
*x += op.df(&inp.data[i]) * grad_out[i];
}
Ok(())
}
}
impl<E: Dtype, Op: BinaryDerivative<E>> BinaryKernel<Op, E> for Cpu {
fn forward<S: Shape>(
&self,
op: Op,
lhs: &Tensor<S, E, Self>,
rhs: &Tensor<S, E, Self>,
) -> Result<Tensor<S, E, Self>, Self::Err> {
let mut out = self.try_zeros_like(&lhs.shape)?;
let mut lhs_iter = lhs.iter();
let mut rhs_iter = rhs.iter();
for o in out.buf_iter_mut() {
let l = lhs_iter.next().unwrap();
let r = rhs_iter.next().unwrap();
*o = op.f(l, r);
}
Ok(out)
}
fn backward<S: Shape>(
&self,
op: Op,
lhs: &Tensor<S, E, Self>,
grad_lhs: &mut Self::Vec<E>,
rhs: &Tensor<S, E, Self>,
grad_rhs: &mut Self::Vec<E>,
grad_out: &Self::Vec<E>,
) -> Result<(), Self::Err> {
let mut lhs_idx = NdIndex::new(lhs.shape, lhs.strides);
let mut rhs_idx = NdIndex::new(rhs.shape, rhs.strides);
let lhs_buf = lhs.data.as_ref();
let rhs_buf = rhs.data.as_ref();
for &go in grad_out.iter() {
let lhs_i = lhs_idx.next().unwrap();
let rhs_i = rhs_idx.next().unwrap();
let l = &lhs_buf[lhs_i];
let r = &rhs_buf[rhs_i];
grad_lhs[lhs_i] += op.dfdx(l, r) * go;
grad_rhs[rhs_i] += op.dfdy(l, r) * go;
}
Ok(())
}
}