use crate::thinfn::ThinFn;
use crate::{core::*, non_null_const};
use blaze_proc::docfg;
use opencl_sys::*;
use std::any::Any;
use std::ffi::c_void;
use std::marker::PhantomData;
use std::panic::resume_unwind;
use std::sync::mpsc::TryRecvError;
use std::time::{Duration, SystemTime};
use std::{mem::MaybeUninit, ptr::NonNull};
use super::ext::NoopEvent;
use super::{CommandType, Consumer, Event, EventStatus, ProfilingInfo};
#[derive(Debug, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct RawEvent(NonNull<c_void>);
impl RawEvent {
#[inline(always)]
pub const unsafe fn from_id_unchecked(inner: cl_event) -> Self {
Self(NonNull::new_unchecked(inner))
}
#[inline(always)]
pub const unsafe fn from_id(inner: cl_event) -> Option<Self> {
match non_null_const(inner) {
Some(x) => Some(Self(x)),
None => None,
}
}
#[inline(always)]
pub const fn id(&self) -> cl_event {
self.0.as_ptr()
}
#[inline(always)]
pub unsafe fn retain(&self) -> Result<()> {
tri!(clRetainEvent(self.id()));
Ok(())
}
#[inline(always)]
pub fn join_by_ref(&self) -> Result<()> {
let slice = &[self.0.as_ptr()];
unsafe { tri!(clWaitForEvents(1, slice.as_ptr())) }
Ok(())
}
#[inline(always)]
pub fn join_all_by_ref(v: &[RawEvent]) -> Result<()> {
let len = u32::try_from(v.len()).unwrap();
unsafe { tri!(clWaitForEvents(len, v.as_ptr().cast())) }
Ok(())
}
}
impl RawEvent {
#[inline(always)]
pub fn into_event(self) -> NoopEvent {
self.into()
}
#[inline(always)]
pub fn join_with_nanos_by_ref(self) -> Result<ProfilingInfo<u64>> {
self.join_by_ref()?;
self.profiling_nanos()
}
#[inline(always)]
pub fn join_with_time_by_ref(self) -> Result<ProfilingInfo<SystemTime>> {
self.join_by_ref()?;
self.profiling_time()
}
#[inline(always)]
pub fn join_with_duration_by_ref(self) -> Result<Duration> {
self.join_by_ref()?;
self.duration()
}
#[inline(always)]
pub fn join_unwrap_by_ref(self) {
self.join_by_ref().unwrap()
}
#[inline(always)]
pub fn ty(&self) -> Result<CommandType> {
self.get_info(CL_EVENT_COMMAND_TYPE)
}
#[inline(always)]
pub fn status(&self) -> Result<EventStatus> {
let int: i32 = self.get_info(CL_EVENT_COMMAND_EXECUTION_STATUS)?;
EventStatus::try_from(int)
}
#[inline(always)]
pub fn command_queue(&self) -> Result<Option<RawCommandQueue>> {
match self.get_info(CL_EVENT_COMMAND_QUEUE) {
Ok(x) => unsafe { Ok(RawCommandQueue::from_id(x)) },
Err(e) => Err(e),
}
}
#[inline(always)]
pub fn reference_count(&self) -> Result<u32> {
self.get_info(opencl_sys::CL_EVENT_REFERENCE_COUNT)
}
#[docfg(feature = "cl1_1")]
#[inline(always)]
pub fn raw_context(&self) -> Result<crate::prelude::RawContext> {
let ctx = self.get_info::<cl_context>(CL_EVENT_CONTEXT)?;
unsafe {
tri!(clRetainContext(ctx));
Ok(crate::prelude::RawContext::from_id_unchecked(ctx))
}
}
#[inline(always)]
pub fn profiling_nanos(&self) -> Result<ProfilingInfo<u64>> {
ProfilingInfo::<u64>::new(self)
}
#[inline(always)]
pub fn profiling_time(&self) -> Result<ProfilingInfo<SystemTime>> {
ProfilingInfo::<SystemTime>::new(self)
}
#[inline(always)]
pub fn duration(&self) -> Result<Duration> {
let nanos = self.profiling_nanos()?;
Ok(nanos.duration())
}
#[inline(always)]
pub fn is_queued(&self) -> bool {
self.status().as_ref().map_or(true, EventStatus::is_queued)
}
#[inline(always)]
pub fn has_submited(&self) -> bool {
self.status()
.as_ref()
.map_or(true, EventStatus::has_submitted)
}
#[inline(always)]
pub fn has_started_running(&self) -> bool {
self.status()
.as_ref()
.map_or(true, EventStatus::has_started_running)
}
#[inline(always)]
pub fn has_completed(&self) -> bool {
self.status()
.as_ref()
.map_or(true, EventStatus::has_completed)
}
#[inline(always)]
pub fn get_info<T: Copy>(&self, id: cl_event_info) -> Result<T> {
let mut result = MaybeUninit::<T>::uninit();
unsafe {
tri!(clGetEventInfo(
self.id(),
id,
core::mem::size_of::<T>(),
result.as_mut_ptr().cast(),
core::ptr::null_mut()
));
Ok(result.assume_init())
}
}
}
#[docfg(feature = "cl1_1")]
impl RawEvent {
#[inline(always)]
pub fn on_submit<T: 'static + Send>(
&self,
f: impl 'static + Send + FnOnce(RawEvent, Result<EventStatus>) -> T,
) -> Result<CallbackHandle<T>> {
self.on_status(EventStatus::Submitted, f)
}
#[inline(always)]
pub fn on_run<T: 'static + Send>(
&self,
f: impl 'static + Send + FnOnce(RawEvent, Result<EventStatus>) -> T,
) -> Result<CallbackHandle<T>> {
self.on_status(EventStatus::Running, f)
}
#[inline(always)]
pub fn on_complete<T: 'static + Send>(
&self,
f: impl 'static + Send + FnOnce(RawEvent, Result<EventStatus>) -> T,
) -> Result<CallbackHandle<T>> {
self.on_status(EventStatus::Complete, f)
}
pub fn on_status<T: 'static + Send>(
&self,
status: EventStatus,
f: impl 'static + Send + FnOnce(RawEvent, Result<EventStatus>) -> T,
) -> Result<CallbackHandle<T>> {
let (send, recv) = std::sync::mpsc::sync_channel::<_>(1);
let data = std::sync::Arc::new(CallbackHandleData {
#[cfg(feature = "cl1_1")]
flag: once_cell::sync::OnceCell::new(),
#[cfg(feature = "futures")]
waker: futures::task::AtomicWaker::new(),
});
let my_data = data.clone();
self.on_status_silent(status, move |evt, status| {
let f = std::panic::AssertUnwindSafe(|| f(evt, status.clone()));
match send.send(std::panic::catch_unwind(f)) {
Ok(_) => {
#[cfg(feature = "cl1_1")]
if let Some(flag) = my_data.flag.get_or_init(|| None) {
flag.try_mark(status.err().map(|x| x.ty)).unwrap();
}
#[cfg(feature = "futures")]
my_data.waker.wake();
}
Err(_) => {}
}
})?;
return Ok(CallbackHandle {
recv,
data,
phtm: PhantomData,
});
}
#[inline(always)]
pub fn on_submit_silent(
&self,
f: impl 'static + FnOnce(RawEvent, Result<EventStatus>) + Send,
) -> Result<()> {
self.on_status_silent(EventStatus::Submitted, f)
}
#[inline(always)]
pub fn on_run_silent(
&self,
f: impl 'static + FnOnce(RawEvent, Result<EventStatus>) + Send,
) -> Result<()> {
self.on_status_silent(EventStatus::Running, f)
}
#[inline(always)]
pub fn on_complete_silent(
&self,
f: impl 'static + FnOnce(RawEvent, Result<EventStatus>) + Send,
) -> Result<()> {
self.on_status_silent(EventStatus::Complete, f)
}
#[inline(always)]
pub fn on_status_silent(
&self,
status: EventStatus,
f: impl 'static + FnOnce(RawEvent, Result<EventStatus>) + Send,
) -> Result<()> {
let r#fn =
ThinFn::<dyn 'static + Send + FnOnce(RawEvent, Result<EventStatus>)>::new_once(f);
let user_data = ThinFn::into_raw(r#fn);
unsafe {
if let Err(e) = self.on_status_raw(status, event_listener, user_data.cast()) {
let _ = ThinFn::<dyn 'static + FnOnce(RawEvent, Result<EventStatus>)>::from_raw(
user_data,
); return Err(e);
}
tri!(clRetainEvent(self.id()));
return Ok(());
}
}
#[inline(always)]
pub unsafe fn on_submit_raw(
&self,
f: unsafe extern "C" fn(
event: cl_event,
event_command_status: cl_int,
user_data: *mut c_void,
),
user_data: *mut c_void,
) -> Result<()> {
Self::on_status_raw(&self, EventStatus::Submitted, f, user_data)
}
#[inline(always)]
pub unsafe fn on_run_raw(
&self,
f: unsafe extern "C" fn(
event: cl_event,
event_command_status: cl_int,
user_data: *mut c_void,
),
user_data: *mut c_void,
) -> Result<()> {
Self::on_status_raw(&self, EventStatus::Running, f, user_data)
}
#[inline(always)]
pub unsafe fn on_complete_raw(
&self,
f: unsafe extern "C" fn(
event: cl_event,
event_command_status: cl_int,
user_data: *mut c_void,
),
user_data: *mut c_void,
) -> Result<()> {
Self::on_status_raw(&self, EventStatus::Complete, f, user_data)
}
#[inline(always)]
pub unsafe fn on_status_raw(
&self,
status: EventStatus,
f: unsafe extern "C" fn(
event: cl_event,
event_command_status: cl_int,
user_data: *mut c_void,
),
user_data: *mut c_void,
) -> Result<()> {
tri!(opencl_sys::clSetEventCallback(
self.id(),
status as i32,
Some(f),
user_data
));
return Ok(());
}
}
impl Into<NoopEvent> for RawEvent {
#[inline(always)]
fn into(self) -> NoopEvent {
Event::new_noop(self)
}
}
impl Clone for RawEvent {
#[inline(always)]
fn clone(&self) -> Self {
unsafe { tri_panic!(clRetainEvent(self.id())) }
Self(self.0)
}
}
impl Drop for RawEvent {
#[inline(always)]
fn drop(&mut self) {
unsafe { tri_panic!(clReleaseEvent(self.id())) }
}
}
unsafe impl Send for RawEvent {}
unsafe impl Sync for RawEvent {}
pub(crate) unsafe extern "C" fn event_listener(
event: cl_event,
event_command_status: cl_int,
user_data: *mut c_void,
) {
let f = ThinFn::<dyn 'static + Send + FnOnce(RawEvent, Result<EventStatus>)>::from_raw(
user_data.cast(),
);
let event = RawEvent::from_id_unchecked(event);
let status = EventStatus::try_from(event_command_status);
f.call_once((event, status))
}
pub type CallbackHandle<T> = ScopedCallbackHandle<'static, T>;
pub type CallbackConsumer<T> = ScopedCallbackConsumer<'static, T>;
#[cfg(any(feature = "cl1_1", feature = "futures"))]
pub(crate) struct CallbackHandleData {
#[cfg(feature = "cl1_1")]
pub(crate) flag: once_cell::sync::OnceCell<Option<super::FlagEvent>>,
#[cfg(feature = "futures")]
pub(crate) waker: futures::task::AtomicWaker,
}
pub struct ScopedCallbackHandle<'a, T> {
pub(crate) recv: std::sync::mpsc::Receiver<std::thread::Result<T>>,
#[cfg(any(feature = "cl1_1", feature = "futures"))]
pub(crate) data: std::sync::Arc<CallbackHandleData>,
pub(crate) phtm: PhantomData<&'a mut &'a ()>,
}
impl<'a, T> ScopedCallbackHandle<'a, T> {
#[docfg(feature = "cl1_1")]
#[inline(always)]
pub fn into_event(self) -> Result<Event<ScopedCallbackConsumer<'a, T>>> {
self.into_event_in(crate::context::Global)
}
#[docfg(feature = "cl1_1")]
pub fn into_event_in<C: crate::prelude::Context>(
self,
ctx: C,
) -> Result<Event<ScopedCallbackConsumer<'a, T>>> {
let flag = super::FlagEvent::new_in(ctx.as_raw())?;
let sub = flag.subscribe();
match self.data.flag.try_insert(Some(flag)) {
Ok(_) => {}
Err((_, flag)) => unsafe {
flag.unwrap_unchecked().try_mark(None)?;
},
}
return Ok(Event::new(sub, ScopedCallbackConsumer(self)));
}
#[inline]
pub fn join(self) -> std::thread::Result<T> {
return match self.recv.recv() {
Ok(x) => x,
Err(_) => panic!("Handle already joined"),
};
}
#[inline]
pub fn join_unwrap(self) -> T {
return match self.recv.recv() {
Ok(Ok(x)) => x,
Ok(Err(e)) => resume_unwind(e),
Err(_) => panic!("Handle already joined"),
};
}
#[inline]
pub fn try_join(self) -> ::core::result::Result<std::thread::Result<T>, Self> {
return match self.recv.try_recv() {
Ok(x) => Ok(x),
Err(TryRecvError::Empty) => Err(self),
Err(TryRecvError::Disconnected) => panic!("Handle already joined"),
};
}
#[inline]
pub fn try_join_unwrap(self) -> ::core::result::Result<T, Self> {
return match self.recv.try_recv() {
Ok(Ok(x)) => Ok(x),
Ok(Err(e)) => resume_unwind(e),
Err(TryRecvError::Empty) => Err(self),
Err(TryRecvError::Disconnected) => panic!("Handle already joined"),
};
}
}
#[repr(transparent)]
pub struct ScopedCallbackConsumer<'a, T>(ScopedCallbackHandle<'a, T>);
impl<'a, T> Consumer for ScopedCallbackConsumer<'a, T> {
type Output = ::core::result::Result<T, Box<dyn 'static + Any + Send>>;
#[inline]
unsafe fn consume(self) -> Result<Self::Output> {
loop {
match self.0.recv.try_recv() {
Ok(x) => return Ok(x),
Err(TryRecvError::Disconnected) => todo!(),
Err(TryRecvError::Empty) => core::hint::spin_loop(),
}
}
}
}
cfg_if::cfg_if! {
if #[cfg(feature = "futures")] {
use futures::future::*;
use std::task::*;
#[cfg_attr(docsrs, doc(cfg(feature = "futures")))]
impl<T> Future for ScopedCallbackHandle<'_, T> {
type Output = std::thread::Result<T>;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
self.data.waker.register(cx.waker());
return match self.recv.try_recv() {
Ok(x) => Poll::Ready(x),
Err(TryRecvError::Empty) => Poll::Pending,
Err(TryRecvError::Disconnected) => panic!("Handle already joined")
}
}
}
}
}