jlrs/data/managed/
background_task.rs

1//! Task delegated to a background thread that can call into Julia.
2
3use std::{
4    fmt,
5    marker::{PhantomData, PhantomPinned},
6    mem::{self, MaybeUninit},
7    os::raw::c_void,
8    ptr::NonNull,
9    thread::{self, JoinHandle},
10};
11
12use jl_sys::jl_gc_alloc_typed;
13use parking_lot::Mutex;
14
15use super::{
16    Managed, Weak,
17    module::JlrsCore,
18    private::ManagedPriv,
19    value::{Value, ValueData, ValueRet, WeakValue},
20};
21use crate::{
22    call::Call,
23    convert::ccall_types::{CCallArg, CCallReturn},
24    data::{
25        layout::{is_bits::IsBits, typed_layout::HasLayout, valid_layout::ValidLayout},
26        types::construct_type::{ConstructType, TypeVarEnv},
27    },
28    error::JlrsError,
29    inline_static_ref,
30    memory::{
31        gc::gc_safe,
32        get_tls,
33        scope::LocalScopeExt,
34        target::{TargetResult, unrooted::Unrooted},
35    },
36    prelude::{DataType, JlrsResult, Target, TargetType},
37    private::Private,
38    util::uv_async_send_func,
39    weak_handle_unchecked,
40};
41
42/// A background task.
43///
44/// Call `Base.fetch` to wait for a background task to complete and fetch the result.
45#[repr(transparent)]
46pub struct BackgroundTask<'scope, T>(
47    NonNull<BackgroundTaskLayout<'scope, T>>,
48    PhantomData<&'scope ()>,
49)
50where
51    T: 'static + HasLayout<'static, 'static>,
52    T::Layout: IsBits + Clone;
53
54impl<T> Clone for BackgroundTask<'_, T>
55where
56    T: 'static + HasLayout<'static, 'static>,
57    T::Layout: IsBits + Clone,
58{
59    fn clone(&self) -> Self {
60        Self(self.0.clone(), PhantomData)
61    }
62}
63
64impl<T> Copy for BackgroundTask<'_, T>
65where
66    T: 'static + HasLayout<'static, 'static>,
67    T::Layout: IsBits + Clone,
68{
69}
70
71impl<'scope, T> BackgroundTask<'scope, T>
72where
73    T: 'static + HasLayout<'static, 'static>,
74    T::Layout: IsBits + Clone + CCallReturn,
75{
76    fn new<'target, Tgt: Target<'target>>(target: Tgt) -> BackgroundTaskData<'target, T, Tgt> {
77        unsafe {
78            target.with_local_scope::<_, 2>(|target, mut frame| {
79                let cond =
80                    inline_static_ref!(ASYNC_CONDITION, DataType, "Base.AsyncCondition", &frame);
81                let cond = cond.as_value().call_unchecked(&mut frame, []);
82
83                let ptls = get_tls();
84                let ty = Self::construct_type(&mut frame);
85                let ptr = jl_gc_alloc_typed(
86                    ptls,
87                    mem::size_of::<BackgroundTaskLayout<T>>(),
88                    ty.unwrap(Private).cast(),
89                ) as *mut MaybeUninit<BackgroundTaskLayout<T>>;
90
91                let layout = (&mut *ptr).write(BackgroundTaskLayout::<T>::new(cond));
92
93                let nn_ptr = NonNull::from(layout);
94                BackgroundTask(nn_ptr, PhantomData).root(target)
95            })
96        }
97    }
98
99    fn set(self, value: T::Layout) {
100        unsafe {
101            let layout = self.unwrap_non_null(Private).as_mut();
102            layout.atomic = value;
103        }
104    }
105
106    unsafe fn notify(self) {
107        unsafe {
108            let func = uv_async_send_func();
109            let cond = self.unwrap_non_null(Private).as_ref().cond;
110            let handle_ref = cond.ptr().cast::<*mut c_void>().as_ref();
111            let handle = *handle_ref;
112
113            func(handle);
114        }
115    }
116
117    unsafe fn set_join_handle(self, handle: JoinHandle<JlrsResult<()>>) {
118        unsafe {
119            let layout = self
120                .unwrap_non_null(Private)
121                .cast::<BackgroundTaskLayout<T>>()
122                .as_ref();
123            let mut guard = layout.thread_handle.lock();
124            *guard = Some(handle);
125        }
126    }
127}
128
129impl<T> fmt::Debug for BackgroundTask<'_, T>
130where
131    T: 'static + HasLayout<'static, 'static>,
132    T::Layout: IsBits + Clone,
133{
134    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135        f.debug_tuple("BackgroundTask").finish()
136    }
137}
138
139impl<'scope, 'data, T> ManagedPriv<'scope, 'data> for BackgroundTask<'scope, T>
140where
141    T: 'static + HasLayout<'static, 'static>,
142    T::Layout: IsBits + Clone,
143{
144    type Wraps = BackgroundTaskLayout<'scope, T>;
145
146    type WithLifetimes<'target, 'da> = BackgroundTask<'target, T>;
147
148    const NAME: &'static str = "BackgroundTask";
149
150    unsafe fn wrap_non_null(inner: NonNull<Self::Wraps>, _: crate::private::Private) -> Self {
151        BackgroundTask(inner, PhantomData)
152    }
153
154    fn unwrap_non_null(self, _: crate::private::Private) -> NonNull<Self::Wraps> {
155        self.0
156    }
157}
158
159unsafe impl<'scope, T> ConstructType for BackgroundTask<'scope, T>
160where
161    T: 'static + HasLayout<'static, 'static>,
162    T::Layout: IsBits + Clone,
163{
164    type Static = BackgroundTask<'static, T>;
165
166    fn construct_type_uncached<'target, Tgt>(target: Tgt) -> ValueData<'target, 'static, Tgt>
167    where
168        Tgt: Target<'target>,
169    {
170        target.with_local_scope::<_, 1>(|target, mut frame| unsafe {
171            let t = T::construct_type(&mut frame);
172            let bgtask_ua = JlrsCore::background_task(&target);
173            bgtask_ua.apply_types_unchecked(target, [t])
174        })
175    }
176
177    fn construct_type_with_env_uncached<'target, Tgt>(
178        target: Tgt,
179        _: &TypeVarEnv,
180    ) -> ValueData<'target, 'static, Tgt>
181    where
182        Tgt: Target<'target>,
183    {
184        target.with_local_scope::<_, 1>(|target, mut frame| unsafe {
185            let t = T::construct_type(&mut frame);
186            let bgtask_ua = JlrsCore::background_task(&target);
187            bgtask_ua.apply_types_unchecked(target, [t])
188        })
189    }
190
191    fn base_type<'target, Tgt>(target: &Tgt) -> Option<Value<'target, 'static>>
192    where
193        Tgt: Target<'target>,
194    {
195        Some(JlrsCore::background_task(target).as_value())
196    }
197}
198
199unsafe impl<'scope, T> CCallArg for BackgroundTask<'scope, T>
200where
201    T: 'static + HasLayout<'static, 'static>,
202    T::Layout: IsBits + Clone,
203{
204    type CCallArgType = Value<'scope, 'static>;
205    type FunctionArgType = Self;
206}
207
208unsafe impl<'scope, 'data, T> HasLayout<'scope, 'data> for BackgroundTask<'scope, T>
209where
210    T: 'static + HasLayout<'static, 'static>,
211    T::Layout: IsBits + Clone,
212{
213    type Layout = BackgroundTaskLayout<'scope, T>;
214}
215
216/// A [`BackgroundTask`] that has not been explicitly rooted.
217pub type WeakBackgroundTask<'scope, T> = Weak<'scope, 'static, BackgroundTask<'scope, T>>;
218
219/// A [`WeakBackgroundTask`] with static lifetimes.
220///
221/// This is a useful shorthand for signatures of `ccall`able functions that return a
222/// [`WeakBackgroundTask`].
223pub type BackgroundTaskRet<T> = WeakBackgroundTask<'static, T>;
224
225/// [`BackgroundTask`] or [`WeakBackgroundTask`], depending on the target type `Tgt`.
226pub type BackgroundTaskData<'target, T, Tgt> =
227    <Tgt as TargetType<'target>>::Data<'static, BackgroundTask<'target, T>>;
228
229/// `JuliaResult<BackgroundTask>` or `WeakJuliaResult<WeakBackgroundTask>`, depending on the target
230/// type `Tgt`.
231pub type BackgroundTaskResult<'target, T, Tgt> =
232    TargetResult<'target, 'static, BackgroundTask<'target, T>, Tgt>;
233
234/// Layout of [`BackgroundTask`].
235#[repr(C)]
236pub struct BackgroundTaskLayout<'scope, T>
237where
238    T: 'static + HasLayout<'static, 'static>,
239    T::Layout: IsBits + Clone,
240{
241    fetch_fn: unsafe extern "C" fn(handle: BackgroundTask<T>) -> ValueRet,
242    thread_handle: Box<Mutex<Option<JoinHandle<JlrsResult<()>>>>>,
243    cond: WeakValue<'scope, 'static>,
244    atomic: T::Layout,
245    _pinned: PhantomPinned,
246}
247
248unsafe impl<'scope, T> ValidLayout for BackgroundTaskLayout<'scope, T>
249where
250    T: 'static + HasLayout<'static, 'static>,
251    T::Layout: IsBits + Clone,
252{
253    fn valid_layout(ty: Value) -> bool {
254        if ty.is::<DataType>() {
255            unsafe {
256                let weak_handle = weak_handle_unchecked!();
257                let ty = ty.cast_unchecked::<DataType>();
258                let constructed = BackgroundTask::<T>::construct_type(&weak_handle).as_managed();
259                ty == constructed
260            }
261        } else {
262            false
263        }
264    }
265
266    fn type_object<'target, Tgt: Target<'target>>(target: &Tgt) -> Value<'target, 'static> {
267        JlrsCore::background_task(target).as_value()
268    }
269}
270
271impl<'scope, T> BackgroundTaskLayout<'scope, T>
272where
273    T: 'static + HasLayout<'static, 'static>,
274    T::Layout: IsBits + Clone + CCallReturn,
275{
276    fn new(cond: Value<'_, 'static>) -> Self {
277        let ptr = cond.unwrap_non_null(Private);
278        let cond = WeakValue::wrap(ptr);
279
280        unsafe {
281            BackgroundTaskLayout {
282                fetch_fn: background_task_fetch,
283                thread_handle: Box::new(Mutex::new(None)),
284                cond,
285                atomic: std::mem::zeroed::<T::Layout>(),
286                _pinned: PhantomPinned,
287            }
288        }
289    }
290
291    fn fetch(&self) -> JlrsResult<T::Layout> {
292        // This blocks Julia
293        match self.thread_handle.lock().take() {
294            Some(x) => match unsafe { gc_safe(|| x.join()) } {
295                Ok(Ok(_)) => Ok(self.atomic.clone()),
296                Ok(Err(e)) => Err(e)?,
297                Err(_e) => Err(JlrsError::exception("background task panicked"))?,
298            },
299            _ => Err(JlrsError::exception("already joined"))?,
300        }
301    }
302}
303
304/// Spawn a new background task.
305pub fn spawn_background_task<'scope, 'target, T, F, Tgt>(
306    target: Tgt,
307    func: F,
308) -> BackgroundTaskData<'target, T, Tgt>
309where
310    F: 'static + Send + FnOnce() -> JlrsResult<T::Layout>,
311    T: 'static + HasLayout<'static, 'static>,
312    T::Layout: IsBits + Clone + CCallReturn,
313    Tgt: Target<'target>,
314{
315    struct Sendable<L>(L);
316    impl<L> Sendable<L> {
317        fn inner(self) -> L {
318            self.0
319        }
320    }
321
322    unsafe impl<L> Send for Sendable<L> {}
323
324    unsafe {
325        target.with_local_scope::<_, 1>(|target, mut frame| {
326            let task = BackgroundTask::new(&mut frame);
327            let task_ref = Sendable(task.as_weak().leak());
328
329            let handle = thread::spawn(move || {
330                let task_ref = task_ref.inner();
331                let task = task_ref.as_managed();
332
333                match func() {
334                    Ok(res) => {
335                        task.set(res);
336                        task.notify();
337                        Ok(())
338                    }
339                    Err(e) => Err(e),
340                }
341            });
342
343            task.set_join_handle(handle);
344            task.root(target)
345        })
346    }
347}
348
349// Should only be called from Julia.
350unsafe extern "C" fn background_task_fetch<T>(handle: BackgroundTask<T>) -> ValueRet
351where
352    T: 'static + HasLayout<'static, 'static>,
353    T::Layout: IsBits + Clone + CCallReturn,
354{
355    unsafe {
356        let res = handle
357            .unwrap_non_null(Private)
358            .as_ref()
359            .fetch()
360            .return_or_throw();
361
362        let unrooted = Unrooted::new();
363        Value::try_new_with::<T, _, _>(&unrooted, res)
364            .return_or_throw()
365            .leak()
366    }
367}