1use custos::{impl_stack, number::Number, Device, MainMemory, Shape, CPU};
2
3use crate::{ew_op, Matrix};
4
5#[cfg(feature = "stack")]
6use custos::Stack;
7
8#[cfg(any(feature = "cuda", feature = "opencl"))]
9use custos::CDatatype;
10
11#[cfg(feature = "opencl")]
12use crate::cl_tew;
13#[cfg(feature = "opencl")]
14use custos::OpenCL;
15
16#[cfg(feature = "cuda")]
17use crate::cu_ew;
18
19#[cfg_attr(feature = "safe", doc = "```ignore")]
20#[cfg_attr(feature = "cpu", doc = "```")]
24#[cfg_attr(not(feature = "cpu"), doc = "```ignore")]
25pub trait BaseOps<T, S: Shape = (), D: Device = Self>: Device {
40 #[cfg_attr(feature = "cpu", doc = "```")]
43 #[cfg_attr(not(feature = "cpu"), doc = "```ignore")]
44 fn add(&self, lhs: &Matrix<T, D, S>, rhs: &Matrix<T, D, S>) -> Matrix<T, Self, S>;
55
56 #[cfg_attr(feature = "cpu", doc = "```")]
59 #[cfg_attr(not(feature = "cpu"), doc = "```ignore")]
60 fn sub(&self, lhs: &Matrix<T, D, S>, rhs: &Matrix<T, D, S>) -> Matrix<T, Self, S>;
71
72 #[cfg_attr(feature = "cpu", doc = "```")]
75 #[cfg_attr(not(feature = "cpu"), doc = "```ignore")]
76 fn mul(&self, lhs: &Matrix<T, D, S>, rhs: &Matrix<T, D, S>) -> Matrix<T, Self, S>;
87
88 #[cfg_attr(feature = "cpu", doc = "```")]
91 #[cfg_attr(not(feature = "cpu"), doc = "```ignore")]
92 fn div(&self, lhs: &Matrix<T, D, S>, rhs: &Matrix<T, D, S>) -> Matrix<T, Self, S>;
103}
104
105#[impl_stack]
106impl<T, S, D> BaseOps<T, S, D> for CPU
107where
108 T: Number,
109 S: Shape,
110 D: MainMemory,
111{
112 fn add(&self, lhs: &Matrix<T, D, S>, rhs: &Matrix<T, D, S>) -> Matrix<T, Self, S> {
113 ew_op(self, lhs, rhs, |x, y| x + y)
114 }
115
116 fn sub(&self, lhs: &Matrix<T, D, S>, rhs: &Matrix<T, D, S>) -> Matrix<T, Self, S> {
117 ew_op(self, lhs, rhs, |x, y| x - y)
118 }
119
120 fn mul(&self, lhs: &Matrix<T, D, S>, rhs: &Matrix<T, D, S>) -> Matrix<T, Self, S> {
121 ew_op(self, lhs, rhs, |x, y| x * y)
122 }
123
124 fn div(&self, lhs: &Matrix<T, D, S>, rhs: &Matrix<T, D, S>) -> Matrix<T, Self, S> {
125 ew_op(self, lhs, rhs, |x, y| x / y)
126 }
127}
128
129#[cfg(feature = "opencl")]
130impl<T: CDatatype> BaseOps<T> for OpenCL {
131 fn add(&self, lhs: &Matrix<T, Self>, rhs: &Matrix<T, Self>) -> Matrix<T, Self> {
132 let buf = cl_tew(self, lhs, rhs, "+").unwrap();
133 (buf, lhs.dims()).into()
134 }
135
136 fn sub(&self, lhs: &Matrix<T, Self>, rhs: &Matrix<T, Self>) -> Matrix<T, Self> {
137 let buf = cl_tew(self, lhs, rhs, "-").unwrap();
138 (buf, lhs.dims()).into()
139 }
140
141 fn mul(&self, lhs: &Matrix<T, Self>, rhs: &Matrix<T, Self>) -> Matrix<T, Self> {
142 let buf = cl_tew(self, lhs, rhs, "*").unwrap();
143 (buf, lhs.dims()).into()
144 }
145
146 fn div(&self, lhs: &Matrix<T, Self>, rhs: &Matrix<T, Self>) -> Matrix<T, Self> {
147 let buf = cl_tew(self, lhs, rhs, "/").unwrap();
148 (buf, lhs.dims()).into()
149 }
150}
151
152#[cfg(feature = "cuda")]
153impl<T: CDatatype> BaseOps<T> for custos::CUDA {
154 fn add(&self, lhs: &Matrix<T, Self>, rhs: &Matrix<T, Self>) -> Matrix<T, Self> {
155 let buf = cu_ew(self, lhs, rhs, "+").unwrap();
156 (buf, lhs.dims()).into()
157 }
158
159 fn sub(&self, lhs: &Matrix<T, Self>, rhs: &Matrix<T, Self>) -> Matrix<T, Self> {
160 let buf = cu_ew(self, lhs, rhs, "-").unwrap();
161 (buf, lhs.dims()).into()
162 }
163
164 fn mul(&self, lhs: &Matrix<T, Self>, rhs: &Matrix<T, Self>) -> Matrix<T, Self> {
165 let buf = cu_ew(self, lhs, rhs, "*").unwrap();
166 (buf, lhs.dims()).into()
167 }
168
169 fn div(&self, lhs: &Matrix<T, Self>, rhs: &Matrix<T, Self>) -> Matrix<T, Self> {
170 let buf = cu_ew(self, lhs, rhs, "/").unwrap();
171 (buf, lhs.dims()).into()
172 }
173
174 }