Skip to main content

jlrs_derive/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TS2;
5use quote::quote;
6use syn::{self, Meta};
7
8#[derive(Default)]
9struct ClassifiedFields<'a> {
10    rs_flag_fields: Vec<&'a syn::Type>,
11    rs_align_fields: Vec<&'a syn::Type>,
12    rs_union_fields: Vec<&'a syn::Type>,
13    rs_non_union_fields: Vec<&'a syn::Type>,
14    jl_union_field_idxs: Vec<usize>,
15    jl_non_union_field_idxs: Vec<usize>,
16}
17
18impl<'a> ClassifiedFields<'a> {
19    fn classify<I>(fields_iter: I) -> Self
20    where
21        I: Iterator<Item = &'a syn::Field> + ExactSizeIterator,
22    {
23        let mut rs_flag_fields = vec![];
24        let mut rs_align_fields = vec![];
25        let mut rs_union_fields = vec![];
26        let mut rs_non_union_fields = vec![];
27        let mut jl_union_field_idxs = vec![];
28        let mut jl_non_union_field_idxs = vec![];
29        let mut offset = 0;
30
31        'outer: for (idx, field) in fields_iter.enumerate() {
32            for attr in &field.attrs {
33                match JlrsFieldAttr::parse(attr) {
34                    Some(JlrsFieldAttr::BitsUnion) => {
35                        rs_union_fields.push(&field.ty);
36                        jl_union_field_idxs.push(idx - offset);
37                        continue 'outer;
38                    }
39                    Some(JlrsFieldAttr::BitsUnionAlign) => {
40                        rs_align_fields.push(&field.ty);
41                        offset += 1;
42                        continue 'outer;
43                    }
44                    Some(JlrsFieldAttr::BitsUnionFlag) => {
45                        rs_flag_fields.push(&field.ty);
46                        offset += 1;
47                        continue 'outer;
48                    }
49                    _ => (),
50                }
51            }
52
53            rs_non_union_fields.push(&field.ty);
54            jl_non_union_field_idxs.push(idx - offset);
55        }
56
57        ClassifiedFields {
58            rs_flag_fields,
59            rs_align_fields,
60            rs_union_fields,
61            rs_non_union_fields,
62            jl_union_field_idxs,
63            jl_non_union_field_idxs,
64        }
65    }
66}
67
68struct JlrsTypeAttrs {
69    julia_type: Option<String>,
70    zst: bool,
71}
72
73impl JlrsTypeAttrs {
74    fn parse(ast: &syn::DeriveInput) -> Self {
75        let mut julia_type = None;
76        let mut zst = false;
77        for attr in &ast.attrs {
78            if attr.path.is_ident("jlrs") {
79                if let Ok(Meta::List(p)) = attr.parse_meta() {
80                    for item in &p.nested {
81                        match item {
82                            syn::NestedMeta::Meta(Meta::NameValue(nv)) => {
83                                if nv.path.is_ident("julia_type") {
84                                    if let syn::Lit::Str(string) = &nv.lit {
85                                        julia_type = Some(string.value())
86                                    }
87                                }
88                            }
89                            syn::NestedMeta::Meta(Meta::Path(pt)) => {
90                                if pt.is_ident("zst") {
91                                    zst = true;
92                                }
93                            }
94                            _ => continue,
95                        }
96                    }
97                }
98            }
99        }
100
101        JlrsTypeAttrs { julia_type, zst }
102    }
103}
104
105enum JlrsFieldAttr {
106    BitsUnionAlign,
107    BitsUnion,
108    BitsUnionFlag,
109}
110
111impl JlrsFieldAttr {
112    pub fn parse(attr: &syn::Attribute) -> Option<Self> {
113        if let Ok(Meta::List(p)) = attr.parse_meta() {
114            if let Some(syn::NestedMeta::Meta(syn::Meta::Path(m))) = p.nested.first() {
115                if m.is_ident("bits_union") {
116                    return Some(JlrsFieldAttr::BitsUnion);
117                }
118
119                if m.is_ident("bits_union_align") {
120                    return Some(JlrsFieldAttr::BitsUnionAlign);
121                }
122
123                if m.is_ident("bits_union_flag") {
124                    return Some(JlrsFieldAttr::BitsUnionFlag);
125                }
126            }
127        }
128
129        None
130    }
131}
132
133#[proc_macro_derive(IntoJulia, attributes(jlrs))]
134pub fn into_julia_derive(input: TokenStream) -> TokenStream {
135    let ast = syn::parse(input).unwrap();
136    impl_into_julia(&ast)
137}
138
139#[proc_macro_derive(Unbox, attributes(jlrs))]
140pub fn unbox_derive(input: TokenStream) -> TokenStream {
141    let ast = syn::parse(input).unwrap();
142    impl_unbox(&ast)
143}
144
145#[proc_macro_derive(Typecheck, attributes(jlrs))]
146pub fn typecheck_derive(input: TokenStream) -> TokenStream {
147    let ast = syn::parse(input).unwrap();
148    impl_typecheck(&ast)
149}
150
151#[proc_macro_derive(ValidLayout, attributes(jlrs))]
152pub fn valid_layout_derive(input: TokenStream) -> TokenStream {
153    let ast = syn::parse(input).unwrap();
154    impl_valid_layout(&ast)
155}
156
157#[proc_macro_derive(ValidField, attributes(jlrs))]
158pub fn valid_field_derive(input: TokenStream) -> TokenStream {
159    let ast = syn::parse(input).unwrap();
160    impl_valid_field(&ast)
161}
162
163fn impl_into_julia(ast: &syn::DeriveInput) -> TokenStream {
164    let name = &ast.ident;
165    if !is_repr_c(ast) {
166        panic!("IntoJulia can only be derived for types with the attribute #[repr(C)].");
167    }
168
169    let mut attrs = JlrsTypeAttrs::parse(ast);
170    let jl_type = attrs.julia_type
171        .take()
172        .expect("IntoJulia can only be derived if the corresponding Julia type is set with #[julia_type = \"Main.MyModule.Submodule.StructType\"]");
173
174    let mut type_it = jl_type.split('.');
175    let func = match type_it.next() {
176        Some("Main") => quote::format_ident!("main"),
177        Some("Base") => quote::format_ident!("base"),
178        Some("Core") => quote::format_ident!("core"),
179        _ => panic!("IntoJulia can only be derived if the first module of \"julia_type\" is either \"Main\", \"Base\" or \"Core\"."),
180    };
181
182    let mut modules = type_it.collect::<Vec<_>>();
183    let ty = modules.pop().expect("IntoJulia can only be derived if the corresponding Julia type is set with #[jlrs(julia_type = \"Main.MyModule.Submodule.StructType\")]");
184    let modules_it = modules.iter();
185    let modules_it_b = modules_it.clone();
186
187    let into_julia_fn = impl_into_julia_fn(&attrs);
188
189    let into_julia_impl = quote! {
190        unsafe impl ::jlrs::convert::into_julia::IntoJulia for #name {
191            fn julia_type<'scope, T>(target: T) -> ::jlrs::wrappers::ptr::datatype::DataTypeData<'scope, T>
192            where
193                T: ::jlrs::memory::target::Target<'scope>,
194            {
195                unsafe {
196                    let global = target.unrooted();
197                    ::jlrs::wrappers::ptr::module::Module::#func(&global)
198                        #(
199                            .submodule(&global, #modules_it)
200                            .expect(&format!("Submodule {} cannot be found", #modules_it_b))
201                            .wrapper()
202                        )*
203                        .global(&global, #ty)
204                        .expect(&format!("Type {} cannot be found in module", #ty))
205                        .value()
206                        .cast::<::jlrs::wrappers::ptr::datatype::DataType>()
207                        .expect("Type is not a DataType")
208                        .root(target)
209                }
210            }
211
212            #into_julia_fn
213        }
214    };
215
216    into_julia_impl.into()
217}
218
219fn impl_into_julia_fn(attrs: &JlrsTypeAttrs) -> Option<TS2> {
220    if attrs.zst {
221        Some(quote! {
222            unsafe fn into_julia<'target, T>(self, target: T) -> ::jlrs::wrappers::ptr::value::ValueData<'target, 'static, T>
223            where
224                T: ::jlrs::memory::target::Target<'scope>,
225            {
226                let ty = self.julia_type(global);
227                unsafe {
228                    ty.wrapper()
229                        .instance()
230                        .value()
231                        .expect("Instance is undefined")
232                        .as_ref()
233                }
234            }
235        })
236    } else {
237        None
238    }
239}
240
241fn impl_unbox(ast: &syn::DeriveInput) -> TokenStream {
242    let name = &ast.ident;
243    if !is_repr_c(ast) {
244        panic!("Unbox can only be derived for types with the attribute #[repr(C)].");
245    }
246
247    let generics = &ast.generics;
248    let where_clause = &ast.generics.where_clause;
249
250    let unbox_impl = quote! {
251        unsafe impl #generics ::jlrs::convert::unbox::Unbox for #name #generics #where_clause {
252            type Output = Self;
253        }
254    };
255
256    unbox_impl.into()
257}
258
259fn impl_typecheck(ast: &syn::DeriveInput) -> TokenStream {
260    let name = &ast.ident;
261    if !is_repr_c(ast) {
262        panic!("Typecheck can only be derived for types with the attribute #[repr(C)].");
263    }
264
265    let generics = &ast.generics;
266    let where_clause = &ast.generics.where_clause;
267
268    let typecheck_impl = quote! {
269        unsafe impl #generics ::jlrs::layout::typecheck::Typecheck for #name #generics #where_clause {
270            fn typecheck(dt: ::jlrs::wrappers::ptr::datatype::DataType) -> bool {
271                <Self as ::jlrs::layout::valid_layout::ValidLayout>::valid_layout(dt.as_value())
272            }
273        }
274    };
275
276    typecheck_impl.into()
277}
278
279fn impl_valid_layout(ast: &syn::DeriveInput) -> TokenStream {
280    let name = &ast.ident;
281    if !is_repr_c(ast) {
282        panic!("ValidLayout can only be derived for types with the attribute #[repr(C)].");
283    }
284
285    let generics = &ast.generics;
286    let where_clause = &ast.generics.where_clause;
287
288    let fields = match &ast.data {
289        syn::Data::Struct(s) => &s.fields,
290        _ => panic!("Julia struct can only be derived for structs."),
291    };
292
293    let classified_fields = match fields {
294        syn::Fields::Named(n) => ClassifiedFields::classify(n.named.iter()),
295        syn::Fields::Unit => ClassifiedFields::default(),
296        _ => panic!("Julia struct cannot be derived for tuple structs."),
297    };
298
299    let rs_flag_fields = classified_fields.rs_flag_fields.iter();
300    let rs_align_fields = classified_fields.rs_align_fields.iter();
301    let rs_union_fields = classified_fields.rs_union_fields.iter();
302    let rs_non_union_fields = classified_fields.rs_non_union_fields.iter();
303    let jl_union_field_idxs = classified_fields.jl_union_field_idxs.iter();
304    let jl_non_union_field_idxs = classified_fields.jl_non_union_field_idxs.iter();
305
306    let n_fields = classified_fields.jl_union_field_idxs.len()
307        + classified_fields.jl_non_union_field_idxs.len();
308
309    let valid_layout_impl = quote! {
310        unsafe impl #generics ::jlrs::layout::valid_layout::ValidLayout for #name #generics #where_clause {
311            fn valid_layout(v: ::jlrs::wrappers::ptr::value::Value) -> bool {
312                unsafe {
313                    if let Ok(dt) = v.cast::<::jlrs::wrappers::ptr::datatype::DataType>() {
314                        if dt.n_fields() as usize != #n_fields {
315                            return false;
316                        }
317
318                        let global = v.unrooted_target();
319                        let field_types = dt.field_types(global);
320                        let field_types_svec = field_types.wrapper();
321                        let field_types_data = field_types_svec.data();
322                        let field_types = field_types_data.as_slice();
323
324                        #(
325                            if !<#rs_non_union_fields as ::jlrs::layout::valid_layout::ValidField>::valid_field(field_types[#jl_non_union_field_idxs].unwrap().wrapper()) {
326                                return false;
327                            }
328                        )*
329
330                        #(
331                            if let Ok(u) = field_types[#jl_union_field_idxs].unwrap().wrapper().cast::<::jlrs::wrappers::ptr::union::Union>() {
332                                if !::jlrs::wrappers::inline::union::correct_layout_for::<#rs_align_fields, #rs_union_fields, #rs_flag_fields>(u) {
333                                    return false
334                                }
335                            } else {
336                                return false
337                            }
338                        )*
339
340
341                        return true;
342                    }
343                }
344
345                false
346            }
347
348            const IS_REF: bool = false;
349        }
350    };
351
352    valid_layout_impl.into()
353}
354
355fn impl_valid_field(ast: &syn::DeriveInput) -> TokenStream {
356    let name = &ast.ident;
357    if !is_repr_c(ast) {
358        panic!("ValidLayout can only be derived for types with the attribute #[repr(C)].");
359    }
360
361    let generics = &ast.generics;
362    let where_clause = &ast.generics.where_clause;
363
364    let valid_field_impl = quote! {
365        unsafe impl #generics ::jlrs::layout::valid_layout::ValidField for #name #generics #where_clause {
366            fn valid_field(v: ::jlrs::wrappers::ptr::value::Value) -> bool {
367                <Self as ::jlrs::layout::valid_layout::ValidLayout>::valid_layout(v)
368            }
369        }
370    };
371
372    valid_field_impl.into()
373}
374
375fn is_repr_c(ast: &syn::DeriveInput) -> bool {
376    for attr in &ast.attrs {
377        if attr.path.is_ident("repr") {
378            if let Ok(Meta::List(p)) = attr.parse_meta() {
379                if let Some(syn::NestedMeta::Meta(syn::Meta::Path(m))) = p.nested.first() {
380                    if m.is_ident("C") {
381                        return true;
382                    }
383                }
384            }
385        }
386    }
387
388    false
389}