Skip to main content

pyro_macro/format/
from_row.rs

1use proc_macro2::{Span, TokenStream};
2use quote::{format_ident, quote};
3use syn::{
4    GenericArgument, GenericParam, ItemStruct, Lifetime, LifetimeParam, Path, PathArguments,
5    TraitBound, Type, TypeParamBound, TypePath,
6};
7
8use crate::format::deep_ref::map_type_to_ref;
9
10/// Generates `impl TryFrom<PyroRow> for Struct` and `impl TryFrom<PyroValue> for Struct`
11pub fn from_row(input: &ItemStruct, import_location: &Path) -> syn::Result<TokenStream> {
12    // 1. Parse Input
13
14    let struct_name = &input.ident;
15
16    // 3. Prepare Generics
17    // We need impl<'a, T: TryFrom<PyroValue<'a>>> ... for Struct<T>
18    let mut impl_generics = input.generics.clone();
19    let lifetime = Lifetime::new("'a", Span::call_site());
20
21    // Construct the bound: std::convert::TryFrom<#import_location::PyroValue<'a>>
22    let bound_tokens = quote! { std::convert::TryFrom<#import_location::PyroValue<'a>> };
23    let bound: TypeParamBound = syn::parse2(bound_tokens)?;
24
25    // Add the bound to every Type parameter in the impl generics
26    for param in impl_generics.params.iter_mut() {
27        if let GenericParam::Type(t) = param {
28            t.bounds.push(bound.clone());
29        }
30    }
31
32    // Insert the lifetime 'a at the beginning of the params list
33    impl_generics.params.insert(
34        0,
35        GenericParam::Lifetime(LifetimeParam::new(lifetime.clone())),
36    );
37
38    let (impl_g, _, where_clause) = impl_generics.split_for_impl();
39    let (_, ty_g, _) = input.generics.split_for_impl(); // Original generics for the Struct<T>
40
41    // =========================================================================
42    // Field Extraction Logic (Owned)
43    // =========================================================================
44    let mut owned_field_extractions = Vec::with_capacity(input.fields.len());
45
46    for f in &input.fields {
47        let name = f
48            .ident
49            .as_ref()
50            .ok_or_else(|| syn::Error::new_spanned(f, "FromRow requires named fields"))?;
51
52        let name_str = name.to_string();
53        let ty = &f.ty;
54        let missing_err = format!("Missing field: {}", name_str);
55        let field_err = format!("Failed to convert field '{}'", name_str);
56
57        let stream = generate_field_try_from_owned(
58            name,
59            &name_str,
60            &missing_err,
61            &field_err,
62            ty,
63            import_location,
64        )?;
65        owned_field_extractions.push(stream);
66    }
67
68    let expanded = quote! {
69        // -----------------------------------------------------------------
70        // TryFrom<PyroRow<'a>> for Struct (Owned)
71        // -----------------------------------------------------------------
72        impl #impl_g std::convert::TryFrom<#import_location::PyroRow<'a>> for #struct_name #ty_g #where_clause {
73            type Error = #import_location::PyroRow<'a>;
74
75            fn try_from(row: #import_location::PyroRow<'a>) -> Result<Self, Self::Error> {
76                let result = (|| -> Result<Self, &'static str> {
77                    Ok(Self {
78                        #(#owned_field_extractions,)*
79                    })
80                })();
81
82                result.map_err(|_| row)
83            }
84        }
85
86        // -----------------------------------------------------------------
87        // TryFrom<&PyroRow<'a>> for Struct (Reference)
88        // -----------------------------------------------------------------
89        impl #impl_g std::convert::TryFrom<& #import_location::PyroRow<'a>> for #struct_name #ty_g #where_clause {
90            type Error = &'static str;
91
92            fn try_from(row: & #import_location::PyroRow<'a>) -> Result<Self, Self::Error> {
93                Ok(Self {
94                    #(#owned_field_extractions,)*
95                })
96            }
97        }
98
99        // -----------------------------------------------------------------
100        // TryFrom<PyroValue<'a>> for Struct (Owned)
101        // -----------------------------------------------------------------
102        impl #impl_g std::convert::TryFrom<#import_location::PyroValue<'a>> for #struct_name #ty_g #where_clause {
103            type Error = #import_location::PyroValue<'a>;
104
105            fn try_from(value: #import_location::PyroValue<'a>) -> Result<Self, Self::Error> {
106                match value {
107                    #import_location::PyroValue::Group(r) => match <Self as std::convert::TryFrom<#import_location::PyroRow<'a>>>::try_from(r) {
108                        Ok(s) => Ok(s),
109                        Err(r) => Err(#import_location::PyroValue::Group(r)),
110                    },
111                    v => Err(v)
112                }
113            }
114        }
115
116        // -----------------------------------------------------------------
117        // TryFrom<&PyroValue<'a>> for Struct (Reference)
118        // -----------------------------------------------------------------
119        impl #impl_g std::convert::TryFrom<& #import_location::PyroValue<'a>> for #struct_name #ty_g #where_clause {
120            type Error = &'static str;
121
122            fn try_from(value: & #import_location::PyroValue<'a>) -> Result<Self, Self::Error> {
123                match value {
124                    #import_location::PyroValue::Group(r) => {
125                        <Self as std::convert::TryFrom<& #import_location::PyroRow<'a>>>::try_from(r)
126                    }
127                    _ => Err("Expected Group")
128                }
129            }
130        }
131    };
132
133    Ok(expanded)
134}
135
136/// Generates `impl TryFrom<PyroRow> for StructRef` and `impl TryFrom<PyroValue> for StructRef`
137pub fn ref_from_row(input: &ItemStruct, import_location: &Path) -> syn::Result<TokenStream> {
138    // 1. Parse Input
139    let struct_name = &input.ident;
140    let ref_struct_name = format_ident!("{}Ref", struct_name);
141
142    // 2. Prepare Generics for Impl
143    // We need impl<'a, T: DeepRef> ... for StructRef<'a, <T as DeepRef>::Ref<'a>>
144    let mut impl_generics = input.generics.clone();
145    let lifetime = Lifetime::new("'a", Span::call_site());
146
147    // Add DeepRef bound to all Type params
148    let mut deep_ref_bound_path = import_location.clone();
149    deep_ref_bound_path
150        .segments
151        .push(syn::PathSegment::from(format_ident!("DeepRef")));
152
153    for param in impl_generics.params.iter_mut() {
154        if let GenericParam::Type(t) = param {
155            t.bounds.push(TypeParamBound::Trait(TraitBound {
156                paren_token: None,
157                modifier: syn::TraitBoundModifier::None,
158                lifetimes: None,
159                path: deep_ref_bound_path.clone(),
160            }));
161            // Also ensure T lives as long as 'a if necessary, but usually DeepRef handles this projection
162            t.bounds.push(TypeParamBound::Lifetime(lifetime.clone()));
163        }
164    }
165
166    // Insert 'a into impl params
167    impl_generics.params.insert(
168        0,
169        GenericParam::Lifetime(LifetimeParam::new(lifetime.clone())),
170    );
171
172    let (impl_g, _, where_clause) = impl_generics.split_for_impl();
173
174    // 3. Prepare Arguments for the Ref Struct
175    // StructRef < 'a, <T as DeepRef>::Ref<'a>, ... >
176    let mut ref_struct_args = Vec::new();
177    ref_struct_args.push(quote! { #lifetime });
178
179    for param in &input.generics.params {
180        match param {
181            GenericParam::Type(t) => {
182                let ident = &t.ident;
183                ref_struct_args
184                    .push(quote! { <#ident as #import_location::format::DeepRef>::Ref<#lifetime> });
185            }
186            GenericParam::Const(c) => {
187                let ident = &c.ident;
188                ref_struct_args.push(quote! { #ident });
189            }
190            GenericParam::Lifetime(l) => {
191                let ident = &l.lifetime;
192                ref_struct_args.push(quote! { #ident });
193            }
194        }
195    }
196
197    // =========================================================================
198    // Field Extraction Logic
199    // =========================================================================
200    let mut ref_field_extractions = Vec::with_capacity(input.fields.len());
201    let mut lifetime_used = false;
202
203    for f in &input.fields {
204        let name = f
205            .ident
206            .as_ref()
207            .ok_or_else(|| syn::Error::new_spanned(f, "FromRow requires named fields"))?;
208
209        let name_str = name.to_string();
210        let ty = &f.ty;
211
212        // IMPORTANT: The mapped_type provided here (via map_type_to_ref) is correct.
213        // It correctly maps Option<String> -> Option<&'a str> and Option<i32> -> Option<i32>.
214        let (mapped_type, is_primitive) = map_type_to_ref(ty);
215        if !is_primitive {
216            lifetime_used = true;
217        }
218
219        let missing_err = format!("Missing field: {}", name_str);
220        let field_err = format!("Failed to convert field '{}'", name_str);
221
222        let stream = generate_field_try_from_ref(
223            name,
224            &name_str,
225            &missing_err,
226            &field_err,
227            &mapped_type,
228            ty,
229            import_location,
230        )?;
231        ref_field_extractions.push(stream);
232    }
233
234    let phantom_init = if !lifetime_used {
235        quote! { _phantom: std::marker::PhantomData }
236    } else {
237        quote! {}
238    };
239
240    let expanded = quote! {
241        // -----------------------------------------------------------------
242        // TryFrom<PyroRow<'a>> for StructRef<'a> (Owned)
243        // -----------------------------------------------------------------
244        impl #impl_g std::convert::TryFrom<#import_location::PyroRow<'a>> for #ref_struct_name < #(#ref_struct_args),* > #where_clause {
245            type Error = #import_location::PyroRow<'a>;
246
247            fn try_from(row: #import_location::PyroRow<'a>) -> Result<Self, Self::Error> {
248                let result = (|| -> Result<Self, &'static str> {
249                    Ok(Self {
250                        #(#ref_field_extractions,)*
251                        #phantom_init
252                    })
253                })();
254
255                result.map_err(|_| row)
256            }
257        }
258
259        // -----------------------------------------------------------------
260        // TryFrom<&PyroRow<'a>> for StructRef<'a> (Reference)
261        // -----------------------------------------------------------------
262        impl #impl_g std::convert::TryFrom<& #import_location::PyroRow<'a>> for #ref_struct_name < #(#ref_struct_args),* > #where_clause {
263            type Error = &'static str;
264
265            fn try_from(row: & #import_location::PyroRow<'a>) -> Result<Self, Self::Error> {
266                Ok(Self {
267                    #(#ref_field_extractions,)*
268                    #phantom_init
269                })
270            }
271        }
272
273        // -----------------------------------------------------------------
274        // TryFrom<PyroValue<'a>> for StructRef<'a> (Owned)
275        // -----------------------------------------------------------------
276        impl #impl_g std::convert::TryFrom<#import_location::PyroValue<'a>> for #ref_struct_name < #(#ref_struct_args),* > #where_clause {
277            type Error = #import_location::PyroValue<'a>;
278
279            fn try_from(value: #import_location::PyroValue<'a>) -> Result<Self, Self::Error> {
280                match value {
281                    #import_location::PyroValue::Group(r) => match <Self as std::convert::TryFrom<#import_location::PyroRow<'a>>>::try_from(r) {
282                        Ok(s) => Ok(s),
283                        Err(r) => Err(#import_location::PyroValue::Group(r)),
284                    },
285                    v => Err(v)
286                }
287            }
288        }
289
290        // -----------------------------------------------------------------
291        // TryFrom<&PyroValue<'a>> for StructRef<'a> (Reference)
292        // -----------------------------------------------------------------
293        impl #impl_g std::convert::TryFrom<& #import_location::PyroValue<'a>> for #ref_struct_name < #(#ref_struct_args),* > #where_clause {
294            type Error = &'static str;
295
296            fn try_from(value: & #import_location::PyroValue<'a>) -> Result<Self, Self::Error> {
297                match value {
298                    #import_location::PyroValue::Group(r) => {
299                        <Self as std::convert::TryFrom<& #import_location::PyroRow<'a>>>::try_from(r)
300                    }
301                    _ => Err("Expected Group")
302                }
303            }
304        }
305    };
306
307    Ok(expanded)
308}
309
310// =============================================================================
311// Code generation helpers for REF fields
312// =============================================================================
313
314fn generate_field_try_from_ref(
315    name: &syn::Ident,
316    name_str: &str,
317    missing_err: &str,
318    field_err: &str,
319    _mapped_type: &TokenStream, // Unused variable, we calculate inner type logic manually
320    original_ty: &Type,
321    import_location: &Path,
322) -> syn::Result<TokenStream> {
323    if is_option(original_ty) {
324        // For Option, we need to unwrap the inner type and map it to its ref type
325        let inner_ty = get_option_inner(original_ty)
326            .ok_or_else(|| syn::Error::new_spanned(original_ty, "Malformed Option type"))?;
327
328        // FIX: Do not blindly project <inner_ty as DeepRef>.
329        // Use map_type_to_ref to get the actual target type (e.g. i32, &'a str, or Ref).
330        let (inner_mapped, _) = map_type_to_ref(inner_ty);
331
332        Ok(quote! {
333            #name: {
334                match row.get(#name_str) {
335                    Some(#import_location::PyroValue::Null) | None => None,
336                    Some(val) => Some(
337                        <#inner_mapped as std::convert::TryFrom<#import_location::PyroValue<'a>>>::try_from(val.clone())
338                            .map_err(|_| #field_err)?
339                    ),
340                }
341            }
342        })
343    } else {
344        // For non-Option, mapped_type (from argument) is already correct.
345        // Recalculating it to be safe and consistent with Option fix above.
346        let (mapped_type, _) = map_type_to_ref(original_ty);
347
348        Ok(quote! {
349            #name: {
350                let val = row.get(#name_str)
351                    .ok_or_else(|| #missing_err)?
352                    .clone();
353                <#mapped_type as std::convert::TryFrom<#import_location::PyroValue<'a>>>::try_from(val)
354                    .map_err(|_| #field_err)?
355            }
356        })
357    }
358}
359
360// =============================================================================
361// Code generation helpers for OWNED fields
362// =============================================================================
363
364fn generate_field_try_from_owned(
365    name: &syn::Ident,
366    name_str: &str,
367    missing_err: &str,
368    field_err: &str,
369    ty: &Type,
370    import_location: &Path,
371) -> syn::Result<TokenStream> {
372    if is_option(ty) {
373        let inner_ty = get_option_inner(ty)
374            .ok_or_else(|| syn::Error::new_spanned(ty, "Malformed Option type"))?;
375
376        Ok(quote! {
377            #name: {
378                match row.get(#name_str) {
379                    Some(#import_location::PyroValue::Null) | None => None,
380                    Some(val) => {
381                        let owned: #inner_ty = val.clone().try_into()
382                            .map_err(|_| #field_err)?;
383                        Some(owned)
384                    }
385                }
386            }
387        })
388    } else if is_nested_struct(ty) {
389        Ok(quote! {
390            #name: {
391                let val = row.get(#name_str)
392                    .ok_or_else(|| #missing_err)?
393                    .clone();
394                val.try_into()
395                    .map_err(|_| #field_err)?
396            }
397        })
398    } else if is_vec_of_struct(ty) {
399        let inner_ty =
400            get_vec_inner(ty).ok_or_else(|| syn::Error::new_spanned(ty, "Malformed Vec type"))?;
401        let fail = format!("Failed to convert element in field '{}'", name_str);
402        let unexpected = format!("Expected List for field '{}'", name_str);
403
404        // FIX: map_err must be applied inside the closure, on the Result returned by try_into()
405        Ok(quote! {
406            #name: {
407                match row.get(#name_str)
408                    .ok_or_else(|| #missing_err)?
409                {
410                    #import_location::PyroValue::List(items) => {
411                        items.iter()
412                            .map(|v| v.clone().try_into().map_err(|_| #fail))
413                            .collect::<Result<Vec<#inner_ty>, _>>()?
414                    }
415                    _ => return Err(#unexpected),
416                }
417            }
418        })
419    } else {
420        Ok(quote! {
421            #name: {
422                let val = row.get(#name_str)
423                    .ok_or_else(|| #missing_err)?
424                    .clone();
425                val.try_into()
426                    .map_err(|_| #field_err)?
427            }
428        })
429    }
430}
431
432// =============================================================================
433// Type inspection helpers
434// =============================================================================
435
436fn is_option(ty: &Type) -> bool {
437    if let Type::Path(TypePath { path, .. }) = ty
438        && let Some(seg) = path.segments.last()
439    {
440        return seg.ident == "Option";
441    }
442    false
443}
444
445fn get_option_inner(ty: &Type) -> Option<&Type> {
446    if let Type::Path(TypePath { path, .. }) = ty
447        && let Some(seg) = path.segments.last()
448        && let PathArguments::AngleBracketed(args) = &seg.arguments
449        && let Some(GenericArgument::Type(inner)) = args.args.first()
450    {
451        return Some(inner);
452    }
453    None
454}
455
456fn get_vec_inner(ty: &Type) -> Option<&Type> {
457    if let Type::Path(TypePath { path, .. }) = ty
458        && let Some(seg) = path.segments.last()
459        && let PathArguments::AngleBracketed(args) = &seg.arguments
460        && let Some(GenericArgument::Type(inner)) = args.args.first()
461    {
462        return Some(inner);
463    }
464    None
465}
466
467fn is_nested_struct(ty: &Type) -> bool {
468    if let Type::Path(TypePath { path, .. }) = ty
469        && let Some(seg) = path.segments.last()
470    {
471        let ident_str = seg.ident.to_string();
472        // Expanded list to be safe, but generic T will return true (good)
473        return !matches!(
474            ident_str.as_str(),
475            "bool"
476                | "i8"
477                | "i16"
478                | "i32"
479                | "i64"
480                | "isize"
481                | "u8"
482                | "u16"
483                | "u32"
484                | "u64"
485                | "usize"
486                | "f16"
487                | "f32"
488                | "f64"
489                | "String"
490                | "Vec"
491                | "Option"
492        );
493    }
494    false
495}
496
497fn is_vec_of_struct(ty: &Type) -> bool {
498    if let Type::Path(TypePath { path, .. }) = ty
499        && let Some(seg) = path.segments.last()
500        && seg.ident == "Vec"
501        && let PathArguments::AngleBracketed(args) = &seg.arguments
502        && let Some(GenericArgument::Type(inner)) = args.args.first()
503    {
504        return is_nested_struct(inner);
505    }
506    false
507}