1use std::mem::size_of;
10use std::ptr::null_mut;
11use std::sync::Arc;
12use std::sync::Mutex;
13use crate::Backend;
14use crate::BackendArray;
15use crate::Error;
16use crate::Result;
17use crate::mutex_lock;
18
19pub use opencl3::context::Context;
20pub use opencl3::device::Device;
21pub use opencl3::device::CL_DEVICE_TYPE_ACCELERATOR;
22pub use opencl3::device::CL_DEVICE_TYPE_ALL;
23pub use opencl3::device::CL_DEVICE_TYPE_CPU;
24pub use opencl3::device::CL_DEVICE_TYPE_CUSTOM;
25pub use opencl3::device::CL_DEVICE_TYPE_DEFAULT;
26pub use opencl3::device::CL_DEVICE_TYPE_GPU;
27pub use opencl3::device::cl_device_id;
28pub use opencl3::error_codes::ClError;
29pub use opencl3::platform::Platform;
30pub use opencl3::platform::get_platforms;
31
32use cl3::info_type::InfoType;
33use opencl3::command_queue::CommandQueue;
34use opencl3::device::CL_DEVICE_MAX_WORK_GROUP_SIZE;
35use opencl3::device::CL_DEVICE_PREFERRED_WORK_GROUP_SIZE_MULTIPLE;
36use opencl3::device::get_device_info;
37use opencl3::event::Event;
38use opencl3::kernel::ExecuteKernel;
39use opencl3::kernel::Kernel;
40use opencl3::memory::Buffer;
41use opencl3::memory::ClMem;
42use opencl3::memory::cl_mem;
43use opencl3::memory::CL_MEM_READ_WRITE;
44use opencl3::program::Program;
45use opencl3::types::CL_TRUE;
46
47const SOURCE: &'static str = include_str!("opencl.cl");
48
49#[derive(Debug)]
53pub struct ClBackendArray
54{
55 buffer: Arc<Mutex<Buffer<f32>>>,
56 len: usize,
57}
58
59struct ClInnerBackend
60{
61 context: Context,
62 command_queue: CommandQueue,
63 program: Program,
64 group_size_for_1d: usize,
65 group_size_for_2d: usize,
66}
67
68pub struct ClBackend
70{
71 inner: Mutex<ClInnerBackend>,
72}
73
74fn preferred_work_sizes(n: usize, m: usize, group_size_for_1d: usize, group_size_for_2d: usize, is_mul: bool) -> (usize, usize, usize, usize)
75{
76 if m == 1 && !is_mul {
77 let n2 = ((n + group_size_for_1d - 1) / group_size_for_1d) * group_size_for_1d;
78 (group_size_for_1d, 1, n2, 1)
79 } else if n == 1 && !is_mul {
80 let m2 = ((m + group_size_for_1d - 1) / group_size_for_1d) * group_size_for_1d;
81 (1, group_size_for_1d, 1, m2)
82 } else if is_mul {
83 let n2 = (((n + 3) / 4 + ((group_size_for_2d + 1) / 2) - 1) / ((group_size_for_2d + 1) / 2)) * ((group_size_for_2d + 1) / 2);
84 let m2 = (((m + 3) / 4 + ((group_size_for_2d + 1) / 2) - 1) / ((group_size_for_2d + 1) / 2)) * ((group_size_for_2d + 1) / 2);
85 ((group_size_for_2d + 1) / 2, (group_size_for_2d + 1) / 2, n2, m2)
86 } else {
87 let n2 = ((n + group_size_for_2d - 1) / group_size_for_2d) * group_size_for_2d;
88 let m2 = ((m + group_size_for_2d - 1) / group_size_for_2d) * group_size_for_2d;
89 (group_size_for_2d, group_size_for_2d, n2, m2)
90 }
91}
92
93impl ClBackend
94{
95 pub fn new() -> Result<ClBackend>
97 {
98 let platforms = match get_platforms() {
99 Ok(tmp_platforms) => tmp_platforms,
100 Err(err) => return Err(Error::OpenCl(err)),
101 };
102 if platforms.is_empty() {
103 return Err(Error::NoPlatform);
104 }
105 let device_ids = match platforms[0].get_devices(CL_DEVICE_TYPE_DEFAULT) {
106 Ok(tmp_device_ids) => tmp_device_ids,
107 Err(err) => return Err(Error::OpenCl(err)),
108 };
109 if device_ids.is_empty() {
110 return Err(Error::NoDevice);
111 }
112 let device = Device::new(device_ids[0]);
113 let context = match Context::from_device(&device) {
114 Ok(tmp_context) => tmp_context,
115 Err(err) => return Err(Error::OpenCl(err)),
116 };
117 Self::new_with_context(context)
118 }
119
120 pub fn new_with_context(context: Context) -> Result<ClBackend>
122 {
123 let command_queue = match CommandQueue::create_default_with_properties(&context, 0, 0) {
124 Ok(tmp_command_queue) => tmp_command_queue,
125 Err(err) => return Err(Error::OpenCl(err)),
126 };
127 let program = match Program::create_and_build_from_source(&context, SOURCE, "") {
128 Ok(tmp_program) => tmp_program,
129 Err(msg) => return Err(Error::Compilation(msg)),
130 };
131 let group_size_for_1d = match get_device_info(context.default_device(), CL_DEVICE_MAX_WORK_GROUP_SIZE) {
132 Ok(InfoType::Size(tmp_group_size_for_1d)) => tmp_group_size_for_1d,
133 _ => return Err(Error::InvalidDeviceInfoType),
134 };
135 let group_size_for_2d = match get_device_info(context.default_device(), CL_DEVICE_PREFERRED_WORK_GROUP_SIZE_MULTIPLE) {
136 Ok(InfoType::Size(tmp_group_size_for_2d)) => tmp_group_size_for_2d,
137 _ => return Err(Error::InvalidDeviceInfoType),
138 };
139 let inner = ClInnerBackend {
140 context,
141 command_queue,
142 program,
143 group_size_for_1d,
144 group_size_for_2d,
145 };
146 Ok(ClBackend { inner: Mutex::new(inner), })
147 }
148
149 fn check_and_enqueue_nd_range2<F, G>(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, f: F, g: G) -> Result<()>
150 where F: FnOnce(&ClBackendArray, &ClBackendArray) -> Result<()>,
151 G: FnOnce(&ClInnerBackend, &Kernel, cl_mem, cl_mem) -> Result<Event>
152 {
153 #[allow(unreachable_patterns)]
154 match (a, b) {
155 (BackendArray::OpenCl(a2), BackendArray::OpenCl(b2)) => {
156 f(a2, b2)?;
157 let inner_g = mutex_lock(&self.inner)?;
158 let kernel = match Kernel::create(&inner_g.program, kernel_name) {
159 Ok(tmp_kernel) => tmp_kernel,
160 Err(err) => return Err(Error::OpenCl(err)),
161 };
162 let event = if !Arc::ptr_eq(&a2.buffer, &b2.buffer) {
163 let a_buffer_g = mutex_lock(&a2.buffer)?;
164 let mut b_buffer_g = mutex_lock(&b2.buffer)?;
165 g(&*inner_g, &kernel, a_buffer_g.get(), b_buffer_g.get_mut())?
166 } else {
167 let mut a_buffer_g = mutex_lock(&a2.buffer)?;
168 g(&*inner_g, &kernel, a_buffer_g.get(), a_buffer_g.get_mut())?
169 };
170 match event.wait() {
171 Ok(()) => (),
172 Err(err) => return Err(Error::OpenCl(err)),
173 }
174 Ok(())
175 },
176 _ => Err(Error::InvalidBackendArray),
177 }
178 }
179
180 fn check_and_enqueue_nd_range3<F, G>(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, c: &BackendArray, f: F, g: G) -> Result<()>
181 where F: FnOnce(&ClBackendArray, &ClBackendArray, &ClBackendArray) -> Result<()>,
182 G: FnOnce(&ClInnerBackend, &Kernel, cl_mem, cl_mem, cl_mem) -> Result<Event>
183 {
184 #[allow(unreachable_patterns)]
185 match (a, b, c) {
186 (BackendArray::OpenCl(a2), BackendArray::OpenCl(b2), BackendArray::OpenCl(c2)) => {
187 f(a2, b2, c2)?;
188 let inner_g = mutex_lock(&self.inner)?;
189 let kernel = match Kernel::create(&inner_g.program, kernel_name) {
190 Ok(tmp_kernel) => tmp_kernel,
191 Err(err) => return Err(Error::OpenCl(err)),
192 };
193 let event = match (Arc::ptr_eq(&a2.buffer, &b2.buffer), Arc::ptr_eq(&a2.buffer, &c2.buffer), Arc::ptr_eq(&b2.buffer, &c2.buffer)) {
194 (false, false, false) => {
195 let a_buffer_g = mutex_lock(&a2.buffer)?;
196 let b_buffer_g = mutex_lock(&b2.buffer)?;
197 let mut c_buffer_g = mutex_lock(&c2.buffer)?;
198 g(&*inner_g, &kernel, a_buffer_g.get(), b_buffer_g.get(), c_buffer_g.get_mut())?
199 },
200 (true, false, false) => {
201 let a_buffer_g = mutex_lock(&a2.buffer)?;
202 let mut c_buffer_g = mutex_lock(&c2.buffer)?;
203 g(&*inner_g, &kernel, a_buffer_g.get(), a_buffer_g.get(), c_buffer_g.get_mut())?
204 },
205 (false, true, false) => {
206 let mut a_buffer_g = mutex_lock(&a2.buffer)?;
207 let b_buffer_g = mutex_lock(&b2.buffer)?;
208 g(&*inner_g, &kernel, a_buffer_g.get(), b_buffer_g.get(), a_buffer_g.get_mut())?
209 },
210 (false, false, true) => {
211 let a_buffer_g = mutex_lock(&a2.buffer)?;
212 let mut b_buffer_g = mutex_lock(&b2.buffer)?;
213 g(&*inner_g, &kernel, a_buffer_g.get(), b_buffer_g.get(), b_buffer_g.get_mut())?
214 },
215 _ => {
216 let mut a_buffer_g = mutex_lock(&a2.buffer)?;
217 g(&*inner_g, &kernel, a_buffer_g.get(), a_buffer_g.get(), a_buffer_g.get_mut())?
218 },
219 };
220 match event.wait() {
221 Ok(()) => (),
222 Err(err) => return Err(Error::OpenCl(err)),
223 }
224 Ok(())
225 },
226 _ => Err(Error::InvalidBackendArray),
227 }
228 }
229
230 fn check_and_enqueue_nd_range_for_fun(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
231 {
232 self.check_and_enqueue_nd_range2(kernel_name, a, b, |a2, b2| {
233 if a2.len != n * m {
234 return Err(Error::BackendArrayElemCount(a2.len, n * m));
235 }
236 if b2.len != n * m {
237 return Err(Error::BackendArrayElemCount(b2.len, n * m));
238 }
239 Ok(())
240 }, |inner, kernel, a_mem, b_mem| {
241 let n2 = n as u64;
242 let m2 = m as u64;
243 let (n3, m3, n4, m4) = preferred_work_sizes(n, m, inner.group_size_for_1d, inner.group_size_for_2d, false);
244 unsafe {
245 let res = ExecuteKernel::new(kernel)
246 .set_arg(&a_mem)
247 .set_arg(&b_mem)
248 .set_arg(&n2)
249 .set_arg(&m2)
250 .set_local_work_sizes(&[n3, m3])
251 .set_global_work_sizes(&[n4, m4])
252 .enqueue_nd_range(&inner.command_queue);
253 match res {
254 Ok(event) => Ok(event),
255 Err(err) => Err(Error::OpenCl(err)),
256 }
257 }
258 })
259 }
260
261 fn check_and_enqueue_nd_range_for_op(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
262 {
263 self.check_and_enqueue_nd_range3(kernel_name, a, b, c, |a2, b2, c2| {
264 if a2.len != n * m {
265 return Err(Error::BackendArrayElemCount(a2.len, n * m));
266 }
267 if b2.len != n * m {
268 return Err(Error::BackendArrayElemCount(b2.len, n * m));
269 }
270 if c2.len != n * m {
271 return Err(Error::BackendArrayElemCount(c2.len, n * m));
272 }
273 Ok(())
274 }, |inner, kernel, a_mem, b_mem, c_mem| {
275 let n2 = n as u64;
276 let m2 = m as u64;
277 let (n3, m3, n4, m4) = preferred_work_sizes(n, m, inner.group_size_for_1d, inner.group_size_for_2d, false);
278 unsafe {
279 let res = ExecuteKernel::new(kernel)
280 .set_arg(&a_mem)
281 .set_arg(&b_mem)
282 .set_arg(&c_mem)
283 .set_arg(&n2)
284 .set_arg(&m2)
285 .set_local_work_sizes(&[n3, m3])
286 .set_global_work_sizes(&[n4, m4])
287 .enqueue_nd_range(&inner.command_queue);
288 match res {
289 Ok(event) => Ok(event),
290 Err(err) => Err(Error::OpenCl(err)),
291 }
292 }
293 })
294 }
295
296 fn check_and_enqueue_nd_range_for_mul(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
297 {
298 self.check_and_enqueue_nd_range3(kernel_name, a, b, c, |a2, b2, c2| {
299 if a2.len != n * l {
300 return Err(Error::BackendArrayElemCount(a2.len, n * l));
301 }
302 if b2.len != l * m {
303 return Err(Error::BackendArrayElemCount(b2.len, l * m));
304 }
305 if c2.len != n * m {
306 return Err(Error::BackendArrayElemCount(c2.len, n * m));
307 }
308 Ok(())
309 }, |inner, kernel, a_mem, b_mem, c_mem| {
310 let n2 = n as u64;
311 let m2 = m as u64;
312 let l2 = l as u64;
313 let (n3, m3, n4, m4) = preferred_work_sizes(n, m, inner.group_size_for_1d, inner.group_size_for_2d, true);
314 unsafe {
315 let res = ExecuteKernel::new(kernel)
316 .set_arg(&a_mem)
317 .set_arg(&b_mem)
318 .set_arg(&c_mem)
319 .set_arg_local_buffer(n3 * m3 * 4 * size_of::<f32>())
320 .set_arg_local_buffer(n3 * m3 * 4 * size_of::<f32>())
321 .set_arg(&n2)
322 .set_arg(&m2)
323 .set_arg(&l2)
324 .set_local_work_sizes(&[n3, m3])
325 .set_global_work_sizes(&[n4, m4])
326 .enqueue_nd_range(&inner.command_queue);
327 match res {
328 Ok(event) => Ok(event),
329 Err(err) => Err(Error::OpenCl(err)),
330 }
331 }
332 })
333 }
334
335 fn check_and_enqueue_nd_range_for_scalar(&self, kernel_name: &str, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
336 {
337 self.check_and_enqueue_nd_range2(kernel_name, a, c, |a2, c2| {
338 if a2.len != n * m {
339 return Err(Error::BackendArrayElemCount(a2.len, n * m));
340 }
341 if c2.len != n * m {
342 return Err(Error::BackendArrayElemCount(c2.len, n * m));
343 }
344 Ok(())
345 }, |inner, kernel, a_mem, c_mem| {
346 let n2 = n as u64;
347 let m2 = m as u64;
348 let (n3, m3, n4, m4) = preferred_work_sizes(n, m, inner.group_size_for_1d, inner.group_size_for_2d, false);
349 unsafe {
350 let res = ExecuteKernel::new(kernel)
351 .set_arg(&a_mem)
352 .set_arg(&b)
353 .set_arg(&c_mem)
354 .set_arg(&n2)
355 .set_arg(&m2)
356 .set_local_work_sizes(&[n3, m3])
357 .set_global_work_sizes(&[n4, m4])
358 .enqueue_nd_range(&inner.command_queue);
359 match res {
360 Ok(event) => Ok(event),
361 Err(err) => Err(Error::OpenCl(err)),
362 }
363 }
364 })
365 }
366
367 fn check_and_enqueue_nd_range_for_fun_and_tiles(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
368 {
369 self.check_and_enqueue_nd_range2(kernel_name, a, b, |a2, b2| {
370 if a2.len != n * m {
371 return Err(Error::BackendArrayElemCount(a2.len, n * m));
372 }
373 if b2.len != n * m {
374 return Err(Error::BackendArrayElemCount(b2.len, n * m));
375 }
376 Ok(())
377 }, |inner, kernel, a_mem, b_mem| {
378 let n2 = n as u64;
379 let m2 = m as u64;
380 let (n3, m3, n4, m4) = preferred_work_sizes(n, m, inner.group_size_for_1d, inner.group_size_for_2d, false);
381 unsafe {
382 let res = ExecuteKernel::new(kernel)
383 .set_arg(&a_mem)
384 .set_arg(&b_mem)
385 .set_arg_local_buffer(n3 * m3 *size_of::<f32>())
386 .set_arg(&n2)
387 .set_arg(&m2)
388 .set_local_work_sizes(&[n3, m3])
389 .set_global_work_sizes(&[n4, m4])
390 .enqueue_nd_range(&inner.command_queue);
391 match res {
392 Ok(event) => Ok(event),
393 Err(err) => Err(Error::OpenCl(err)),
394 }
395 }
396 })
397 }
398
399 fn check_and_enqueue_nd_range_for_repeat_col(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
400 {
401 self.check_and_enqueue_nd_range2(kernel_name, a, b, |a2, b2| {
402 if a2.len != n {
403 return Err(Error::BackendArrayElemCount(a2.len, n));
404 }
405 if b2.len != n * m {
406 return Err(Error::BackendArrayElemCount(b2.len, n * m));
407 }
408 Ok(())
409 }, |inner, kernel, a_mem, b_mem| {
410 let n2 = n as u64;
411 let m2 = m as u64;
412 let (n3, m3, n4, m4) = preferred_work_sizes(n, m, inner.group_size_for_1d, inner.group_size_for_2d, false);
413 unsafe {
414 let res = ExecuteKernel::new(kernel)
415 .set_arg(&a_mem)
416 .set_arg(&b_mem)
417 .set_arg(&n2)
418 .set_arg(&m2)
419 .set_local_work_sizes(&[n3, m3])
420 .set_global_work_sizes(&[n4, m4])
421 .enqueue_nd_range(&inner.command_queue);
422 match res {
423 Ok(event) => Ok(event),
424 Err(err) => Err(Error::OpenCl(err)),
425 }
426 }
427 })
428 }
429
430 fn check_and_enqueue_nd_range_for_repeat_row(&self, kernel_name: &str, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
431 {
432 self.check_and_enqueue_nd_range2(kernel_name, a, b, |a2, b2| {
433 if a2.len != m {
434 return Err(Error::BackendArrayElemCount(a2.len, m));
435 }
436 if b2.len != n * m {
437 return Err(Error::BackendArrayElemCount(b2.len, n * m));
438 }
439 Ok(())
440 }, |inner, kernel, a_mem, b_mem| {
441 let n2 = n as u64;
442 let m2 = m as u64;
443 let (n3, m3, n4, m4) = preferred_work_sizes(n, m, inner.group_size_for_1d, inner.group_size_for_2d, false);
444 unsafe {
445 let res = ExecuteKernel::new(kernel)
446 .set_arg(&a_mem)
447 .set_arg(&b_mem)
448 .set_arg(&n2)
449 .set_arg(&m2)
450 .set_local_work_sizes(&[n3, m3])
451 .set_global_work_sizes(&[n4, m4])
452 .enqueue_nd_range(&inner.command_queue);
453 match res {
454 Ok(event) => Ok(event),
455 Err(err) => Err(Error::OpenCl(err)),
456 }
457 }
458 })
459 }
460}
461
462impl Backend for ClBackend
463{
464 fn name(&self) -> &'static str
465 { "OpenCL" }
466
467 fn has_cublas(&self) -> bool
468 { false }
469
470 unsafe fn alloc(&self, n: usize) -> Result<BackendArray>
471 {
472 let inner_g = mutex_lock(&self.inner)?;
473 let buffer: Buffer<f32> = match Buffer::create(&inner_g.context, CL_MEM_READ_WRITE, n, null_mut()) {
474 Ok(tmp_buffer) => tmp_buffer,
475 Err(err) => return Err(Error::OpenCl(err)),
476 };
477 let cl_array = ClBackendArray { buffer: Arc::new(Mutex::new(buffer)), len: n, };
478 Ok(BackendArray::OpenCl(cl_array))
479 }
480
481 fn alloc_and_store_zeros(&self, n: usize) -> Result<BackendArray>
482 {
483 let inner_g = mutex_lock(&self.inner)?;
484 let mut buffer: Buffer<f32> = match unsafe { Buffer::create(&inner_g.context, CL_MEM_READ_WRITE, n, null_mut()) } {
485 Ok(tmp_buffer) => tmp_buffer,
486 Err(err) => return Err(Error::OpenCl(err)),
487 };
488 let event = match unsafe { inner_g.command_queue.enqueue_fill_buffer(&mut buffer, &[0.0f32], 0, n * size_of::<f32>(), &[]) } {
489 Ok(tmp_event) => tmp_event,
490 Err(err) => return Err(Error::OpenCl(err)),
491 };
492 match event.wait() {
493 Ok(()) => (),
494 Err(err) => return Err(Error::OpenCl(err)),
495 }
496 let cl_array = ClBackendArray { buffer: Arc::new(Mutex::new(buffer)), len: n, };
497 Ok(BackendArray::OpenCl(cl_array))
498 }
499
500 fn alloc_and_store(&self, elems: &[f32]) -> Result<BackendArray>
501 {
502 let inner_g = mutex_lock(&self.inner)?;
503 let mut buffer: Buffer<f32> = match unsafe { Buffer::create(&inner_g.context, CL_MEM_READ_WRITE, elems.len(), null_mut()) } {
504 Ok(tmp_buffer) => tmp_buffer,
505 Err(err) => return Err(Error::OpenCl(err)),
506 };
507 let event = match unsafe { inner_g.command_queue.enqueue_write_buffer(&mut buffer, CL_TRUE, 0, elems, &[]) } {
508 Ok(tmp_event) => tmp_event,
509 Err(err) => return Err(Error::OpenCl(err)),
510 };
511 match event.wait() {
512 Ok(()) => (),
513 Err(err) => return Err(Error::OpenCl(err)),
514 }
515 let cl_array = ClBackendArray { buffer: Arc::new(Mutex::new(buffer)), len: elems.len(), };
516 Ok(BackendArray::OpenCl(cl_array))
517 }
518
519 fn load(&self, a: &BackendArray, elems: &mut [f32]) -> Result<()>
520 {
521 #[allow(unreachable_patterns)]
522 match a {
523 BackendArray::OpenCl(a2) => {
524 if a2.len != elems.len() {
525 return Err(Error::BackendArrayElemCount(a2.len, elems.len()));
526 }
527 let inner_g = mutex_lock(&self.inner)?;
528 let a_buffer_g = mutex_lock(&a2.buffer)?;
529 let event = match unsafe { inner_g.command_queue.enqueue_read_buffer(&*a_buffer_g, CL_TRUE, 0, elems, &[]) } {
530 Ok(tmp_event) => tmp_event,
531 Err(err) => return Err(Error::OpenCl(err)),
532 };
533 match event.wait() {
534 Ok(()) => (),
535 Err(err) => return Err(Error::OpenCl(err)),
536 }
537 },
538 _ => return Err(Error::InvalidBackendArray),
539 }
540 Ok(())
541 }
542
543 fn store(&self, a: &BackendArray, elems: &[f32]) -> Result<()>
544 {
545 #[allow(unreachable_patterns)]
546 match a {
547 BackendArray::OpenCl(a2) => {
548 if a2.len != elems.len() {
549 return Err(Error::BackendArrayElemCount(a2.len, elems.len()));
550 }
551 let inner_g = mutex_lock(&self.inner)?;
552 let mut a_buffer_g = mutex_lock(&a2.buffer)?;
553 let event = match unsafe { inner_g.command_queue.enqueue_write_buffer(&mut *a_buffer_g, CL_TRUE, 0, elems, &[]) } {
554 Ok(tmp_event) => tmp_event,
555 Err(err) => return Err(Error::OpenCl(err)),
556 };
557 match event.wait() {
558 Ok(()) => (),
559 Err(err) => return Err(Error::OpenCl(err)),
560 }
561 },
562 _ => return Err(Error::InvalidBackendArray),
563 }
564 Ok(())
565 }
566
567 fn copy(&self, a: &BackendArray, b: &BackendArray) -> Result<()>
568 {
569 #[allow(unreachable_patterns)]
570 match (a, b) {
571 (BackendArray::OpenCl(a2), BackendArray::OpenCl(b2)) => {
572 if Arc::ptr_eq(&a2.buffer, &b2.buffer) {
573 return Ok(());
574 }
575 if a2.len != b2.len {
576 return Err(Error::TwoBackendArrayElemCounts(a2.len, b2.len));
577 }
578 let inner_g = mutex_lock(&self.inner)?;
579 let a_buffer_g = mutex_lock(&a2.buffer)?;
580 let mut b_buffer_g = mutex_lock(&b2.buffer)?;
581 let event = match unsafe { inner_g.command_queue.enqueue_copy_buffer(&*a_buffer_g, &mut *b_buffer_g, 0, 0, a2.len * size_of::<f32>(), &[]) } {
582 Ok(tmp_event) => tmp_event,
583 Err(err) => return Err(Error::OpenCl(err)),
584 };
585 match event.wait() {
586 Ok(()) => (),
587 Err(err) => return Err(Error::OpenCl(err)),
588 }
589 },
590 _ => return Err(Error::InvalidBackendArray),
591 }
592 Ok(())
593 }
594
595 fn transpose_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
596 { self.check_and_enqueue_nd_range_for_fun("transpose_a", a, b, n, m) }
597
598 fn add_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
599 { self.check_and_enqueue_nd_range_for_op("add_a_b", a, b, c, n, m) }
600
601 fn add_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
602 { self.check_and_enqueue_nd_range_for_op("add_at_b", a, b, c, n, m) }
603
604 fn add_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
605 { self.check_and_enqueue_nd_range_for_op("add_a_bt", a, b, c, n, m) }
606
607 fn add_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
608 { self.check_and_enqueue_nd_range_for_op("add_at_bt", a, b, c, n, m) }
609
610 fn sub_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
611 { self.check_and_enqueue_nd_range_for_op("sub_a_b", a, b, c, n, m) }
612
613 fn sub_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
614 { self.check_and_enqueue_nd_range_for_op("sub_at_b", a, b, c, n, m) }
615
616 fn sub_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
617 { self.check_and_enqueue_nd_range_for_op("sub_a_bt", a, b, c, n, m) }
618
619 fn sub_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
620 { self.check_and_enqueue_nd_range_for_op("sub_at_bt", a, b, c, n, m) }
621
622 fn mul_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
623 { self.check_and_enqueue_nd_range_for_mul("mul_a_b", a, b, c, n, m, l) }
624
625 fn mul_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
626 { self.check_and_enqueue_nd_range_for_mul("mul_at_b", a, b, c, n, m, l) }
627
628 fn mul_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
629 { self.check_and_enqueue_nd_range_for_mul("mul_a_bt", a, b, c, n, m, l) }
630
631 fn mul_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>
632 { self.check_and_enqueue_nd_range_for_mul("mul_at_bt", a, b, c, n, m, l) }
633
634 fn mul_a_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
635 { self.check_and_enqueue_nd_range_for_op("mul_a_b_for_elems", a, b, c, n, m) }
636
637 fn mul_at_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
638 { self.check_and_enqueue_nd_range_for_op("mul_at_b_for_elems", a, b, c, n, m) }
639
640 fn mul_a_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
641 { self.check_and_enqueue_nd_range_for_op("mul_a_bt_for_elems", a, b, c, n, m) }
642
643 fn mul_at_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
644 { self.check_and_enqueue_nd_range_for_op("mul_at_bt_for_elems", a, b, c, n, m) }
645
646 fn div_a_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
647 { self.check_and_enqueue_nd_range_for_op("div_a_b_for_elems", a, b, c, n, m) }
648
649 fn div_at_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
650 { self.check_and_enqueue_nd_range_for_op("div_at_b_for_elems", a, b, c, n, m) }
651
652 fn div_a_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
653 { self.check_and_enqueue_nd_range_for_op("div_a_bt_for_elems", a, b, c, n, m) }
654
655 fn div_at_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>
656 { self.check_and_enqueue_nd_range_for_op("div_at_bt_for_elems", a, b, c, n, m) }
657
658 fn add_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
659 { self.check_and_enqueue_nd_range_for_scalar("add_a_b_for_scalar", a, b, c, n, m) }
660
661 fn add_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
662 { self.check_and_enqueue_nd_range_for_scalar("add_at_b_for_scalar", a, b, c, n, m) }
663
664 fn sub_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
665 { self.check_and_enqueue_nd_range_for_scalar("sub_a_b_for_scalar", a, b, c, n, m) }
666
667 fn sub_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
668 { self.check_and_enqueue_nd_range_for_scalar("sub_at_b_for_scalar", a, b, c, n, m) }
669
670 fn rsub_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
671 { self.check_and_enqueue_nd_range_for_scalar("rsub_a_b_for_scalar", a, b, c, n, m) }
672
673 fn rsub_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
674 { self.check_and_enqueue_nd_range_for_scalar("rsub_at_b_for_scalar", a, b, c, n, m) }
675
676 fn mul_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
677 { self.check_and_enqueue_nd_range_for_scalar("mul_a_b_for_scalar", a, b, c, n, m) }
678
679 fn mul_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
680 { self.check_and_enqueue_nd_range_for_scalar("mul_at_b_for_scalar", a, b, c, n, m) }
681
682 fn div_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
683 { self.check_and_enqueue_nd_range_for_scalar("div_a_b_for_scalar", a, b, c, n, m) }
684
685 fn div_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
686 { self.check_and_enqueue_nd_range_for_scalar("div_at_b_for_scalar", a, b, c, n, m) }
687
688 fn rdiv_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
689 { self.check_and_enqueue_nd_range_for_scalar("rdiv_a_b_for_scalar", a, b, c, n, m) }
690
691 fn rdiv_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>
692 { self.check_and_enqueue_nd_range_for_scalar("rdiv_at_b_for_scalar", a, b, c, n, m) }
693
694 fn sigmoid_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
695 { self.check_and_enqueue_nd_range_for_fun("sigmoid_a", a, b, n, m) }
696
697 fn sigmoid_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
698 { self.check_and_enqueue_nd_range_for_fun("sigmoid_at", a, b, n, m) }
699
700 fn tanh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
701 { self.check_and_enqueue_nd_range_for_fun("tanh_a", a, b, n, m) }
702
703 fn tanh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
704 { self.check_and_enqueue_nd_range_for_fun("tanh_at", a, b, n, m) }
705
706 fn softmax_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
707 { self.check_and_enqueue_nd_range_for_fun_and_tiles("softmax_a", a, b, n, m) }
708
709 fn softmax_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
710 { self.check_and_enqueue_nd_range_for_fun_and_tiles("softmax_at", a, b, n, m) }
711
712 fn repeat_col_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
713 { self.check_and_enqueue_nd_range_for_repeat_col("repeat_col_a", a, b, n, m) }
714
715 fn repeat_row_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>
716 { self.check_and_enqueue_nd_range_for_repeat_row("repeat_row_a", a, b, n, m) }
717}
718
719#[cfg(test)]
720mod tests;