1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TS2;
5use quote::quote;
6use syn::{self, Meta};
7
8#[derive(Default)]
9struct ClassifiedFields<'a> {
10 rs_flag_fields: Vec<&'a syn::Type>,
11 rs_align_fields: Vec<&'a syn::Type>,
12 rs_union_fields: Vec<&'a syn::Type>,
13 rs_non_union_fields: Vec<&'a syn::Type>,
14 jl_union_field_idxs: Vec<usize>,
15 jl_non_union_field_idxs: Vec<usize>,
16}
17
18impl<'a> ClassifiedFields<'a> {
19 fn classify<I>(fields_iter: I) -> Self
20 where
21 I: Iterator<Item = &'a syn::Field> + ExactSizeIterator,
22 {
23 let mut rs_flag_fields = vec![];
24 let mut rs_align_fields = vec![];
25 let mut rs_union_fields = vec![];
26 let mut rs_non_union_fields = vec![];
27 let mut jl_union_field_idxs = vec![];
28 let mut jl_non_union_field_idxs = vec![];
29 let mut offset = 0;
30
31 'outer: for (idx, field) in fields_iter.enumerate() {
32 for attr in &field.attrs {
33 match JlrsFieldAttr::parse(attr) {
34 Some(JlrsFieldAttr::BitsUnion) => {
35 rs_union_fields.push(&field.ty);
36 jl_union_field_idxs.push(idx - offset);
37 continue 'outer;
38 }
39 Some(JlrsFieldAttr::BitsUnionAlign) => {
40 rs_align_fields.push(&field.ty);
41 offset += 1;
42 continue 'outer;
43 }
44 Some(JlrsFieldAttr::BitsUnionFlag) => {
45 rs_flag_fields.push(&field.ty);
46 offset += 1;
47 continue 'outer;
48 }
49 _ => (),
50 }
51 }
52
53 rs_non_union_fields.push(&field.ty);
54 jl_non_union_field_idxs.push(idx - offset);
55 }
56
57 ClassifiedFields {
58 rs_flag_fields,
59 rs_align_fields,
60 rs_union_fields,
61 rs_non_union_fields,
62 jl_union_field_idxs,
63 jl_non_union_field_idxs,
64 }
65 }
66}
67
68struct JlrsTypeAttrs {
69 julia_type: Option<String>,
70 zst: bool,
71}
72
73impl JlrsTypeAttrs {
74 fn parse(ast: &syn::DeriveInput) -> Self {
75 let mut julia_type = None;
76 let mut zst = false;
77 for attr in &ast.attrs {
78 if attr.path.is_ident("jlrs") {
79 if let Ok(Meta::List(p)) = attr.parse_meta() {
80 for item in &p.nested {
81 match item {
82 syn::NestedMeta::Meta(Meta::NameValue(nv)) => {
83 if nv.path.is_ident("julia_type") {
84 if let syn::Lit::Str(string) = &nv.lit {
85 julia_type = Some(string.value())
86 }
87 }
88 }
89 syn::NestedMeta::Meta(Meta::Path(pt)) => {
90 if pt.is_ident("zst") {
91 zst = true;
92 }
93 }
94 _ => continue,
95 }
96 }
97 }
98 }
99 }
100
101 JlrsTypeAttrs { julia_type, zst }
102 }
103}
104
105enum JlrsFieldAttr {
106 BitsUnionAlign,
107 BitsUnion,
108 BitsUnionFlag,
109}
110
111impl JlrsFieldAttr {
112 pub fn parse(attr: &syn::Attribute) -> Option<Self> {
113 if let Ok(Meta::List(p)) = attr.parse_meta() {
114 if let Some(syn::NestedMeta::Meta(syn::Meta::Path(m))) = p.nested.first() {
115 if m.is_ident("bits_union") {
116 return Some(JlrsFieldAttr::BitsUnion);
117 }
118
119 if m.is_ident("bits_union_align") {
120 return Some(JlrsFieldAttr::BitsUnionAlign);
121 }
122
123 if m.is_ident("bits_union_flag") {
124 return Some(JlrsFieldAttr::BitsUnionFlag);
125 }
126 }
127 }
128
129 None
130 }
131}
132
133#[proc_macro_derive(IntoJulia, attributes(jlrs))]
134pub fn into_julia_derive(input: TokenStream) -> TokenStream {
135 let ast = syn::parse(input).unwrap();
136 impl_into_julia(&ast)
137}
138
139#[proc_macro_derive(Unbox, attributes(jlrs))]
140pub fn unbox_derive(input: TokenStream) -> TokenStream {
141 let ast = syn::parse(input).unwrap();
142 impl_unbox(&ast)
143}
144
145#[proc_macro_derive(Typecheck, attributes(jlrs))]
146pub fn typecheck_derive(input: TokenStream) -> TokenStream {
147 let ast = syn::parse(input).unwrap();
148 impl_typecheck(&ast)
149}
150
151#[proc_macro_derive(ValidLayout, attributes(jlrs))]
152pub fn valid_layout_derive(input: TokenStream) -> TokenStream {
153 let ast = syn::parse(input).unwrap();
154 impl_valid_layout(&ast)
155}
156
157#[proc_macro_derive(ValidField, attributes(jlrs))]
158pub fn valid_field_derive(input: TokenStream) -> TokenStream {
159 let ast = syn::parse(input).unwrap();
160 impl_valid_field(&ast)
161}
162
163fn impl_into_julia(ast: &syn::DeriveInput) -> TokenStream {
164 let name = &ast.ident;
165 if !is_repr_c(ast) {
166 panic!("IntoJulia can only be derived for types with the attribute #[repr(C)].");
167 }
168
169 let mut attrs = JlrsTypeAttrs::parse(ast);
170 let jl_type = attrs.julia_type
171 .take()
172 .expect("IntoJulia can only be derived if the corresponding Julia type is set with #[julia_type = \"Main.MyModule.Submodule.StructType\"]");
173
174 let mut type_it = jl_type.split('.');
175 let func = match type_it.next() {
176 Some("Main") => quote::format_ident!("main"),
177 Some("Base") => quote::format_ident!("base"),
178 Some("Core") => quote::format_ident!("core"),
179 _ => panic!("IntoJulia can only be derived if the first module of \"julia_type\" is either \"Main\", \"Base\" or \"Core\"."),
180 };
181
182 let mut modules = type_it.collect::<Vec<_>>();
183 let ty = modules.pop().expect("IntoJulia can only be derived if the corresponding Julia type is set with #[jlrs(julia_type = \"Main.MyModule.Submodule.StructType\")]");
184 let modules_it = modules.iter();
185 let modules_it_b = modules_it.clone();
186
187 let into_julia_fn = impl_into_julia_fn(&attrs);
188
189 let into_julia_impl = quote! {
190 unsafe impl ::jlrs::convert::into_julia::IntoJulia for #name {
191 fn julia_type<'scope, T>(target: T) -> ::jlrs::wrappers::ptr::datatype::DataTypeData<'scope, T>
192 where
193 T: ::jlrs::memory::target::Target<'scope>,
194 {
195 unsafe {
196 let global = target.unrooted();
197 ::jlrs::wrappers::ptr::module::Module::#func(&global)
198 #(
199 .submodule(&global, #modules_it)
200 .expect(&format!("Submodule {} cannot be found", #modules_it_b))
201 .wrapper()
202 )*
203 .global(&global, #ty)
204 .expect(&format!("Type {} cannot be found in module", #ty))
205 .value()
206 .cast::<::jlrs::wrappers::ptr::datatype::DataType>()
207 .expect("Type is not a DataType")
208 .root(target)
209 }
210 }
211
212 #into_julia_fn
213 }
214 };
215
216 into_julia_impl.into()
217}
218
219fn impl_into_julia_fn(attrs: &JlrsTypeAttrs) -> Option<TS2> {
220 if attrs.zst {
221 Some(quote! {
222 unsafe fn into_julia<'target, T>(self, target: T) -> ::jlrs::wrappers::ptr::value::ValueData<'target, 'static, T>
223 where
224 T: ::jlrs::memory::target::Target<'scope>,
225 {
226 let ty = self.julia_type(global);
227 unsafe {
228 ty.wrapper()
229 .instance()
230 .value()
231 .expect("Instance is undefined")
232 .as_ref()
233 }
234 }
235 })
236 } else {
237 None
238 }
239}
240
241fn impl_unbox(ast: &syn::DeriveInput) -> TokenStream {
242 let name = &ast.ident;
243 if !is_repr_c(ast) {
244 panic!("Unbox can only be derived for types with the attribute #[repr(C)].");
245 }
246
247 let generics = &ast.generics;
248 let where_clause = &ast.generics.where_clause;
249
250 let unbox_impl = quote! {
251 unsafe impl #generics ::jlrs::convert::unbox::Unbox for #name #generics #where_clause {
252 type Output = Self;
253 }
254 };
255
256 unbox_impl.into()
257}
258
259fn impl_typecheck(ast: &syn::DeriveInput) -> TokenStream {
260 let name = &ast.ident;
261 if !is_repr_c(ast) {
262 panic!("Typecheck can only be derived for types with the attribute #[repr(C)].");
263 }
264
265 let generics = &ast.generics;
266 let where_clause = &ast.generics.where_clause;
267
268 let typecheck_impl = quote! {
269 unsafe impl #generics ::jlrs::layout::typecheck::Typecheck for #name #generics #where_clause {
270 fn typecheck(dt: ::jlrs::wrappers::ptr::datatype::DataType) -> bool {
271 <Self as ::jlrs::layout::valid_layout::ValidLayout>::valid_layout(dt.as_value())
272 }
273 }
274 };
275
276 typecheck_impl.into()
277}
278
279fn impl_valid_layout(ast: &syn::DeriveInput) -> TokenStream {
280 let name = &ast.ident;
281 if !is_repr_c(ast) {
282 panic!("ValidLayout can only be derived for types with the attribute #[repr(C)].");
283 }
284
285 let generics = &ast.generics;
286 let where_clause = &ast.generics.where_clause;
287
288 let fields = match &ast.data {
289 syn::Data::Struct(s) => &s.fields,
290 _ => panic!("Julia struct can only be derived for structs."),
291 };
292
293 let classified_fields = match fields {
294 syn::Fields::Named(n) => ClassifiedFields::classify(n.named.iter()),
295 syn::Fields::Unit => ClassifiedFields::default(),
296 _ => panic!("Julia struct cannot be derived for tuple structs."),
297 };
298
299 let rs_flag_fields = classified_fields.rs_flag_fields.iter();
300 let rs_align_fields = classified_fields.rs_align_fields.iter();
301 let rs_union_fields = classified_fields.rs_union_fields.iter();
302 let rs_non_union_fields = classified_fields.rs_non_union_fields.iter();
303 let jl_union_field_idxs = classified_fields.jl_union_field_idxs.iter();
304 let jl_non_union_field_idxs = classified_fields.jl_non_union_field_idxs.iter();
305
306 let n_fields = classified_fields.jl_union_field_idxs.len()
307 + classified_fields.jl_non_union_field_idxs.len();
308
309 let valid_layout_impl = quote! {
310 unsafe impl #generics ::jlrs::layout::valid_layout::ValidLayout for #name #generics #where_clause {
311 fn valid_layout(v: ::jlrs::wrappers::ptr::value::Value) -> bool {
312 unsafe {
313 if let Ok(dt) = v.cast::<::jlrs::wrappers::ptr::datatype::DataType>() {
314 if dt.n_fields() as usize != #n_fields {
315 return false;
316 }
317
318 let global = v.unrooted_target();
319 let field_types = dt.field_types(global);
320 let field_types_svec = field_types.wrapper();
321 let field_types_data = field_types_svec.data();
322 let field_types = field_types_data.as_slice();
323
324 #(
325 if !<#rs_non_union_fields as ::jlrs::layout::valid_layout::ValidField>::valid_field(field_types[#jl_non_union_field_idxs].unwrap().wrapper()) {
326 return false;
327 }
328 )*
329
330 #(
331 if let Ok(u) = field_types[#jl_union_field_idxs].unwrap().wrapper().cast::<::jlrs::wrappers::ptr::union::Union>() {
332 if !::jlrs::wrappers::inline::union::correct_layout_for::<#rs_align_fields, #rs_union_fields, #rs_flag_fields>(u) {
333 return false
334 }
335 } else {
336 return false
337 }
338 )*
339
340
341 return true;
342 }
343 }
344
345 false
346 }
347
348 const IS_REF: bool = false;
349 }
350 };
351
352 valid_layout_impl.into()
353}
354
355fn impl_valid_field(ast: &syn::DeriveInput) -> TokenStream {
356 let name = &ast.ident;
357 if !is_repr_c(ast) {
358 panic!("ValidLayout can only be derived for types with the attribute #[repr(C)].");
359 }
360
361 let generics = &ast.generics;
362 let where_clause = &ast.generics.where_clause;
363
364 let valid_field_impl = quote! {
365 unsafe impl #generics ::jlrs::layout::valid_layout::ValidField for #name #generics #where_clause {
366 fn valid_field(v: ::jlrs::wrappers::ptr::value::Value) -> bool {
367 <Self as ::jlrs::layout::valid_layout::ValidLayout>::valid_layout(v)
368 }
369 }
370 };
371
372 valid_field_impl.into()
373}
374
375fn is_repr_c(ast: &syn::DeriveInput) -> bool {
376 for attr in &ast.attrs {
377 if attr.path.is_ident("repr") {
378 if let Ok(Meta::List(p)) = attr.parse_meta() {
379 if let Some(syn::NestedMeta::Meta(syn::Meta::Path(m))) = p.nested.first() {
380 if m.is_ident("C") {
381 return true;
382 }
383 }
384 }
385 }
386 }
387
388 false
389}