use crate::bad_wire;
use crate::inventory::{Inventory, TypeId};
use crate::lang::meta::Visibility;
use crate::lang::types::{TypeInfo, TypeKind, WireIO};
use crate::wire::SerializationError;
use std::ffi::c_void;
use std::future::Future;
use std::io::{Read, Write};
use std::ops::Deref;
use std::pin::Pin;
use std::ptr::null;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
#[doc(hidden)]
#[derive(Clone, Copy)]
#[repr(C)]
pub struct AsyncCallback<T>(Option<extern "C" fn(*const T, *const c_void) -> ()>, *const c_void);
unsafe impl<T> Send for AsyncCallback<T> {}
unsafe impl<T> Sync for AsyncCallback<T> {}
impl<T: TypeInfo> AsyncCallback<T> {
pub fn new(func: extern "C" fn(*const T, *const c_void)) -> Self {
Self(Some(func), null())
}
pub fn with_context(func: extern "C" fn(*const T, *const c_void), context: *const c_void) -> Self {
Self(Some(func), context)
}
pub unsafe fn call(&self, t: *const T) {
self.0.expect("Assumed function would exist but it didn't.")(t, self.1);
}
pub unsafe fn call_if_some(&self, t: *const T) -> Option<()> {
match self.0 {
Some(c) => {
c(t, self.1);
Some(())
}
None => None,
}
}
}
impl<T: TypeInfo> From<extern "C" fn(*const T, *const c_void)> for AsyncCallback<T> {
fn from(x: extern "C" fn(*const T, *const c_void) -> ()) -> Self {
Self(Some(x), null())
}
}
impl<T: TypeInfo> From<AsyncCallback<T>> for Option<extern "C" fn(*const T, *const c_void)> {
fn from(x: AsyncCallback<T>) -> Self {
x.0
}
}
unsafe impl<T: TypeInfo> TypeInfo for AsyncCallback<T> {
const WIRE_SAFE: bool = false;
const RAW_SAFE: bool = T::RAW_SAFE;
const ASYNC_SAFE: bool = T::ASYNC_SAFE;
const SERVICE_SAFE: bool = false;
const SERVICE_CTOR_SAFE: bool = false;
fn id() -> TypeId {
T::id().derive(0x3BA866E612BB2BEA769699B3476994B8)
}
fn kind() -> TypeKind {
TypeKind::TypePattern(crate::lang::types::TypePattern::AsyncCallback(T::id()))
}
fn ty() -> crate::lang::types::Type {
let t = T::ty();
crate::lang::types::Type {
emission: t.emission.clone(),
docs: crate::lang::meta::Docs::empty(),
visibility: Visibility::Public,
name: format!("AsyncCallback<{}>", t.name),
kind: Self::kind(),
}
}
fn register(inventory: &mut impl Inventory) {
T::register(inventory);
inventory.register_type(Self::id(), Self::ty());
}
}
unsafe impl<T: WireIO> WireIO for AsyncCallback<T> {
fn write(&self, _: &mut impl Write) -> Result<(), SerializationError> {
bad_wire!()
}
fn read(_: &mut impl Read) -> Result<Self, SerializationError> {
bad_wire!()
}
fn live_size(&self) -> usize {
bad_wire!()
}
}
struct FutureState<T> {
result: Option<T>,
waker: Option<Waker>,
on_complete: Option<Box<dyn FnOnce() + Send + 'static>>,
}
extern "C" fn async_callback_complete<T: Send + 'static>(value: *const T, context: *const c_void) {
let state = unsafe { Arc::from_raw(context.cast::<Mutex<FutureState<T>>>()) };
let mut lock = state.lock().unwrap();
lock.result = Some(unsafe { std::ptr::read(value) });
if let Some(on_complete) = lock.on_complete.take() {
on_complete();
}
if let Some(waker) = lock.waker.take() {
waker.wake();
}
}
pub struct AsyncCallbackFuture<T> {
state: Arc<Mutex<FutureState<T>>>,
}
impl<T: Send + 'static + TypeInfo> AsyncCallbackFuture<T> {
pub fn new() -> (Self, AsyncCallback<T>) {
let state = Arc::new(Mutex::new(FutureState { result: None, waker: None, on_complete: None }));
let raw = Arc::into_raw(Arc::clone(&state)).cast::<c_void>();
let cb = AsyncCallback::with_context(async_callback_complete::<T>, raw);
(Self { state }, cb)
}
pub fn new_with_on_complete(on_complete: impl FnOnce() + Send + 'static) -> (Self, AsyncCallback<T>) {
let state = Arc::new(Mutex::new(FutureState { result: None, waker: None, on_complete: Some(Box::new(on_complete)) }));
let raw = Arc::into_raw(Arc::clone(&state)).cast::<c_void>();
let cb = AsyncCallback::with_context(async_callback_complete::<T>, raw);
(Self { state }, cb)
}
}
impl<T: Send + 'static> Future for AsyncCallbackFuture<T> {
type Output = T;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> {
let mut lock = self.state.lock().unwrap();
if let Some(result) = lock.result.take() {
Poll::Ready(result)
} else {
lock.waker = Some(cx.waker().clone());
Poll::Pending
}
}
}
pub struct Async<S: AsyncRuntime> {
s: Arc<S>, t: S::T,
}
impl<S: AsyncRuntime> Async<S> {
pub fn new(s: Arc<S>, t: S::T) -> Self {
Self { s, t }
}
pub fn context(&self) -> &S::T {
&self.t
}
}
impl<S: AsyncRuntime> Deref for Async<S> {
type Target = Arc<S>;
fn deref(&self) -> &Self::Target {
&self.s
}
}
pub trait AsyncRuntime {
type T;
fn spawn<Fn, F>(&self, f: Fn)
where
Fn: FnOnce(Self::T) -> F + Send + 'static,
F: Future<Output = ()> + Send + 'static;
}