jlrs/data/managed/
union_all.rs

1//! Managed type for `UnionAll`, A union of types over all values of a type parameter.
2
3use std::{marker::PhantomData, ptr::NonNull};
4
5use jl_sys::{
6    jl_abstractarray_type, jl_anytuple_type_type, jl_apply_type, jl_array_type, jl_densearray_type,
7    jl_llvmpointer_type, jl_namedtuple_type, jl_pointer_type, jl_ref_type, jl_type_type,
8    jl_type_unionall, jl_unionall_t, jl_unionall_type, jl_value_t,
9};
10use jlrs_macros::julia_version;
11use jlrs_sys::{jlrs_unionall_body, jlrs_unionall_tvar};
12
13use super::{
14    Managed, Weak, erase_scope_lifetime,
15    value::{ValueData, ValueResult},
16};
17use crate::{
18    catch::{catch_exceptions, unwrap_exc},
19    data::managed::{datatype::DataType, private::ManagedPriv, type_var::TypeVar, value::Value},
20    impl_julia_typecheck,
21    memory::{
22        scope::LocalScopeExt,
23        target::{Target, TargetResult},
24    },
25    private::Private,
26};
27
28/// An iterated union of types. If a struct field has a parametric type with some of its
29/// parameters unknown, its type is represented by a `UnionAll`.
30#[derive(Copy, Clone)]
31#[repr(transparent)]
32pub struct UnionAll<'scope>(NonNull<jl_unionall_t>, PhantomData<&'scope ()>);
33
34impl<'scope> UnionAll<'scope> {
35    /// Create a new `UnionAll`. If an exception is thrown, it's caught and returned.
36    pub fn new<'target, Tgt>(
37        target: Tgt,
38        tvar: TypeVar,
39        body: Value<'_, 'static>,
40    ) -> ValueResult<'target, 'static, Tgt>
41    where
42        Tgt: Target<'target>,
43    {
44        // Safety: if an exception is thrown it's caught, the result is immediately rooted
45        unsafe {
46            let callback = || jl_type_unionall(tvar.unwrap(Private), body.unwrap(Private));
47
48            let res = match catch_exceptions(callback, unwrap_exc) {
49                Ok(ptr) => Ok(NonNull::new_unchecked(ptr)),
50                Err(e) => Err(e),
51            };
52
53            target.result_from_ptr(res, Private)
54        }
55    }
56
57    /// Create a new `UnionAll`. If an exception is thrown it isn't caught
58    ///
59    /// Safety: an exception must not be thrown if this method is called from a `ccall`ed
60    /// function.
61    #[inline]
62    pub unsafe fn new_unchecked<'target, Tgt>(
63        target: Tgt,
64        tvar: TypeVar,
65        body: Value<'_, 'static>,
66    ) -> ValueData<'target, 'static, Tgt>
67    where
68        Tgt: Target<'target>,
69    {
70        unsafe {
71            let ua = jl_type_unionall(tvar.unwrap(Private), body.unwrap(Private));
72            target.data_from_ptr(NonNull::new_unchecked(ua), Private)
73        }
74    }
75
76    /// The type at the bottom of this `UnionAll`.
77    #[inline]
78    pub fn base_type(self) -> DataType<'scope> {
79        let mut b = self;
80
81        unsafe {
82            // Safety: pointer points to valid data
83            while b.body().is::<UnionAll>() {
84                b = b.body().cast_unchecked();
85            }
86
87            // Safety: type at the base must be a DataType
88            b.body().cast_unchecked::<DataType>()
89        }
90    }
91
92    /// The body of this `UnionAll`. This is either another `UnionAll` or a `DataType`.
93    #[inline]
94    pub fn body(self) -> Value<'scope, 'static> {
95        // Safety: pointer points to valid data
96        unsafe {
97            let body = jlrs_unionall_body(self.unwrap(Private));
98            debug_assert!(!body.is_null());
99            Value::wrap_non_null(NonNull::new_unchecked(body), Private)
100        }
101    }
102
103    /// The type variable associated with this "layer" of the `UnionAll`.
104    #[inline]
105    pub fn var(self) -> TypeVar<'scope> {
106        // Safety: pointer points to valid data
107        unsafe {
108            let var = jlrs_unionall_tvar(self.unwrap(Private));
109            debug_assert!(!var.is_null());
110            TypeVar::wrap_non_null(NonNull::new_unchecked(var), Private)
111        }
112    }
113
114    /// Apply `types` to this `UnionAll`.
115    ///
116    /// If the result has free type parameters, it's returned as a `DataType` with free type
117    /// parameters. Call `UnionAll::rewrap` to turn such a type into a `UnionAll`.
118    pub unsafe fn apply_types<'target, 'params, V, Tgt>(
119        self,
120        target: Tgt,
121        types: V,
122    ) -> ValueResult<'target, 'static, Tgt>
123    where
124        V: AsRef<[Value<'params, 'static>]>,
125        Tgt: Target<'target>,
126    {
127        let types = types.as_ref();
128        let n = types.len();
129        let types_ptr = types.as_ptr() as *mut *mut jl_value_t;
130        unsafe {
131            let callback = || jl_apply_type(self.as_value().unwrap(Private), types_ptr, n);
132
133            let res = match catch_exceptions(callback, unwrap_exc) {
134                Ok(ptr) => Ok(NonNull::new_unchecked(ptr)),
135                Err(e) => Err(e),
136            };
137
138            target.result_from_ptr(res, Private)
139        }
140    }
141
142    /// Apply `types` to this `UnionAll` without catching exceptions.
143    ///
144    /// If the result has free type parameters, it's returned as a `DataType` with free type
145    /// parameters. Call `UnionAll::rewrap` to turn such a type into a `UnionAll`.
146    ///
147    /// Safety: if an exception is throw it isn't caught.
148    #[inline]
149    pub unsafe fn apply_types_unchecked<'target, 'params, V, Tgt>(
150        self,
151        target: Tgt,
152        types: V,
153    ) -> ValueData<'target, 'static, Tgt>
154    where
155        V: AsRef<[Value<'params, 'static>]>,
156        Tgt: Target<'target>,
157    {
158        unsafe {
159            let types = types.as_ref();
160            let n = types.len();
161            let types_ptr = types.as_ptr() as *mut *mut jl_value_t;
162            let applied = jl_apply_type(self.as_value().unwrap(Private), types_ptr, n);
163            debug_assert!(!applied.is_null());
164            target.data_from_ptr(NonNull::new_unchecked(applied), Private)
165        }
166    }
167
168    /// Wrap `ty` with its free type parameters.
169    pub fn rewrap<'target, Tgt: Target<'target>>(
170        target: Tgt,
171        ty: DataType,
172    ) -> ValueData<'target, 'static, Tgt> {
173        target.with_local_scope::<_, 1>(|target, mut frame| unsafe {
174            let params = ty.parameters();
175            let params = params.data();
176            let mut output = frame.output();
177            let mut body = erase_scope_lifetime(ty.as_value());
178
179            for pidx in (0..params.len()).rev() {
180                let param = params.get(&target, pidx);
181                let param = param.unwrap_unchecked().as_value();
182                if param.is::<TypeVar>() {
183                    let tvar = param.cast_unchecked::<TypeVar>();
184                    let b = UnionAll::new_unchecked(&mut output, tvar, body).as_value();
185                    body = erase_scope_lifetime(b);
186                }
187            }
188
189            body.root(target)
190        })
191    }
192}
193
194impl<'base> UnionAll<'base> {
195    /// The `UnionAll` `Type`.
196    #[inline]
197    pub fn type_type<Tgt>(_: &Tgt) -> Self
198    where
199        Tgt: Target<'base>,
200    {
201        // Safety: global constant
202        unsafe { Self::wrap_non_null(NonNull::new_unchecked(jl_type_type), Private) }
203    }
204
205    /// `Type{Tgt} where Tgt<:Tuple`
206    #[inline]
207    pub fn anytuple_type_type<Tgt>(_: &Tgt) -> Self
208    where
209        Tgt: Target<'base>,
210    {
211        // Safety: global constant
212        unsafe { Self::wrap_non_null(NonNull::new_unchecked(jl_anytuple_type_type), Private) }
213    }
214
215    /// The `UnionAll` `AbstractArray`.
216    #[inline]
217    pub fn abstractarray_type<Tgt>(_: &Tgt) -> Self
218    where
219        Tgt: Target<'base>,
220    {
221        // Safety: global constant
222        unsafe { Self::wrap_non_null(NonNull::new_unchecked(jl_abstractarray_type), Private) }
223    }
224
225    /// The `UnionAll` `DenseArray`.
226    #[inline]
227    pub fn densearray_type<Tgt>(_: &Tgt) -> Self
228    where
229        Tgt: Target<'base>,
230    {
231        // Safety: global constant
232        unsafe { Self::wrap_non_null(NonNull::new_unchecked(jl_densearray_type), Private) }
233    }
234
235    /// The `UnionAll` `Array`.
236    #[inline]
237    pub fn array_type<Tgt>(_: &Tgt) -> Self
238    where
239        Tgt: Target<'base>,
240    {
241        // Safety: global constant
242        unsafe { Self::wrap_non_null(NonNull::new_unchecked(jl_array_type), Private) }
243    }
244
245    /// The `UnionAll` `Ptr`.
246    #[inline]
247    pub fn pointer_type<Tgt>(_: &Tgt) -> Self
248    where
249        Tgt: Target<'base>,
250    {
251        // Safety: global constant
252        unsafe { Self::wrap_non_null(NonNull::new_unchecked(jl_pointer_type), Private) }
253    }
254
255    /// The `UnionAll` `LLVMPtr`.
256    #[inline]
257    pub fn llvmpointer_type<Tgt>(_: &Tgt) -> Self
258    where
259        Tgt: Target<'base>,
260    {
261        // Safety: global constant
262        unsafe { Self::wrap_non_null(NonNull::new_unchecked(jl_llvmpointer_type), Private) }
263    }
264
265    /// The `UnionAll` `Ref`.
266    #[inline]
267    pub fn ref_type<Tgt>(_: &Tgt) -> Self
268    where
269        Tgt: Target<'base>,
270    {
271        // Safety: global constant
272        unsafe { Self::wrap_non_null(NonNull::new_unchecked(jl_ref_type), Private) }
273    }
274
275    /// The `UnionAll` `NamedTuple`.
276    #[inline]
277    pub fn namedtuple_type<Tgt>(_: &Tgt) -> Self
278    where
279        Tgt: Target<'base>,
280    {
281        // Safety: global constant
282        unsafe { Self::wrap_non_null(NonNull::new_unchecked(jl_namedtuple_type), Private) }
283    }
284
285    #[julia_version(since = "1.11")]
286    /// The `UnionAll` `GenericMemory`.
287    #[inline]
288    pub fn genericmemory_type<Tgt>(_: &Tgt) -> Self
289    where
290        Tgt: Target<'base>,
291    {
292        // Safety: global constant
293        unsafe {
294            Self::wrap_non_null(
295                NonNull::new_unchecked(jl_sys::jl_genericmemory_type),
296                Private,
297            )
298        }
299    }
300
301    #[julia_version(since = "1.11")]
302    /// The `UnionAll` `GenericMemoryRef`.
303    #[inline]
304    pub fn genericmemoryref_type<Tgt>(_: &Tgt) -> Self
305    where
306        Tgt: Target<'base>,
307    {
308        // Safety: global constant
309        unsafe {
310            Self::wrap_non_null(
311                NonNull::new_unchecked(jl_sys::jl_genericmemoryref_type),
312                Private,
313            )
314        }
315    }
316}
317
318impl_julia_typecheck!(UnionAll<'scope>, jl_unionall_type, 'scope);
319impl_debug!(UnionAll<'_>);
320
321impl<'scope> ManagedPriv<'scope, '_> for UnionAll<'scope> {
322    type Wraps = jl_unionall_t;
323    type WithLifetimes<'target, 'da> = UnionAll<'target>;
324    const NAME: &'static str = "UnionAll";
325
326    // Safety: `inner` must not have been freed yet, the result must never be
327    // used after the GC might have freed it.
328    #[inline]
329    unsafe fn wrap_non_null(inner: NonNull<Self::Wraps>, _: Private) -> Self {
330        Self(inner, PhantomData)
331    }
332
333    #[inline]
334    fn unwrap_non_null(self, _: Private) -> NonNull<Self::Wraps> {
335        self.0
336    }
337}
338
339impl_construct_type_managed!(UnionAll, 1, jl_unionall_type);
340
341/// A [`UnionAll`] that has not been explicitly rooted.
342pub type WeakUnionAll<'scope> = Weak<'scope, 'static, UnionAll<'scope>>;
343
344/// A [`WeakUnionAll`] with static lifetimes. This is a useful shorthand for signatures of
345/// `ccall`able functions that return a [`UnionAll`].
346pub type UnionAllRet = WeakUnionAll<'static>;
347
348impl_valid_layout!(WeakUnionAll, UnionAll, jl_unionall_type);
349
350use crate::memory::target::TargetType;
351
352/// `UnionAll` or `WeakUnionAll`, depending on the target type `Tgt`.
353pub type UnionAllData<'target, Tgt> =
354    <Tgt as TargetType<'target>>::Data<'static, UnionAll<'target>>;
355
356/// `JuliaResult<UnionAll>` or `WeakJuliaResult<WeakUnionAll>`, depending on the target type `Tgt`.
357pub type UnionAllResult<'target, Tgt> = TargetResult<'target, 'static, UnionAll<'target>, Tgt>;
358
359impl_ccall_arg_managed!(UnionAll, 1);
360impl_into_typed!(UnionAll);