Skip to main content

pyro_macro/format/
deep_ref.rs

1use proc_macro2::TokenStream;
2use quote::{format_ident, quote};
3use syn::{GenericArgument, Ident, ItemStruct, Path, PathArguments, Type, TypePath};
4
5/// Main function to generate the Deep Reference struct and implementation.
6/// Accepts a list of additional derives to apply to the generated struct.
7pub fn deep_ref(
8    input: &ItemStruct,
9    import_location: &Path,
10    derives_to_pass: &Vec<Ident>,
11) -> syn::Result<TokenStream> {
12    if !input.generics.params.is_empty() {
13        return Err(syn::Error::new_spanned(
14            &input.generics,
15            "DeepRef cannot be derived for structs with generic parameters (types, lifetimes, or consts)",
16        ));
17    }
18
19    let struct_name = &input.ident;
20    let ref_struct_name = format_ident!("{}Ref", struct_name);
21
22    // ItemStruct stores fields directly
23    let fields = &input.fields;
24
25    // 1. Generate the Reference Struct Definition
26    let mut lifetime_used = false;
27    let ref_fields = fields
28        .iter()
29        .map(|f| {
30            let name = &f.ident;
31            let vis = &f.vis;
32            let ty = &f.ty;
33            let (mapped_type, is_primitive) = map_type_to_ref(ty);
34
35            if !is_primitive {
36                lifetime_used = true;
37            }
38
39            quote! { #vis #name: #mapped_type }
40        })
41        .collect::<Vec<_>>();
42
43    let phantom_field = if !lifetime_used {
44        quote! { _phantom: std::marker::PhantomData<&'a ()> }
45    } else {
46        quote! {}
47    };
48
49    // Inject the user-provided derives here
50    let struct_def = quote! {
51        #[derive(#(#derives_to_pass),*)]
52        pub struct #ref_struct_name<'a> {
53            #(#ref_fields,)*
54            #phantom_field
55        }
56    };
57
58    // 2. Generate the DeepRef Implementation for the Owned type
59    let field_conversions = fields.iter().map(|f| {
60        let field_name = f.ident.as_ref().unwrap();
61        let ty = &f.ty;
62        generate_field_conversion(field_name, ty)
63    });
64
65    let phantom_init = if !lifetime_used {
66        quote! { _phantom: std::marker::PhantomData }
67    } else {
68        quote! {}
69    };
70
71    let impl_owned = quote! {
72        impl #import_location::format::DeepRef for #struct_name {
73            type Ref<'a> = #ref_struct_name<'a>;
74
75            fn as_deep_ref(&self) -> Self::Ref<'_> {
76                #ref_struct_name {
77                    #(#field_conversions,)*
78                    #phantom_init
79                }
80            }
81        }
82    };
83
84    Ok(quote! {
85        #struct_def
86        #impl_owned
87    })
88}
89
90/// Specialized function for Rkyv.
91/// Generates the standard DeepRef stuff, PLUS an implementation for the Archived variant.
92pub fn deep_ref_rkyv(input: &ItemStruct, import_location: &Path) -> syn::Result<TokenStream> {
93    let struct_name = &input.ident;
94    // Standard rkyv naming convention: Archived + StructName
95    let archived_struct_name = format_ident!("Archived{}", struct_name);
96    let ref_struct_name = format_ident!("{}Ref", struct_name);
97
98    // 2. Generate the conversions for the Archived fields
99    // We iterate over the *original* fields to determine types, but generate logic
100    // that assumes we are operating on the *Archived* struct.
101    let fields = &input.fields;
102    let rkyv_field_conversions = fields.iter().map(|f| {
103        let field_name = f.ident.as_ref().unwrap();
104        let ty = &f.ty;
105        generate_rkyv_field_conversion(field_name, ty)
106    });
107
108    // Determine if we need phantom data init (same logic as main function)
109    let mut lifetime_used = false;
110    for f in fields {
111        let (_, is_prim) = map_type_to_ref(&f.ty);
112        if !is_prim {
113            lifetime_used = true;
114            break;
115        }
116    }
117
118    let phantom_init = if !lifetime_used {
119        quote! { _phantom: std::marker::PhantomData }
120    } else {
121        quote! {}
122    };
123
124    // 3. Generate the DeepRef implementation for the Archived struct
125    let impl_archived = quote! {
126        #[cfg(target_endian = "little")]
127        impl #import_location::format::DeepRef for #archived_struct_name {
128            type Ref<'a> = #ref_struct_name<'a>;
129
130            fn as_deep_ref(&self) -> Self::Ref<'_> {
131                #ref_struct_name {
132                    #(#rkyv_field_conversions,)*
133                    #phantom_init
134                }
135            }
136        }
137    };
138
139    // Combine base (Ref definition + Owned impl) with the new Archived impl
140    Ok(quote! {
141        #impl_archived
142    })
143}
144
145// -------------------------------------------------------------------------
146// Helper Functions
147// -------------------------------------------------------------------------
148
149// Map Owned types to Borrowed types for the struct definition
150pub(crate) fn map_type_to_ref(ty: &Type) -> (TokenStream, bool) {
151    match ty {
152        Type::Path(TypePath { path, .. }) => {
153            let segment = path.segments.last().unwrap();
154            let ident_str = segment.ident.to_string();
155
156            // Check if it is a wrapper around `str` (Arc<str>, Box<str>, etc)
157            if is_string_like(ty) {
158                return (quote! { &'a str }, false);
159            }
160
161            match ident_str.as_str() {
162                "bool" | "i8" | "i16" | "i32" | "i64" | "isize" | "u8" | "u16" | "u32" | "u64"
163                | "usize" | "f16" | "f32" | "f64" => {
164                    let ident = &segment.ident;
165                    (quote! { #ident }, true)
166                }
167                "Vec" => {
168                    if let PathArguments::AngleBracketed(args) = &segment.arguments
169                        && let Some(GenericArgument::Type(inner_ty)) = args.args.first()
170                    {
171                        let (inner_ref, is_prim) = map_type_to_ref(inner_ty);
172                        if is_prim {
173                            return (quote! { &'a [#inner_ref] }, false);
174                        } else {
175                            // For complex types, we return Vec<Ref>
176                            return (quote! { Vec<#inner_ref> }, false);
177                        }
178                    }
179                    (quote! { Vec<()> }, false)
180                }
181                "Option" => {
182                    if let PathArguments::AngleBracketed(args) = &segment.arguments
183                        && let Some(GenericArgument::Type(inner_ty)) = args.args.first()
184                    {
185                        // Recursively map the inner type
186
187                        // 1. If inner is primitive (i32), Option<i32> is Copy (effectively),
188                        //    so we keep Option<i32>.
189                        if is_primitive(inner_ty) {
190                            return (quote! { Option<#inner_ty> }, true);
191                        }
192
193                        // 2. If inner is string-like (String, Arc<str>), we want Option<&'a str>
194                        if is_string_like(inner_ty) {
195                            return (quote! { Option<&'a str> }, false);
196                        }
197
198                        // 3. Otherwise, map normally (e.g., Option<MyStruct> -> Option<MyStructRef>)
199                        let (inner_ref, _) = map_type_to_ref(inner_ty);
200                        return (quote! { Option<#inner_ref> }, false);
201                    }
202                    (quote! { Option<()> }, false)
203                }
204                // Nested struct - assume it has a Ref variant
205                other => {
206                    let ref_name = format_ident!("{}Ref", other);
207                    (quote! { #ref_name<'a> }, false)
208                }
209            }
210        }
211        _ => (quote! { () }, true),
212    }
213}
214
215// Generate the conversion logic for as_deep_ref (Owned -> Borrowed)
216fn generate_field_conversion(field_name: &Ident, ty: &Type) -> TokenStream {
217    match ty {
218        Type::Path(TypePath { path, .. }) => {
219            let segment = path.segments.last().unwrap();
220            let ident_str = segment.ident.to_string();
221
222            // Direct String-like types (Arc<str>, String, Box<str>) -> &str
223            if is_string_like(ty) {
224                // as_deref works for Option, but for Arc<str>/String we usually use as_ref() or &*
225                // String: .as_str()
226                // Arc<str>: &**self.field or .as_ref()
227                if ident_str == "String" {
228                    return quote! { #field_name: self.#field_name.as_str() };
229                } else {
230                    return quote! { #field_name: &self.#field_name };
231                }
232            }
233
234            match ident_str.as_str() {
235                // Primitives: Copy
236                "bool" | "i8" | "i16" | "i32" | "i64" | "isize" | "u8" | "u16" | "u32" | "u64"
237                | "usize" | "f16" | "f32" | "f64" => {
238                    quote! { #field_name: self.#field_name }
239                }
240
241                // Vec
242                "Vec" => {
243                    if let PathArguments::AngleBracketed(args) = &segment.arguments {
244                        if let Some(GenericArgument::Type(inner_ty)) = args.args.first() {
245                            if is_primitive(inner_ty) {
246                                // Primitive vec: borrow as slice
247                                quote! { #field_name: self.#field_name.as_slice() }
248                            } else if is_string_like(inner_ty) {
249                                // Vec<String>, Vec<Arc<str>> -> Vec<&str>
250                                if ident_str == "String" || is_string(inner_ty) {
251                                    quote! { #field_name: self.#field_name.iter().map(|x| x.as_str()).collect() }
252                                } else {
253                                    quote! { #field_name: self.#field_name.iter().map(|x| x.as_ref()).collect() }
254                                }
255                            } else {
256                                // Complex vec: map to Vec<Ref>
257                                quote! {
258                                    #field_name: self.#field_name.iter().map(|x| x.as_deep_ref()).collect()
259                                }
260                            }
261                        } else {
262                            quote! { #field_name: vec![] }
263                        }
264                    } else {
265                        quote! { #field_name: vec![] }
266                    }
267                }
268
269                // Option
270                "Option" => {
271                    if let PathArguments::AngleBracketed(args) = &segment.arguments {
272                        if let Some(GenericArgument::Type(inner_ty)) = args.args.first() {
273                            if is_primitive(inner_ty) {
274                                // Option<i32>: Copy
275                                quote! { #field_name: self.#field_name }
276                            } else if is_string_like(inner_ty) {
277                                // Option<String>, Option<Arc<str>> -> .as_deref() gives Option<&str>
278                                quote! { #field_name: self.#field_name.as_deref() }
279                            } else {
280                                // Option<Struct> -> map(as_deep_ref)
281                                quote! { #field_name: self.#field_name.as_ref().map(|x| x.as_deep_ref()) }
282                            }
283                        } else {
284                            quote! { #field_name: None }
285                        }
286                    } else {
287                        quote! { #field_name: None }
288                    }
289                }
290
291                // Nested Structs
292                _ => {
293                    quote! { #field_name: self.#field_name.as_deep_ref() }
294                }
295            }
296        }
297        _ => quote! { #field_name: self.#field_name.as_deep_ref() },
298    }
299}
300
301// Generate the conversion logic for as_deep_ref (Archived -> Borrowed)
302// Handles rkyv specific types (ArchivedString, ArchivedVec, etc.)
303fn generate_rkyv_field_conversion(field_name: &Ident, ty: &Type) -> TokenStream {
304    match ty {
305        Type::Path(TypePath { path, .. }) => {
306            let segment = path.segments.last().unwrap();
307            let ident_str = segment.ident.to_string();
308
309            match ident_str.as_str() {
310                // 1. Single-byte or simple primitives (Usually Copy in Rkyv)
311                "bool" | "i8" | "u8" => {
312                    quote! { #field_name: self.#field_name }
313                }
314
315                // 2. Multi-byte Endian-Specific Primitives (Rkyv wrappers)
316                "i16" | "i16_le" | "i32" | "i32_le" | "i64" | "i64_le" | "isize" | "u16"
317                | "u16_le" | "u32" | "u32_le" | "u64" | "usize" | "u64_le" | "f16" | "f16_le"
318                | "f32" | "f32_le" | "f64" | "f64_le" => {
319                    quote! { #field_name: self.#field_name.to_native() as _ }
320                }
321
322                // 3. String: ArchivedString has .as_str()
323                "String" | "ArchivedString" => {
324                    quote! { #field_name: self.#field_name.as_str() }
325                }
326
327                // 4. Vec
328                "Vec" | "ArchivedVec" => {
329                    if let PathArguments::AngleBracketed(args) = &segment.arguments {
330                        if let Some(GenericArgument::Type(inner_ty)) = args.args.first() {
331                            if is_primitive(inner_ty) {
332                                quote! { #field_name: unsafe { std::mem::transmute(self.#field_name.as_slice()) }}
333                            } else if is_string_like(inner_ty) {
334                                // ArchivedVec<ArchivedString>. Inner (in struct def) is String.
335                                // We iterate and get &ArchivedString. .as_str() works.
336                                quote! {
337                                    #field_name: self.#field_name.iter().map(|x| x.as_str()).collect()
338                                }
339                            } else {
340                                quote! {
341                                    #field_name: self.#field_name.iter().map(|x| x.as_deep_ref()).collect()
342                                }
343                            }
344                        } else {
345                            quote! { #field_name: vec![] }
346                        }
347                    } else {
348                        quote! { #field_name: vec![] }
349                    }
350                }
351
352                // 5. Option
353                "Option" | "ArchivedOption" => {
354                    if let PathArguments::AngleBracketed(args) = &segment.arguments {
355                        if let Some(GenericArgument::Type(inner_ty)) = args.args.first() {
356                            if is_primitive(inner_ty) {
357                                // ArchivedOption<ArchivedPrimitive> -> .as_ref() gives Option<&ArchivedPrimitive>
358                                // We need Option<Primitive>.
359                                // .to_native() converts &ArchivedPrimitive (Copy) to Primitive
360                                quote! {
361                                    #field_name: self.#field_name.as_ref().map(|x| x.to_native() as _)
362                                }
363                            } else if is_string_like(inner_ty) {
364                                // ArchivedOption<ArchivedString>.
365                                // We need Option<&str>.
366                                // as_ref() -> Option<&ArchivedString>
367                                // .map(|x| x.as_str()) -> Option<&str>
368                                quote! {
369                                    #field_name: self.#field_name.as_ref().map(|x| x.as_str())
370                                }
371                            } else {
372                                // Standard nested struct
373                                quote! { #field_name: self.#field_name.as_ref().map(|x| x.as_deep_ref()) }
374                            }
375                        } else {
376                            quote! { #field_name: None }
377                        }
378                    } else {
379                        quote! { #field_name: None }
380                    }
381                }
382
383                // Nested Structs
384                _ => {
385                    quote! { #field_name: self.#field_name.as_deep_ref() }
386                }
387            }
388        }
389        _ => quote! { #field_name: self.#field_name.as_deep_ref() },
390    }
391}
392
393// Helper to identify primitives
394fn is_primitive(ty: &Type) -> bool {
395    if let Type::Path(TypePath { path, .. }) = ty {
396        let ident = path.segments.last().unwrap().ident.to_string();
397        matches!(
398            ident.as_str(),
399            "i8" | "i16"
400                | "i32"
401                | "i64"
402                | "isize"
403                | "u8"
404                | "u16"
405                | "u32"
406                | "u64"
407                | "usize"
408                | "f16"
409                | "f32"
410                | "f64"
411                | "bool"
412        )
413    } else {
414        false
415    }
416}
417
418fn is_string(ty: &Type) -> bool {
419    if let Type::Path(TypePath { path, .. }) = ty {
420        let ident = path.segments.last().unwrap().ident.to_string();
421        ident == "String"
422    } else {
423        false
424    }
425}
426
427// Helper to identify "String-like" types that should become &'a str
428// Covers: String, Arc<str>, Box<str>, Cow<str>
429fn is_string_like(ty: &Type) -> bool {
430    if let Type::Path(TypePath { path, .. }) = ty {
431        let segment = path.segments.last().unwrap();
432        let ident = segment.ident.to_string();
433
434        // 1. Simple String
435        if ident == "String" {
436            return true;
437        }
438
439        // 2. Wrappers (Arc, Box, Cow)
440        if matches!(ident.as_str(), "Arc" | "Box" | "Cow" | "Rc")
441            && let PathArguments::AngleBracketed(args) = &segment.arguments
442            && let Some(GenericArgument::Type(inner_ty)) = args.args.first()
443        {
444            // Check if inner is "str"
445            if let Type::Path(TypePath {
446                path: inner_path, ..
447            }) = inner_ty
448                && let Some(inner_seg) = inner_path.segments.last()
449            {
450                return inner_seg.ident == "str";
451            }
452        }
453    }
454    false
455}