1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
use crate::tensor::Tensor; use super::OpTrait; macro_rules! new_binary_op { ($a:ident, $b:expr, $c:tt) => { pub struct $a {} impl $a { pub fn new() -> $a{ $a{} } } impl OpTrait for $a { fn get_name(&self) -> String { ($b).to_string() } fn apply(&mut self, input: &[&Tensor], output: &[&Tensor]) { $c(input, output) } fn grad(&self, input: &[&Tensor], output_grad: &[&Tensor], input_grad: &[&Tensor]) { println!("binary op grad"); } fn get_values(&self) -> Vec<&Tensor> { Vec::new() } fn get_grads(&self) -> Vec<&Tensor> { Vec::new() } fn set_values(&self, v: &[Tensor]) { } } } } new_binary_op!(Add, "add", (|a:&[&Tensor], b:&[&Tensor]| b[0].swap(a[0].add(&a[1])) ) ); new_binary_op!(Sub, "sub", (|a:&[&Tensor], b:&[&Tensor]| b[0].swap(a[0].sub(a[1]))) ); new_binary_op!(Mul, "mul", (|a:&[&Tensor], b:&[&Tensor]| b[0].swap(a[0].mul(a[1]))) ); new_binary_op!(Div, "div", (|a:&[&Tensor], b:&[&Tensor]| b[0].swap(a[0].div(a[1]))) );