ferrite/tensor/device/cpu/kernels/
arithmetic.rs

1use crate::*;
2use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign, Div, DivAssign};
3
4
5
6impl ArithmeticOps for CpuStorage {
7  fn add_tensor(&self, other: &Self) -> Self {
8    let (tensor_a, tensor_b) = CpuStorage::broadcast_tensors(self, other);
9    tensor_a.elementwise_op(&tensor_b, |a, b| a + b)
10  }
11
12  fn sub_tensor(&self, other: &Self) -> Self {
13    let (tensor_a, tensor_b) = CpuStorage::broadcast_tensors(self, other);
14    tensor_a.elementwise_op(&tensor_b, |a, b| a - b)
15  }
16
17  fn mul_tensor(&self, other: &Self) -> Self {
18    let (tensor_a, tensor_b) = CpuStorage::broadcast_tensors(self, other);
19    tensor_a.elementwise_op(&tensor_b, |a, b| a * b)
20  }
21
22  fn div_tensor(&self, other: &Self) -> Self {
23    let (tensor_a, tensor_b) = CpuStorage::broadcast_tensors(self, other);
24   tensor_a.elementwise_op(&tensor_b, |a, b| a / b)
25  }
26
27  fn add_f32(&self, other: f32) -> Self {
28    self.scalar_op(other, |a, b| a + b)
29  }
30
31  fn sub_f32(&self, other: f32) -> Self {
32    self.scalar_op(other, |a, b| a - b)
33  }
34
35  fn mul_f32(&self, other: f32) -> Self {
36    self.scalar_op(other, |a, b| a * b)
37  }
38
39  fn div_f32(&self, other: f32) -> Self {
40    self.scalar_op(other, |a, b| a / b)
41  }
42
43  fn pow_f32(&self, other: f32) -> Self {
44    self.scalar_op(other, |a, b| f32::powf(a, b))
45  }
46
47  fn greater_than(&self, other: &Self, make_binary: bool) -> Self {
48    let (tensor_a, tensor_b) = CpuStorage::broadcast_tensors(self, other);
49    tensor_a.elementwise_op(&tensor_b, |a, b| if a > b { 1.0 } else if make_binary { 0.0 } else {-1.0})
50  }
51
52  fn greater_than_f32(&self, other: f32, make_binary: bool) -> Self {
53    self.scalar_op(other, |a, b| if a > b { 1.0 } else if make_binary { 0.0 } else {-1.0})
54  }
55
56  fn less_than(&self, other: &Self, make_binary: bool) -> Self {
57    let (tensor_a, tensor_b) = CpuStorage::broadcast_tensors(self, other);
58    tensor_a.elementwise_op(&tensor_b, |a, b| if a < b { 1.0 } else if make_binary { 0.0 } else {-1.0})
59  }
60
61  fn less_than_f32(&self, other: f32, make_binary: bool) -> Self {
62    self.scalar_op(other, |a, b| if a < b { 1.0 } else if make_binary { 0.0 } else {-1.0})
63  }
64
65  fn sign(&self) -> Self {
66    self.apply(|a| if a > 0. { 1.0 } else if a < 0. { -1. } else { 0.0 })
67  }
68
69  fn abs(&self) -> Self {
70    self.apply(|a| f32::abs(a))
71  }
72
73  fn add_tensor_assign(&mut self, other: &Self) {
74    // Only broadcast one side
75    let broadcast_b = other.broadcast(&self.shape());
76    self.elementwise_op_assign(&broadcast_b, |a, b| a + b)
77  }
78
79  fn sub_tensor_assign(&mut self, other: &Self) {
80    self.elementwise_op_assign(other, |a, b| a - b)
81  }
82
83  fn mul_tensor_assign(&mut self, other: &Self) {
84    self.elementwise_op_assign(other, |a, b| a * b)
85  }
86
87  fn div_tensor_assign(&mut self, other: &Self) {
88    self.elementwise_op_assign(other, |a, b| a / b)
89  }
90
91  fn add_f32_assign(&mut self, other: f32) {
92    self.scalar_op_assign(other, |a, b| a + b)
93  }
94
95  fn sub_f32_assign(&mut self, other: f32) {
96    self.scalar_op_assign(other, |a, b| a - b)
97  }
98
99  fn mul_f32_assign(&mut self, other: f32) {
100    self.scalar_op_assign(other, |a, b| a * b)
101  }
102
103  fn div_f32_assign(&mut self, other: f32) {
104    self.scalar_op_assign(other, |a, b| a / b)
105  }
106
107  fn pow_f32_assign(&mut self, other: f32) {
108    self.scalar_op_assign(other, |a, b| f32::powf(a, b))
109  }
110
111
112  fn abs_assign(&mut self) {
113    self.apply_assign(|a| f32::abs(a))
114  }
115}