use crate::bad_wire;
use crate::inventory::{Inventory, TypeId};
use crate::lang::meta::Visibility;
use crate::lang::types::{TypeInfo, TypeKind, TypePattern, WireIO};
use crate::wire::SerializationError;
use std::ffi::c_void;
use std::future::Future;
use std::io::{Read, Write};
use std::ops::Deref;
use std::pin::Pin;
use std::ptr::null;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
#[doc(hidden)]
#[repr(C)]
pub struct AsyncCallback<T>(Option<extern "C" fn(*const T, *const c_void) -> ()>, *const c_void);
impl<T> Clone for AsyncCallback<T> {
fn clone(&self) -> Self {
*self
}
}
impl<T> Copy for AsyncCallback<T> {}
unsafe impl<T> Send for AsyncCallback<T> {}
unsafe impl<T> Sync for AsyncCallback<T> {}
impl<T: TypeInfo> AsyncCallback<T> {
pub fn new(func: extern "C" fn(*const T, *const c_void)) -> Self {
Self(Some(func), null())
}
pub fn with_context(func: extern "C" fn(*const T, *const c_void), context: *const c_void) -> Self {
Self(Some(func), context)
}
pub unsafe fn call(&self, t: *const T) {
self.0.expect("Assumed function would exist but it didn't.")(t, self.1);
}
pub unsafe fn call_if_some(&self, t: *const T) -> Option<()> {
match self.0 {
Some(c) => {
c(t, self.1);
Some(())
}
None => None,
}
}
}
impl<T: TypeInfo> From<extern "C" fn(*const T, *const c_void)> for AsyncCallback<T> {
fn from(x: extern "C" fn(*const T, *const c_void) -> ()) -> Self {
Self(Some(x), null())
}
}
impl<T: TypeInfo> From<AsyncCallback<T>> for Option<extern "C" fn(*const T, *const c_void)> {
fn from(x: AsyncCallback<T>) -> Self {
x.0
}
}
unsafe impl<T: TypeInfo> TypeInfo for AsyncCallback<T> {
const WIRE_SAFE: bool = false;
const RAW_SAFE: bool = T::RAW_SAFE;
const ASYNC_SAFE: bool = T::ASYNC_SAFE;
const SERVICE_SAFE: bool = false;
const SERVICE_CTOR_SAFE: bool = false;
fn id() -> TypeId {
T::id().derive(0x3BA866E612BB2BEA769699B3476994B8)
}
fn kind() -> TypeKind {
TypeKind::TypePattern(crate::lang::types::TypePattern::AsyncCallback(T::id()))
}
fn ty() -> crate::lang::types::Type {
let t = T::ty();
crate::lang::types::Type {
emission: t.emission.clone(),
docs: crate::lang::meta::Docs::empty(),
visibility: Visibility::Public,
name: format!("AsyncCallback<{}>", t.name),
kind: Self::kind(),
}
}
fn register(inventory: &mut impl Inventory) {
T::register(inventory);
inventory.register_type(Self::id(), Self::ty());
}
}
unsafe impl<T: WireIO> WireIO for AsyncCallback<T> {
fn write(&self, _: &mut impl Write) -> Result<(), SerializationError> {
bad_wire!()
}
fn read(_: &mut impl Read) -> Result<Self, SerializationError> {
bad_wire!()
}
fn live_size(&self) -> usize {
bad_wire!()
}
}
struct FutureState<T> {
result: Option<T>,
waker: Option<Waker>,
on_complete: Option<Box<dyn FnOnce() + Send + 'static>>,
}
extern "C" fn async_callback_complete<T: Send + 'static>(value: *const T, context: *const c_void) {
let state = unsafe { Arc::from_raw(context.cast::<Mutex<FutureState<T>>>()) };
let mut lock = state.lock().unwrap();
lock.result = Some(unsafe { std::ptr::read(value) });
if let Some(on_complete) = lock.on_complete.take() {
on_complete();
}
if let Some(waker) = lock.waker.take() {
waker.wake();
}
}
#[doc(hidden)]
pub struct AsyncCallbackFuture<T> {
state: Arc<Mutex<FutureState<T>>>,
}
impl<T: Send + 'static + TypeInfo> AsyncCallbackFuture<T> {
pub fn new() -> (Self, AsyncCallback<T>) {
let state = Arc::new(Mutex::new(FutureState { result: None, waker: None, on_complete: None }));
let raw = Arc::into_raw(Arc::clone(&state)).cast::<c_void>();
let cb = AsyncCallback::with_context(async_callback_complete::<T>, raw);
(Self { state }, cb)
}
pub fn new_with_on_complete(on_complete: impl FnOnce() + Send + 'static) -> (Self, AsyncCallback<T>) {
let state = Arc::new(Mutex::new(FutureState { result: None, waker: None, on_complete: Some(Box::new(on_complete)) }));
let raw = Arc::into_raw(Arc::clone(&state)).cast::<c_void>();
let cb = AsyncCallback::with_context(async_callback_complete::<T>, raw);
(Self { state }, cb)
}
}
impl<T: Send + 'static> Future for AsyncCallbackFuture<T> {
type Output = T;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> {
let mut lock = self.state.lock().unwrap();
if let Some(result) = lock.result.take() {
Poll::Ready(result)
} else {
lock.waker = Some(cx.waker().clone());
Poll::Pending
}
}
}
pub struct Async<S: AsyncRuntime> {
s: Arc<S>, t: S::T,
}
impl<S: AsyncRuntime> Async<S> {
pub fn new(s: Arc<S>, t: S::T) -> Self {
Self { s, t }
}
pub fn context(&self) -> &S::T {
&self.t
}
}
impl<S: AsyncRuntime> Deref for Async<S> {
type Target = Arc<S>;
fn deref(&self) -> &Self::Target {
&self.s
}
}
pub trait AsyncRuntime {
type T;
fn spawn<Fn, F>(&self, f: Fn) -> TaskHandle
where
Fn: FnOnce(Self::T) -> F + Send + 'static,
F: Future<Output = ()> + Send + 'static;
}
#[repr(C)]
pub struct TaskHandle {
data: *mut (),
abort_fn: Option<unsafe extern "C" fn(*mut ())>,
drop_fn: Option<unsafe extern "C" fn(*mut ())>,
}
unsafe impl Send for TaskHandle {}
unsafe impl Sync for TaskHandle {}
impl TaskHandle {
pub fn from_handle<T: Send + 'static>(handle: T, abort: fn(&T)) -> Self {
let boxed = Box::into_raw(Box::new(TaskHandleInner { handle, abort }));
Self { data: boxed.cast(), abort_fn: Some(trampoline_abort::<T>), drop_fn: Some(trampoline_drop::<T>) }
}
pub fn abort(&self) {
if let Some(f) = self.abort_fn {
unsafe {
f(self.data);
}
}
}
#[must_use]
pub fn dummy() -> Self {
Self { data: std::ptr::null_mut(), abort_fn: None, drop_fn: None }
}
}
struct TaskHandleInner<T> {
handle: T,
abort: fn(&T),
}
unsafe extern "C" fn trampoline_abort<T>(data: *mut ()) {
unsafe {
let inner = &*(data.cast::<TaskHandleInner<T>>());
(inner.abort)(&inner.handle);
}
}
unsafe extern "C" fn trampoline_drop<T>(data: *mut ()) {
unsafe {
let _ = Box::from_raw(data.cast::<TaskHandleInner<T>>());
}
}
impl Drop for TaskHandle {
fn drop(&mut self) {
if let Some(f) = self.drop_fn.take() {
unsafe {
f(self.data);
}
}
}
}
unsafe impl TypeInfo for TaskHandle {
const WIRE_SAFE: bool = false;
const RAW_SAFE: bool = true;
const ASYNC_SAFE: bool = false;
const SERVICE_SAFE: bool = false;
const SERVICE_CTOR_SAFE: bool = false;
fn id() -> TypeId {
TypeId::new(0xA4B3C2D1E0F98765_4321ABCDEF012345)
}
fn kind() -> TypeKind {
TypeKind::TypePattern(TypePattern::TaskHandle)
}
fn ty() -> crate::lang::types::Type {
crate::lang::types::Type {
emission: crate::lang::meta::Emission::Builtin,
docs: crate::lang::meta::Docs::empty(),
visibility: Visibility::Public,
name: "TaskHandle".to_string(),
kind: Self::kind(),
}
}
fn register(inventory: &mut impl Inventory) {
inventory.register_type(Self::id(), Self::ty());
}
}
unsafe impl WireIO for TaskHandle {
fn write(&self, _: &mut impl Write) -> Result<(), SerializationError> {
bad_wire!()
}
fn read(_: &mut impl Read) -> Result<Self, SerializationError> {
bad_wire!()
}
fn live_size(&self) -> usize {
bad_wire!()
}
}
#[doc(hidden)]
pub trait CancelValue {
fn cancel_value() -> Self;
}
#[doc(hidden)]
pub struct AsyncCallbackGuard<T: TypeInfo + CancelValue> {
completed: AtomicBool,
callback: AsyncCallback<T>,
}
impl<T: TypeInfo + CancelValue> AsyncCallbackGuard<T> {
#[must_use]
pub fn new(callback: AsyncCallback<T>) -> Self {
Self { completed: AtomicBool::new(false), callback }
}
pub fn mark_completed(&self) -> bool {
!self.completed.swap(true, Ordering::AcqRel)
}
}
impl<T: TypeInfo + CancelValue> Drop for AsyncCallbackGuard<T> {
#[allow(clippy::mem_forget)]
fn drop(&mut self) {
if !self.completed.swap(true, Ordering::AcqRel) {
let v = T::cancel_value();
unsafe {
self.callback.call(&raw const v);
}
std::mem::forget(v);
}
}
}