1use custos::{
5 devices::opencl::cl_device::OpenCL,
6 opencl::{
7 api::{enqueue_write_buffer, wait_for_event},
8 AsClCvoidPtr,
9 },
10 prelude::CLBuffer,
11 CDatatype, Device, Error, GraphReturn, WriteBuf, CPU,
12};
13use std::fmt::Debug;
14
15use crate::{cl_scalar_op, cl_str_op, Matrix};
16
17#[inline]
18pub fn cl_str_op_mat<'a, T: CDatatype>(
19 device: &'a OpenCL,
20 x: &Matrix<T, OpenCL>,
21 op: &str,
22) -> Result<Matrix<'a, T, OpenCL>, Error> {
23 let mut out: CLBuffer<T> = device.retrieve(x.len(), x.node.idx);
24 cl_str_op(device, x, &mut out, op)?;
25 Ok((out, x.dims()).into())
26}
27
28pub fn cl_scalar_op_mat<'a, T: CDatatype>(
29 device: &'a OpenCL,
30 x: &Matrix<T, OpenCL>,
31 scalar: T,
32 op: &str,
33) -> Result<Matrix<'a, T, OpenCL>, Error> {
34 let out = cl_scalar_op(device, x, scalar, op)?;
35 Ok((out, x.dims()).into())
36}
37
38pub fn cl_write<T>(device: &OpenCL, x: &mut CLBuffer<T>, data: &[T]) {
39 let event = unsafe { enqueue_write_buffer(&device.queue(), x.ptr.ptr, data, true).unwrap() };
40 wait_for_event(event).unwrap();
41}
42
43impl<'a, T> AsClCvoidPtr for Matrix<'a, T, OpenCL> {
44 fn as_cvoid_ptr(&self) -> *const std::ffi::c_void {
45 self.ptr.ptr
46 }
47}
48
49impl<'a, T> AsClCvoidPtr for &Matrix<'a, T, OpenCL> {
50 fn as_cvoid_ptr(&self) -> *const std::ffi::c_void {
51 self.ptr.ptr
52 }
53}
54
55pub fn cpu_exec<'a, 'o, T, F>(
72 device: &'o OpenCL,
73 matrix: &Matrix<'a, T, OpenCL>,
74 f: F,
75) -> custos::Result<Matrix<'o, T, OpenCL>>
76where
77 F: for<'b> Fn(&'b CPU, &Matrix<T>) -> Matrix<'b, T>,
78 T: Copy + Default + Debug,
79{
80 #[cfg(not(feature = "realloc"))]
82 if device.unified_mem() {
83 let no_drop = f(
88 &device.cpu,
89 &Matrix::from((matrix.ptr.host_ptr, matrix.dims)),
90 );
91
92 let dims = no_drop.dims();
93 return unsafe {
95 custos::opencl::construct_buffer(device, no_drop.to_buf(), matrix.node.idx)
96 }
97 .map(|buf| (buf, dims).into());
98 }
99
100 let cpu = CPU::new();
101
102 #[cfg(feature = "realloc")]
104 if device.unified_mem() {
105 return Ok(Matrix::from((
106 device,
107 f(&cpu, &Matrix::from((matrix.ptr.host_ptr, matrix.dims))),
108 )));
109 }
110
111 let cpu_buf: Matrix<T> = Matrix::from((&cpu, matrix.dims(), matrix.read()));
113 let mat: Matrix<T> = f(&cpu, &cpu_buf);
114 let mut convert = Matrix::from((device, mat));
115 convert.node = device.graph().add(convert.len(), matrix.node.idx);
116 Ok(convert)
117}
118
119pub fn cpu_exec_mut<T, F>(
120 device: &OpenCL,
121 matrix: &mut Matrix<T, OpenCL>,
122 f: F,
123) -> custos::Result<()>
124where
125 F: Fn(&CPU, &mut Matrix<T>),
126 T: Copy + Default,
127{
128 let cpu = CPU::new();
129
130 if device.unified_mem() {
132 return Ok(f(
133 &cpu,
134 &mut Matrix::from((matrix.ptr.host_ptr, matrix.dims)),
135 ));
136 }
137
138 let mut cpu_matrix = Matrix::from((&cpu, matrix.dims(), matrix.read()));
140 f(&cpu, &mut cpu_matrix);
141 device.write(matrix, &cpu_matrix);
143 Ok(())
144}
145
146pub fn cpu_exec_lhs_rhs<'a, 'o, T, F>(
147 device: &'o OpenCL,
148 lhs: &Matrix<'a, T, OpenCL>,
149 rhs: &Matrix<'a, T, OpenCL>,
150 f: F,
151) -> custos::Result<Matrix<'o, T, OpenCL>>
152where
153 F: for<'b> Fn(&'b CPU, &Matrix<T>, &Matrix<T>) -> Matrix<'b, T>,
154 T: Copy + Default + Debug,
155{
156 let cpu = CPU::new();
157
158 #[cfg(not(feature = "realloc"))]
159 if device.unified_mem() {
160 let no_drop = f(
161 &device.cpu,
162 &Matrix::from((lhs.ptr.host_ptr, lhs.dims)),
163 &Matrix::from((rhs.ptr.host_ptr, rhs.dims)),
164 );
165
166 let no_drop_dims = no_drop.dims();
167 return unsafe {
169 custos::opencl::construct_buffer(device, no_drop.to_buf(), (lhs.node.idx, rhs.node.idx))
170 }
171 .map(|buf| (buf, no_drop_dims).into());
172 }
173
174 #[cfg(feature = "realloc")]
175 if device.unified_mem() {
176 return Ok(Matrix::from((
177 device,
178 f(
179 &cpu,
180 &Matrix::from((lhs.ptr.host_ptr, lhs.dims)),
181 &Matrix::from((rhs.ptr.host_ptr, rhs.dims)),
182 ),
183 )));
184 }
185
186 let lhs = Matrix::from((&cpu, lhs.dims(), lhs.read()));
188 let rhs = Matrix::from((&cpu, rhs.dims(), rhs.read()));
189
190 let mut convert = Matrix::from((device, f(&cpu, &lhs, &rhs)));
191 convert.node = device
192 .graph()
193 .add(convert.len(), (lhs.node.idx, rhs.node.idx));
194
195 Ok(convert)
196}
197
198pub fn cpu_exec_lhs_rhs_mut<T, F>(
199 device: &OpenCL,
200 lhs: &mut Matrix<T, OpenCL>,
201 rhs: &Matrix<T, OpenCL>,
202 f: F,
203) -> custos::Result<()>
204where
205 F: Fn(&CPU, &mut Matrix<T>, &Matrix<T>),
206 T: Copy + Default,
207{
208 let cpu = CPU::new();
209
210 if device.unified_mem() {
212 return Ok(f(
213 &cpu,
214 &mut Matrix::from((lhs.ptr.host_ptr, lhs.dims)),
215 &Matrix::from((rhs.ptr.host_ptr, rhs.dims)),
216 ));
217 }
218
219 let mut cpu_lhs = Matrix::from((&cpu, lhs.dims(), lhs.read()));
221 let cpu_rhs = Matrix::from((&cpu, rhs.dims(), rhs.read()));
222 f(&cpu, &mut cpu_lhs, &cpu_rhs);
223
224 device.write(lhs, &cpu_lhs);
226 Ok(())
227}
228
229pub fn cpu_exec_scalar<T, F>(device: &OpenCL, matrix: &Matrix<T, OpenCL>, f: F) -> T
230where
231 F: Fn(&CPU, &Matrix<T>) -> T,
232 T: Copy + Default,
233{
234 let cpu = CPU::new();
235
236 if device.unified_mem() {
237 return f(&cpu, &Matrix::from((matrix.ptr.host_ptr, matrix.dims)));
238 }
239
240 let cpu_buf = Matrix::from((&cpu, matrix.dims(), matrix.read()));
242
243 f(&cpu, &cpu_buf)
244}