1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
use std::cell::{RefCell, Ref};
use std::rc::Rc;
use super::tensor::Tensor;
pub trait OpTrait {
fn get_name(&self) -> String;
fn apply(&mut self, input: &[&Tensor], output: &[&Tensor]);
fn grad(&self, input: &[&Tensor], output_grad: &[&Tensor], input_grad: &[&Tensor]);
fn get_values(&self) -> Vec<&Tensor>;
fn set_values(&self, v: &[Tensor]);
fn get_grads(&self) -> Vec<&Tensor>;
}
pub struct Op {
o: Rc<RefCell<Box<dyn OpTrait>>>,
}
impl Op {
pub fn new(o: Box<dyn OpTrait>) -> Self {
Op {
o: Rc::new(RefCell::new(o)),
}
}
pub fn get(&self) -> Ref<Box<dyn OpTrait>> {
self.o.borrow()
}
pub fn get_name(&self) -> String {
self.o.borrow_mut().get_name()
}
pub fn apply(&self, input: &[&Tensor], output: &[&Tensor]) {
self.o.borrow_mut().apply(input, output)
}
pub fn grad(&self, input: &[&Tensor], output_grad: &[&Tensor], input_grad: &[&Tensor]) {
self.o.borrow_mut().grad(input, output_grad, input_grad);
}
pub fn get_values(&self) -> Vec<Tensor> {
let mut ret = Vec::new();
for i in self.o.borrow().get_values() {
ret.push(i.clone());
}
ret
}
pub fn set_values(&self, v: &[Tensor]) {
self.o.borrow_mut().set_values(v);
}
pub fn get_grads(&self) -> Vec<Tensor> {
let mut ret = Vec::new();
for i in self.o.borrow().get_grads() {
ret.push(i.clone());
}
ret
}
}
impl Clone for Op {
fn clone(&self) -> Self {
Op {
o: Rc::clone(&self.o),
}
}
}
pub mod local;
pub use local::Add as Add;
pub use local::Sub as Sub;
pub use local::Mul as Mul;
pub use local::Div as Div;
pub mod linear;
pub use linear::Linear as Linear;
pub mod loss;
pub use loss::MSELoss as MSELoss;