custos_math/ops/
arithmetic.rs

1use custos::{impl_stack, number::Number, Device, MainMemory, Shape, CPU};
2
3use crate::{ew_op, Matrix};
4
5#[cfg(feature = "stack")]
6use custos::Stack;
7
8#[cfg(any(feature = "cuda", feature = "opencl"))]
9use custos::CDatatype;
10
11#[cfg(feature = "opencl")]
12use crate::cl_tew;
13#[cfg(feature = "opencl")]
14use custos::OpenCL;
15
16#[cfg(feature = "cuda")]
17use crate::cu_ew;
18
19#[cfg_attr(feature = "safe", doc = "```ignore")]
20/// Element-wise +, -, *, / operations for matrices.
21///
22/// # Examples
23#[cfg_attr(feature = "cpu", doc = "```")]
24#[cfg_attr(not(feature = "cpu"), doc = "```ignore")]
25/// use custos::CPU;
26/// use custos_math::Matrix;
27///
28/// let device = CPU::new();
29/// let a = Matrix::from((&device, (2, 3), [2, 4, 6, 8, 10, 12]));
30/// let b = Matrix::from((&device, (2, 3), [12, 4, 3, 1, -5, -3]));
31///
32/// let c = &a + &b;
33/// assert_eq!(c.read(), vec![14, 8, 9, 9, 5, 9]);
34///
35/// use custos_math::BaseOps;
36/// let sub = device.sub(&a, &b);
37/// assert_eq!(sub.read(), vec![-10, 0, 3, 7, 15, 15]);
38/// ```
39pub trait BaseOps<T, S: Shape = (), D: Device = Self>: Device {
40    /// Element-wise addition
41    /// # Example
42    #[cfg_attr(feature = "cpu", doc = "```")]
43    #[cfg_attr(not(feature = "cpu"), doc = "```ignore")]
44    /// use custos::CPU;
45    /// use custos_math::Matrix;
46    ///
47    /// let device = CPU::new();
48    /// let a = Matrix::from((&device, 2, 3, [2, 4, 6, 8, 10, 12]));
49    /// let b = Matrix::from((&device, 2, 3, [12, 4, 3, 1, -5, -3]));
50    ///
51    /// let c = a + b;
52    /// assert_eq!(c.read(), vec![14, 8, 9, 9, 5, 9]);
53    /// ```
54    fn add(&self, lhs: &Matrix<T, D, S>, rhs: &Matrix<T, D, S>) -> Matrix<T, Self, S>;
55
56    /// Element-wise subtraction
57    /// # Example
58    #[cfg_attr(feature = "cpu", doc = "```")]
59    #[cfg_attr(not(feature = "cpu"), doc = "```ignore")]
60    /// use custos::CPU;
61    /// use custos_math::{Matrix, BaseOps};
62    ///
63    /// let device = CPU::new();
64    /// let a = Matrix::from((&device, 2, 3, [2, 4, 6, 8, 10, 12]));
65    /// let b = Matrix::from((&device, 2, 3, [12, 4, 3, 1, -5, -3]));
66    ///
67    /// let sub = device.sub(&a, &b);
68    /// assert_eq!(sub.read(), vec![-10, 0, 3, 7, 15, 15]);
69    /// ```
70    fn sub(&self, lhs: &Matrix<T, D, S>, rhs: &Matrix<T, D, S>) -> Matrix<T, Self, S>;
71
72    /// Element-wise multiplication
73    /// # Example
74    #[cfg_attr(feature = "cpu", doc = "```")]
75    #[cfg_attr(not(feature = "cpu"), doc = "```ignore")]
76    /// use custos::CPU;
77    /// use custos_math::{Matrix, BaseOps};
78    ///
79    /// let device = CPU::new();
80    /// let a = Matrix::from((&device, 2, 3, [2, 4, 6, 8, 10, 12]));
81    /// let b = Matrix::from((&device, 2, 3, [12, 4, 3, 1, -5, -3]));
82    ///
83    /// let mul = a * b;
84    /// assert_eq!(mul.read(), vec![24, 16, 18, 8, -50, -36]);
85    /// ```
86    fn mul(&self, lhs: &Matrix<T, D, S>, rhs: &Matrix<T, D, S>) -> Matrix<T, Self, S>;
87
88    /// Element-wise division
89    /// # Example
90    #[cfg_attr(feature = "cpu", doc = "```")]
91    #[cfg_attr(not(feature = "cpu"), doc = "```ignore")]
92    /// use custos::CPU;
93    /// use custos_math::{Matrix, BaseOps};
94    ///
95    /// let device = CPU::new();
96    /// let a = Matrix::from((&device, 2, 3, [2, 4, 6, 8, 10, 12]));
97    /// let b = Matrix::from((&device, 2, 3, [12, 4, 3, 1, -5, -3]));
98    ///
99    /// let div = device.div(&a, &b);
100    /// assert_eq!(div.read(), vec![0, 1, 2, 8, -2, -4]);
101    /// ```
102    fn div(&self, lhs: &Matrix<T, D, S>, rhs: &Matrix<T, D, S>) -> Matrix<T, Self, S>;
103}
104
105#[impl_stack]
106impl<T, S, D> BaseOps<T, S, D> for CPU
107where
108    T: Number,
109    S: Shape,
110    D: MainMemory,
111{
112    fn add(&self, lhs: &Matrix<T, D, S>, rhs: &Matrix<T, D, S>) -> Matrix<T, Self, S> {
113        ew_op(self, lhs, rhs, |x, y| x + y)
114    }
115
116    fn sub(&self, lhs: &Matrix<T, D, S>, rhs: &Matrix<T, D, S>) -> Matrix<T, Self, S> {
117        ew_op(self, lhs, rhs, |x, y| x - y)
118    }
119
120    fn mul(&self, lhs: &Matrix<T, D, S>, rhs: &Matrix<T, D, S>) -> Matrix<T, Self, S> {
121        ew_op(self, lhs, rhs, |x, y| x * y)
122    }
123
124    fn div(&self, lhs: &Matrix<T, D, S>, rhs: &Matrix<T, D, S>) -> Matrix<T, Self, S> {
125        ew_op(self, lhs, rhs, |x, y| x / y)
126    }
127}
128
129#[cfg(feature = "opencl")]
130impl<T: CDatatype> BaseOps<T> for OpenCL {
131    fn add(&self, lhs: &Matrix<T, Self>, rhs: &Matrix<T, Self>) -> Matrix<T, Self> {
132        let buf = cl_tew(self, lhs, rhs, "+").unwrap();
133        (buf, lhs.dims()).into()
134    }
135
136    fn sub(&self, lhs: &Matrix<T, Self>, rhs: &Matrix<T, Self>) -> Matrix<T, Self> {
137        let buf = cl_tew(self, lhs, rhs, "-").unwrap();
138        (buf, lhs.dims()).into()
139    }
140
141    fn mul(&self, lhs: &Matrix<T, Self>, rhs: &Matrix<T, Self>) -> Matrix<T, Self> {
142        let buf = cl_tew(self, lhs, rhs, "*").unwrap();
143        (buf, lhs.dims()).into()
144    }
145
146    fn div(&self, lhs: &Matrix<T, Self>, rhs: &Matrix<T, Self>) -> Matrix<T, Self> {
147        let buf = cl_tew(self, lhs, rhs, "/").unwrap();
148        (buf, lhs.dims()).into()
149    }
150}
151
152#[cfg(feature = "cuda")]
153impl<T: CDatatype> BaseOps<T> for custos::CUDA {
154    fn add(&self, lhs: &Matrix<T, Self>, rhs: &Matrix<T, Self>) -> Matrix<T, Self> {
155        let buf = cu_ew(self, lhs, rhs, "+").unwrap();
156        (buf, lhs.dims()).into()
157    }
158
159    fn sub(&self, lhs: &Matrix<T, Self>, rhs: &Matrix<T, Self>) -> Matrix<T, Self> {
160        let buf = cu_ew(self, lhs, rhs, "-").unwrap();
161        (buf, lhs.dims()).into()
162    }
163
164    fn mul(&self, lhs: &Matrix<T, Self>, rhs: &Matrix<T, Self>) -> Matrix<T, Self> {
165        let buf = cu_ew(self, lhs, rhs, "*").unwrap();
166        (buf, lhs.dims()).into()
167    }
168
169    fn div(&self, lhs: &Matrix<T, Self>, rhs: &Matrix<T, Self>) -> Matrix<T, Self> {
170        let buf = cu_ew(self, lhs, rhs, "/").unwrap();
171        (buf, lhs.dims()).into()
172    }
173
174    /*fn clear(&self, buf: &mut crate::Buffer<T>) {
175        cu_clear(self, buf).unwrap();
176    }*/
177}