use std::{
ffi::c_void,
fmt::Display,
future::Future,
marker::PhantomData,
pin::Pin,
ptr::NonNull,
sync::Arc,
task::{Context, Poll, Waker},
};
use jl_sys::{jl_call, jl_call1, jl_exception_occurred};
use jlrs_sys::jlrs_current_task;
use crate::{
args::Values,
call::Call,
data::managed::{
Managed, erase_scope_lifetime,
module::{JlrsCore, Module},
named_tuple::NamedTuple,
private::ManagedPriv,
value::Value,
},
error::{CANNOT_DISPLAY_VALUE, JuliaResult},
gc_safe::GcSafeMutex,
memory::{
PTls,
gc::{gc_safe_with, gc_unsafe_with},
get_tls,
target::{frame::AsyncGcFrame, private::TargetPriv, unrooted::Unrooted},
},
prelude::Target,
private::Private,
util::kwcall_function,
};
pub struct GcSafeFuture<T, F>
where
F: Future<Output = T>,
{
fut: F,
ptls: PTls,
_marker_t: PhantomData<T>,
}
impl<T, F> GcSafeFuture<T, F>
where
F: Future<Output = T>,
{
pub unsafe fn new(fut: F) -> Self {
unsafe {
assert!(!jlrs_current_task().is_null(), "No task");
let ptls = get_tls();
assert!(!ptls.is_null(), "no TLS");
GcSafeFuture {
fut,
ptls,
_marker_t: PhantomData,
}
}
}
}
impl<T, F> Future for GcSafeFuture<T, F>
where
F: Future<Output = T>,
{
type Output = T;
fn poll(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
unsafe {
gc_safe_with(self.ptls, || {
let pinned: Pin<&mut F> =
std::pin::Pin::new_unchecked(&mut self.get_unchecked_mut().fut);
pinned.poll(cx)
})
}
}
}
pub struct GcUnsafeFuture<T, F>
where
F: Future<Output = T>,
{
fut: F,
ptls: PTls,
_marker_t: PhantomData<T>,
}
impl<T, F> GcUnsafeFuture<T, F>
where
F: Future<Output = T>,
{
pub fn new(fut: F) -> Self {
unsafe {
debug_assert!(!jlrs_current_task().is_null(), "invalid_thread");
let ptls = get_tls();
GcUnsafeFuture {
fut,
ptls,
_marker_t: PhantomData,
}
}
}
}
impl<T, F> Future for GcUnsafeFuture<T, F>
where
F: Future<Output = T>,
{
type Output = T;
fn poll(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
unsafe {
gc_unsafe_with(self.ptls, |_| {
let pinned: Pin<&mut F> =
std::pin::Pin::new_unchecked(&mut self.get_unchecked_mut().fut);
pinned.poll(cx)
})
}
}
}
pub(crate) struct TaskState<'frame, 'data> {
completed: bool,
waker: Option<Waker>,
task: Option<Value<'frame, 'data>>,
}
impl<'frame, 'data> TaskState<'frame, 'data> {
fn new() -> Arc<GcSafeMutex<Self>> {
Arc::new(GcSafeMutex::new(TaskState {
completed: false,
waker: None,
task: None,
}))
}
}
enum AsyncMethod {
AsyncCall,
InteractiveCall,
}
impl AsyncMethod {
fn method<'target, Tgt>(&self, target: &Tgt) -> Value<'target, 'static>
where
Tgt: Target<'target>,
{
match self {
AsyncMethod::AsyncCall => return JlrsCore::async_call(target),
AsyncMethod::InteractiveCall => return JlrsCore::interactive_call(target),
}
}
}
impl Display for AsyncMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AsyncMethod::AsyncCall => f.write_str("asynccall"),
AsyncMethod::InteractiveCall => f.write_str("interactivecall"),
}
}
}
pub(crate) struct JuliaFuture<'frame, 'data> {
shared_state: Arc<GcSafeMutex<TaskState<'frame, 'data>>>,
}
impl<'frame, 'data> JuliaFuture<'frame, 'data> {
#[inline]
pub(crate) fn new<'value, V, const N: usize>(
frame: &mut AsyncGcFrame<'frame>,
func: Value<'value, 'data>,
values: V,
) -> Self
where
V: Values<'value, 'data, N>,
{
Self::new_future(frame, func, values, AsyncMethod::AsyncCall)
}
#[inline]
pub(crate) fn new_interactive<'value, V, const N: usize>(
frame: &mut AsyncGcFrame<'frame>,
func: Value<'value, 'data>,
values: V,
) -> Self
where
V: Values<'value, 'data, N>,
{
Self::new_future(frame, func, values, AsyncMethod::InteractiveCall)
}
#[inline]
pub(crate) fn new_with_keywords<'value, V, const N: usize>(
frame: &mut AsyncGcFrame<'frame>,
func: Value<'_, 'data>,
values: V,
kwargs: NamedTuple<'_, 'data>,
) -> Self
where
V: Values<'value, 'data, N>,
{
Self::new_future_with_keywords(frame, func, kwargs, values, AsyncMethod::AsyncCall)
}
#[inline]
pub(crate) fn new_interactive_with_keywords<'value, V, const N: usize>(
frame: &mut AsyncGcFrame<'frame>,
func: Value<'_, 'data>,
values: V,
kwargs: NamedTuple<'_, 'data>,
) -> Self
where
V: Values<'value, 'data, N>,
{
Self::new_future_with_keywords(frame, func, kwargs, values, AsyncMethod::InteractiveCall)
}
fn new_future<'value, V, const N: usize>(
frame: &mut AsyncGcFrame<'frame>,
func: Value<'value, 'data>,
values: V,
method: AsyncMethod,
) -> Self
where
V: Values<'value, 'data, N>,
{
let shared_state = TaskState::new();
let state_ptr = Arc::into_raw(shared_state.clone()) as *mut c_void;
let state_ptr_boxed = Value::new(&mut *frame, state_ptr);
let task = unsafe {
let values = values.into_extended_with_start(
[
erase_scope_lifetime(func),
erase_scope_lifetime(state_ptr_boxed),
],
Private,
);
method
.method(&frame)
.call(&mut *frame, values.as_ref())
.unwrap_or_else(|e| {
let msg = e.display_string_or(CANNOT_DISPLAY_VALUE);
panic!("{} threw an exception: {}", method, msg)
})
};
{
let mut locked = shared_state.lock();
locked.task = Some(task);
}
JuliaFuture { shared_state }
}
fn new_future_with_keywords<'value, V, const N: usize>(
frame: &mut AsyncGcFrame<'frame>,
func: Value<'_, 'data>,
kwargs: NamedTuple<'_, 'data>,
values: V,
method: AsyncMethod,
) -> Self
where
V: Values<'value, 'data, N>,
{
let shared_state = TaskState::new();
let state_ptr = Arc::into_raw(shared_state.clone()) as *mut c_void;
let state_ptr_boxed = Value::new(&mut *frame, state_ptr);
let task = unsafe {
let f = method.method(&frame);
let kw_call = kwcall_function(&frame);
let values = values.into_extended_pointers_with_start(
[
kwargs.unwrap(Private),
f.unwrap(Private),
func.unwrap(Private),
state_ptr_boxed.unwrap(Private),
],
Private,
);
let values = values.as_ref();
let res = jl_call(kw_call, values.as_ptr() as *mut _, values.len() as _);
let exc = jl_exception_occurred();
let res = if exc.is_null() {
Ok(NonNull::new_unchecked(res))
} else {
Err(NonNull::new_unchecked(exc))
};
frame
.result_from_ptr::<Value>(res, Private)
.unwrap_or_else(|e| {
let msg = e.display_string_or(CANNOT_DISPLAY_VALUE);
panic!("{} threw an exception: {}", method, msg)
})
};
{
let mut locked = shared_state.lock();
locked.task = Some(task);
}
JuliaFuture { shared_state }
}
}
impl<'frame, 'data> Future for JuliaFuture<'frame, 'data> {
type Output = JuliaResult<'frame, 'data>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut shared_state = self.shared_state.lock();
if shared_state.completed {
if let Some(task) = shared_state.task {
unsafe {
let global = Unrooted::new();
let f = Module::base(&global)
.global(&global, "fetch")
.unwrap()
.as_managed();
let res = jl_call1(f.unwrap(Private), task.unwrap(Private).cast());
let exc = jl_exception_occurred();
if exc.is_null() {
Poll::Ready(Ok(Value::wrap_non_null(
NonNull::new_unchecked(res),
Private,
)))
} else {
Poll::Ready(Err(Value::wrap_non_null(
NonNull::new_unchecked(exc),
Private,
)))
}
}
} else {
unreachable!()
}
} else if shared_state.waker.is_none() {
shared_state.waker = Some(cx.waker().clone());
Poll::Pending
} else {
Poll::Pending
}
}
}
#[cfg(feature = "async-rt")]
pub(crate) unsafe extern "C" fn wake_task(state: *const GcSafeMutex<TaskState>) {
unsafe {
let state = Arc::from_raw(state);
let mut state = state.lock();
state.completed = true;
state.waker.take().map(|waker| waker.wake());
}
}