use crate::error::other_err;
use crate::error::{JlrsError, JlrsResult};
use crate::frame::AsyncFrame;
use crate::global::Global;
use crate::mode::Async;
use crate::stack::multitask::{MultitaskStack, TaskStack};
use crate::stack::{Dynamic, StackView};
use crate::traits::multitask::{JuliaTask, ReturnChannel};
use crate::value::module::Module;
use crate::value::Value;
use crate::{INIT, JLRS_JL};
use async_std::channel::{
bounded, Receiver as AsyncStdReceiver, RecvError, Sender as AsyncStdSender, TrySendError,
};
use async_std::future::timeout;
use async_std::sync::{Condvar as AsyncStdCondvar, Mutex as AsyncStdMutex};
use async_std::task::{self, JoinHandle as AsyncStdHandle};
use jl_sys::{jl_atexit_hook, jl_gc_safepoint, jl_init_with_image__threading, jl_is_initialized};
use std::ffi::{c_void, CString};
use std::io::{Error as IOError, ErrorKind};
use std::path::{Path, PathBuf};
use std::sync::atomic::Ordering;
use std::sync::{Arc, Condvar, Mutex};
use std::thread::{self, JoinHandle as ThreadHandle};
use std::time::Duration;
#[derive(Clone)]
pub struct AsyncJulia<T, R>
where
T: Send + Sync + 'static,
R: ReturnChannel<T = T>,
{
sender: AsyncStdSender<Message<T, R>>,
}
impl<T, R> AsyncJulia<T, R>
where
T: Send + Sync + 'static,
R: ReturnChannel<T = T>,
{
pub unsafe fn init(
channel_capacity: usize,
n_threads: usize,
stack_size: usize,
process_events_ms: u64,
) -> JlrsResult<(Self, ThreadHandle<JlrsResult<()>>)> {
let (sender, receiver) = bounded(channel_capacity);
let julia = AsyncJulia { sender };
let handle =
thread::spawn(move || run_async(n_threads, stack_size, process_events_ms, receiver));
julia.try_set_wake_fn().map_err(other_err)?;
Ok((julia, handle))
}
pub async unsafe fn init_async(
channel_capacity: usize,
n_threads: usize,
stack_size: usize,
process_events_ms: u64,
) -> JlrsResult<(Self, AsyncStdHandle<JlrsResult<()>>)> {
let (sender, receiver) = bounded(channel_capacity);
let julia = AsyncJulia { sender };
let handle = task::spawn_blocking(move || {
run_async(n_threads, stack_size, process_events_ms, receiver)
});
julia.set_wake_fn().await?;
Ok((julia, handle))
}
pub unsafe fn init_with_image<P, Q>(
channel_capacity: usize,
n_threads: usize,
stack_size: usize,
process_events_ms: u64,
julia_bindir: P,
image_path: Q,
) -> JlrsResult<(Self, ThreadHandle<JlrsResult<()>>)>
where
P: AsRef<Path> + Send + 'static,
Q: AsRef<Path> + Send + 'static,
{
let (sender, receiver) = bounded(channel_capacity);
let julia = AsyncJulia { sender };
let handle = thread::spawn(move || {
run_async_with_image(
n_threads,
stack_size,
process_events_ms,
receiver,
julia_bindir,
image_path,
)
});
julia.try_set_wake_fn().map_err(other_err)?;
Ok((julia, handle))
}
pub async unsafe fn init_with_image_async<P, Q>(
channel_capacity: usize,
n_threads: usize,
stack_size: usize,
process_events_ms: u64,
julia_bindir: P,
image_path: Q,
) -> JlrsResult<(Self, AsyncStdHandle<JlrsResult<()>>)>
where
P: AsRef<Path> + Send + 'static,
Q: AsRef<Path> + Send + 'static,
{
let (sender, receiver) = bounded(channel_capacity);
let julia = AsyncJulia { sender };
let handle = task::spawn_blocking(move || {
run_async_with_image(
n_threads,
stack_size,
process_events_ms,
receiver,
julia_bindir,
image_path,
)
});
julia.set_wake_fn().await?;
Ok((julia, handle))
}
pub async fn new_task<D: JuliaTask<T = T, R = R>>(&self, task: D) {
let sender = self.sender.clone();
self.sender
.send(Message::Task(Box::new(task), sender))
.await
.expect("Channel was closed");
}
pub fn try_new_task<D: JuliaTask<T = T, R = R>>(&self, task: D) -> JlrsResult<()> {
let sender = self.sender.clone();
self.sender
.try_send(Message::Task(Box::new(task), sender))
.map_err(|e| match e {
TrySendError::Full(Message::Task(t, _)) => {
Box::new(other_err(TrySendError::Full(t)))
}
TrySendError::Closed(Message::Task(t, _)) => {
Box::new(other_err(TrySendError::Closed(t)))
}
_ => unreachable!(),
})
}
pub async fn include<P: AsRef<Path>>(&self, path: P) -> JlrsResult<()> {
if !path.as_ref().exists() {
return Err(JlrsError::IncludeNotFound(path.as_ref().to_string_lossy().into()).into());
}
let completed = Arc::new((AsyncStdMutex::new(Status::Pending), AsyncStdCondvar::new()));
self.sender
.send(Message::Include(
path.as_ref().to_path_buf(),
completed.clone(),
))
.await
.expect("Channel was closed");
let (lock, cvar) = &*completed;
let mut completed = lock.lock().await;
while (&*completed).is_pending() {
completed = cvar.wait(completed).await;
}
(&mut *completed).as_jlrs_result()
}
pub fn try_include<P: AsRef<Path>>(&self, path: P) -> JlrsResult<()> {
if !path.as_ref().exists() {
return Err(JlrsError::IncludeNotFound(path.as_ref().to_string_lossy().into()).into());
}
let completed = Arc::new((Mutex::new(Status::Pending), Condvar::new()));
self.sender
.try_send(Message::TryInclude(
path.as_ref().to_path_buf(),
completed.clone(),
))
.map_err(|e| match e {
TrySendError::Full(Message::Include(t, _)) => {
Box::new(other_err(TrySendError::Full(t)))
}
TrySendError::Closed(Message::Include(t, _)) => {
Box::new(other_err(TrySendError::Closed(t)))
}
_ => unreachable!(),
})
.and_then(|_| {
let (lock, cvar) = &*completed;
let mut completed = lock.lock().unwrap();
while (&*completed).is_pending() {
completed = cvar.wait(completed).unwrap();
}
(&mut *completed).as_jlrs_result()
})
}
pub fn capacity(&self) -> usize {
self.sender.capacity().unwrap()
}
pub fn len(&self) -> usize {
self.sender.len()
}
pub fn is_empty(&self) -> bool {
self.sender.is_empty()
}
pub fn is_full(&self) -> bool {
self.sender.is_full()
}
fn try_set_wake_fn(&self) -> JlrsResult<()> {
let completed = Arc::new((Mutex::new(Status::Pending), Condvar::new()));
self.sender
.try_send(Message::TrySetWakeFn(completed.clone()))
.map_err(|e| match e {
TrySendError::Full(Message::TrySetWakeFn(_)) => {
Box::new(other_err(TrySendError::Full(())))
}
TrySendError::Closed(Message::TrySetWakeFn(_)) => {
Box::new(other_err(TrySendError::Closed(())))
}
_ => unreachable!(),
})
.and_then(|_| {
let (lock, cvar) = &*completed;
let mut completed = lock.lock().unwrap();
while (&*completed).is_pending() {
completed = cvar.wait(completed).unwrap();
}
(&mut *completed).as_jlrs_result()
})
}
async fn set_wake_fn(&self) -> JlrsResult<()> {
let completed = Arc::new((AsyncStdMutex::new(Status::Pending), AsyncStdCondvar::new()));
self.sender
.send(Message::SetWakeFn(completed.clone()))
.await
.expect("Channel was closed");
{
let (lock, cvar) = &*completed;
let mut completed = lock.lock().await;
while (&*completed).is_pending() {
completed = cvar.wait(completed).await;
}
(&mut *completed).as_jlrs_result()
}
}
}
enum Status {
Pending,
Ok,
Err(Option<Box<JlrsError>>),
}
impl Status {
fn is_pending(&self) -> bool {
match self {
Status::Pending => true,
_ => false,
}
}
fn as_jlrs_result(&mut self) -> JlrsResult<()> {
match self {
Status::Ok => Ok(()),
Status::Err(ref mut e) => Err(e.take().expect("Status is Err, but no error is set")),
Status::Pending => panic!("Cannot convert Status::Pending to JlrsResult"),
}
}
}
enum Message<T, R> {
Task(
Box<dyn JuliaTask<T = T, R = R>>,
AsyncStdSender<Message<T, R>>,
),
Include(PathBuf, Arc<(AsyncStdMutex<Status>, AsyncStdCondvar)>),
TryInclude(PathBuf, Arc<(Mutex<Status>, Condvar)>),
Complete(Wrapper, AsyncStdSender<Message<T, R>>),
SetWakeFn(Arc<(AsyncStdMutex<Status>, AsyncStdCondvar)>),
TrySetWakeFn(Arc<(Mutex<Status>, Condvar)>),
}
struct Wrapper(usize, TaskStack);
unsafe impl Send for Wrapper {}
fn run_task<T: Send + Sync + 'static, R>(
mut jl_task: Box<dyn JuliaTask<T = T, R = R>>,
task_idx: usize,
mut task_stack: TaskStack,
rt_sender: AsyncStdSender<Message<T, R>>,
) -> AsyncStdHandle<()>
where
R: ReturnChannel<T = T> + 'static,
{
unsafe {
task::spawn_local(async move {
let mut tv = StackView::<Async, Dynamic>::new(&mut task_stack.raw);
match tv.new_frame() {
Ok(frame_idx) => {
let global = Global::new();
let mut frame = AsyncFrame {
idx: frame_idx,
memory: tv,
len: 0,
};
let res = jl_task.run(global, &mut frame).await;
if let Some(sender) = jl_task.return_channel() {
sender.send(res).await;
}
}
Err(e) => {
if let Some(sender) = jl_task.return_channel() {
sender.send(Err(e)).await;
}
}
}
let rt_c = rt_sender.clone();
rt_sender
.send(Message::Complete(Wrapper(task_idx, task_stack), rt_c))
.await
.expect("Channel was closed");
})
}
}
fn run_async<T, R>(
n_threads: usize,
stack_size: usize,
process_events_ms: u64,
receiver: AsyncStdReceiver<Message<T, R>>,
) -> JlrsResult<()>
where
T: Send + Sync + 'static,
R: ReturnChannel<T = T> + 'static,
{
task::block_on(async {
let mut mt_stack: MultitaskStack<T, R> = unsafe {
if jl_is_initialized() != 0 || INIT.swap(true, Ordering::SeqCst) {
return Err(JlrsError::AlreadyInitialized.into());
}
jl_sys::jl_init();
let jlrs_jl = CString::new(JLRS_JL).expect("Invalid Jlrs module");
jl_sys::jl_eval_string(jlrs_jl.as_ptr());
MultitaskStack::new(n_threads, stack_size)
};
loop {
match timeout(Duration::from_millis(process_events_ms), receiver.recv()).await {
Err(_) => unsafe {
if mt_stack.n > 0 {
jl_sys::jl_process_events();
}
},
Ok(Ok(Message::Task(jl_task, sender))) => {
if let Some((task_idx, task_stack)) = mt_stack.acquire_task_frame() {
mt_stack.n += 1;
mt_stack.running[task_idx] =
Some(run_task(jl_task, task_idx, task_stack, sender));
} else {
mt_stack.add_pending(jl_task);
}
}
Ok(Ok(Message::Complete(Wrapper(task_idx, task_stack), sender))) => {
if let Some(jl_task) = mt_stack.pop_pending() {
mt_stack.running[task_idx] =
Some(run_task(jl_task, task_idx, task_stack, sender));
} else {
mt_stack.n -= 1;
mt_stack.running[task_idx] = None;
mt_stack.return_task_frame(task_idx, task_stack);
}
}
Ok(Ok(Message::Include(path, completed))) => {
include(&mut mt_stack.raw, path, completed).await
}
Ok(Ok(Message::TryInclude(path, completed))) => {
try_include(&mut mt_stack.raw, path, completed)
}
Ok(Ok(Message::SetWakeFn(completed))) => {
set_wake_fn(&mut mt_stack.raw, completed).await
}
Ok(Ok(Message::TrySetWakeFn(completed))) => {
try_set_wake_fn(&mut mt_stack.raw, completed)
}
Ok(Err(RecvError)) => break,
}
}
for running in mt_stack.running.iter_mut() {
if let Some(handle) = running.take() {
handle.await;
}
}
unsafe {
jl_atexit_hook(0);
}
Ok(())
})
}
fn run_async_with_image<T, R, P, Q>(
n_threads: usize,
stack_size: usize,
process_events_ms: u64,
receiver: AsyncStdReceiver<Message<T, R>>,
julia_bindir: P,
image_path: Q,
) -> JlrsResult<()>
where
T: Send + Sync + 'static,
R: ReturnChannel<T = T> + 'static,
P: AsRef<Path>,
Q: AsRef<Path>,
{
task::block_on(async {
let mut mt_stack: MultitaskStack<T, R> = unsafe {
if jl_is_initialized() != 0 || INIT.swap(true, Ordering::SeqCst) {
return Err(JlrsError::AlreadyInitialized.into());
}
let julia_bindir_str = julia_bindir.as_ref().to_string_lossy().to_string();
let image_path_str = image_path.as_ref().to_string_lossy().to_string();
if !julia_bindir.as_ref().exists() {
let io_err = IOError::new(ErrorKind::NotFound, julia_bindir_str);
return Err(other_err(io_err))?;
}
if !image_path.as_ref().exists() {
let io_err = IOError::new(ErrorKind::NotFound, image_path_str);
return Err(other_err(io_err))?;
}
let bindir = std::ffi::CString::new(julia_bindir_str).unwrap();
let im_rel_path = std::ffi::CString::new(image_path_str).unwrap();
jl_init_with_image__threading(bindir.as_ptr(), im_rel_path.as_ptr());
let jlrs_jl = CString::new(JLRS_JL).expect("Invalid Jlrs module");
jl_sys::jl_eval_string(jlrs_jl.as_ptr());
MultitaskStack::new(n_threads, stack_size)
};
loop {
match timeout(Duration::from_millis(process_events_ms), receiver.recv()).await {
Err(_) => unsafe {
if mt_stack.n > 0 {
jl_gc_safepoint();
}
},
Ok(Ok(Message::Task(jl_task, sender))) => {
if let Some((task_idx, task_stack)) = mt_stack.acquire_task_frame() {
mt_stack.n += 1;
mt_stack.running[task_idx] =
Some(run_task(jl_task, task_idx, task_stack, sender));
} else {
mt_stack.add_pending(jl_task);
}
}
Ok(Ok(Message::Complete(Wrapper(task_idx, task_stack), sender))) => {
if let Some(jl_task) = mt_stack.pop_pending() {
mt_stack.running[task_idx] =
Some(run_task(jl_task, task_idx, task_stack, sender));
} else {
mt_stack.n -= 1;
mt_stack.running[task_idx] = None;
mt_stack.return_task_frame(task_idx, task_stack);
}
}
Ok(Ok(Message::Include(path, completed))) => {
include(&mut mt_stack.raw, path, completed).await
}
Ok(Ok(Message::TryInclude(path, completed))) => {
try_include(&mut mt_stack.raw, path, completed)
}
Ok(Ok(Message::SetWakeFn(completed))) => {
set_wake_fn(&mut mt_stack.raw, completed).await
}
Ok(Ok(Message::TrySetWakeFn(completed))) => {
try_set_wake_fn(&mut mt_stack.raw, completed)
}
Ok(Err(RecvError)) => break,
}
}
for pending in mt_stack.running.iter_mut() {
if let Some(handle) = pending.take() {
handle.await;
}
}
unsafe {
jl_atexit_hook(0);
}
Ok(())
})
}
fn call_set_wake_fn(stack: &mut [*mut c_void]) -> JlrsResult<()> {
unsafe {
let global = Global::new();
let mut view = StackView::<Async, Dynamic>::new(stack);
let idx = view.new_frame()?;
let mut frame = AsyncFrame {
idx,
len: 0,
memory: view,
};
let waker = Value::new(&mut frame, crate::julia_future::wake_task as *mut c_void)?;
Module::main(global)
.submodule("Jlrs")?
.global("wakerust")?
.set_nth_field(0, waker)?;
let dropper = Value::new(&mut frame, crate::droparray as *mut c_void)?;
Module::main(global)
.submodule("Jlrs")?
.global("droparray")?
.set_nth_field(0, dropper)?;
}
Ok(())
}
async fn set_wake_fn(
stacks: &mut [Option<TaskStack>],
completed: Arc<(AsyncStdMutex<Status>, AsyncStdCondvar)>,
) {
let idx = stacks.len() - 1;
let mut stack = stacks[idx].take().expect("GC stack is corrupted.");
let set_wake_result = call_set_wake_fn(&mut stack.raw);
stacks[idx] = Some(stack);
{
let (lock, condvar) = &*completed;
let mut completed = lock.lock().await;
if set_wake_result.is_ok() {
*completed = Status::Ok;
} else {
*completed = Status::Err(Some(set_wake_result.unwrap_err()));
}
condvar.notify_one();
}
}
fn try_set_wake_fn(stacks: &mut [Option<TaskStack>], completed: Arc<(Mutex<Status>, Condvar)>) {
let idx = stacks.len() - 1;
let mut stack = stacks[idx].take().expect("GC stack is corrupted.");
let set_wake_result = call_set_wake_fn(&mut stack.raw);
stacks[idx] = Some(stack);
{
let (lock, condvar) = &*completed;
let mut completed = lock.lock().expect("Cannot lock");
if set_wake_result.is_ok() {
*completed = Status::Ok;
} else {
*completed = Status::Err(Some(set_wake_result.unwrap_err()));
}
condvar.notify_one();
}
}
fn call_include(stack: &mut [*mut c_void], path: PathBuf) -> JlrsResult<()> {
unsafe {
let global = Global::new();
let mut view = StackView::<Async, Dynamic>::new(stack);
let idx = view.new_frame()?;
let mut frame = AsyncFrame {
idx,
len: 0,
memory: view,
};
match path.to_str() {
Some(path) => {
let path = Value::new(&mut frame, path)?;
Module::main(global)
.function("include")?
.call1(&mut frame, path)?
.map_err(|_e| {
crate::error::exception::<Value>("Include error".into()).unwrap_err()
})?;
}
None => {}
}
Ok(())
}
}
async fn include(
stacks: &mut [Option<TaskStack>],
path: PathBuf,
completed: Arc<(AsyncStdMutex<Status>, AsyncStdCondvar)>,
) {
let idx = stacks.len() - 1;
let include_result = {
let mut stack = stacks[idx].take().expect("GC stack is corrupted.");
let res = call_include(&mut stack.raw, path);
stacks[idx] = Some(stack);
res
};
{
let (lock, condvar) = &*completed;
let mut completed = lock.lock().await;
if include_result.is_ok() {
*completed = Status::Ok;
} else {
*completed = Status::Err(Some(include_result.unwrap_err()));
}
condvar.notify_one();
}
}
fn try_include(
stacks: &mut [Option<TaskStack>],
path: PathBuf,
completed: Arc<(Mutex<Status>, Condvar)>,
) {
let idx = stacks.len() - 1;
let include_result = {
let mut stack = stacks[idx].take().expect("GC stack is corrupted.");
let res = call_include(&mut stack.raw, path);
stacks[idx] = Some(stack);
res
};
{
let (lock, condvar) = &*completed;
let mut completed = lock.lock().expect("Cannot lock");
if include_result.is_ok() {
*completed = Status::Ok;
} else {
*completed = Status::Err(Some(include_result.unwrap_err()));
}
condvar.notify_one();
}
}