1use proc_macro2::{Span, TokenStream};
2use quote::{format_ident, quote};
3use syn::{
4 GenericArgument, GenericParam, ItemStruct, Lifetime, LifetimeParam, Path, PathArguments,
5 TraitBound, Type, TypeParamBound, TypePath,
6};
7
8use crate::format::deep_ref::map_type_to_ref;
9
10pub fn from_row(input: &ItemStruct, import_location: &Path) -> syn::Result<TokenStream> {
12 let struct_name = &input.ident;
15
16 let mut impl_generics = input.generics.clone();
19 let lifetime = Lifetime::new("'a", Span::call_site());
20
21 let bound_tokens = quote! { std::convert::TryFrom<#import_location::PyroValue<'a>> };
23 let bound: TypeParamBound = syn::parse2(bound_tokens)?;
24
25 for param in impl_generics.params.iter_mut() {
27 if let GenericParam::Type(t) = param {
28 t.bounds.push(bound.clone());
29 }
30 }
31
32 impl_generics.params.insert(
34 0,
35 GenericParam::Lifetime(LifetimeParam::new(lifetime.clone())),
36 );
37
38 let (impl_g, _, where_clause) = impl_generics.split_for_impl();
39 let (_, ty_g, _) = input.generics.split_for_impl(); let mut owned_field_extractions = Vec::with_capacity(input.fields.len());
45
46 for f in &input.fields {
47 let name = f
48 .ident
49 .as_ref()
50 .ok_or_else(|| syn::Error::new_spanned(f, "FromRow requires named fields"))?;
51
52 let name_str = name.to_string();
53 let ty = &f.ty;
54 let missing_err = format!("Missing field: {}", name_str);
55 let field_err = format!("Failed to convert field '{}'", name_str);
56
57 let stream = generate_field_try_from_owned(
58 name,
59 &name_str,
60 &missing_err,
61 &field_err,
62 ty,
63 import_location,
64 )?;
65 owned_field_extractions.push(stream);
66 }
67
68 let expanded = quote! {
69 impl #impl_g std::convert::TryFrom<#import_location::PyroRow<'a>> for #struct_name #ty_g #where_clause {
73 type Error = #import_location::PyroRow<'a>;
74
75 fn try_from(row: #import_location::PyroRow<'a>) -> Result<Self, Self::Error> {
76 let result = (|| -> Result<Self, &'static str> {
77 Ok(Self {
78 #(#owned_field_extractions,)*
79 })
80 })();
81
82 result.map_err(|_| row)
83 }
84 }
85
86 impl #impl_g std::convert::TryFrom<& #import_location::PyroRow<'a>> for #struct_name #ty_g #where_clause {
90 type Error = &'static str;
91
92 fn try_from(row: & #import_location::PyroRow<'a>) -> Result<Self, Self::Error> {
93 Ok(Self {
94 #(#owned_field_extractions,)*
95 })
96 }
97 }
98
99 impl #impl_g std::convert::TryFrom<#import_location::PyroValue<'a>> for #struct_name #ty_g #where_clause {
103 type Error = #import_location::PyroValue<'a>;
104
105 fn try_from(value: #import_location::PyroValue<'a>) -> Result<Self, Self::Error> {
106 match value {
107 #import_location::PyroValue::Group(r) => match <Self as std::convert::TryFrom<#import_location::PyroRow<'a>>>::try_from(r) {
108 Ok(s) => Ok(s),
109 Err(r) => Err(#import_location::PyroValue::Group(r)),
110 },
111 v => Err(v)
112 }
113 }
114 }
115
116 impl #impl_g std::convert::TryFrom<& #import_location::PyroValue<'a>> for #struct_name #ty_g #where_clause {
120 type Error = &'static str;
121
122 fn try_from(value: & #import_location::PyroValue<'a>) -> Result<Self, Self::Error> {
123 match value {
124 #import_location::PyroValue::Group(r) => {
125 <Self as std::convert::TryFrom<& #import_location::PyroRow<'a>>>::try_from(r)
126 }
127 _ => Err("Expected Group")
128 }
129 }
130 }
131 };
132
133 Ok(expanded)
134}
135
136pub fn ref_from_row(input: &ItemStruct, import_location: &Path) -> syn::Result<TokenStream> {
138 let struct_name = &input.ident;
140 let ref_struct_name = format_ident!("{}Ref", struct_name);
141
142 let mut impl_generics = input.generics.clone();
145 let lifetime = Lifetime::new("'a", Span::call_site());
146
147 let mut deep_ref_bound_path = import_location.clone();
149 deep_ref_bound_path
150 .segments
151 .push(syn::PathSegment::from(format_ident!("DeepRef")));
152
153 for param in impl_generics.params.iter_mut() {
154 if let GenericParam::Type(t) = param {
155 t.bounds.push(TypeParamBound::Trait(TraitBound {
156 paren_token: None,
157 modifier: syn::TraitBoundModifier::None,
158 lifetimes: None,
159 path: deep_ref_bound_path.clone(),
160 }));
161 t.bounds.push(TypeParamBound::Lifetime(lifetime.clone()));
163 }
164 }
165
166 impl_generics.params.insert(
168 0,
169 GenericParam::Lifetime(LifetimeParam::new(lifetime.clone())),
170 );
171
172 let (impl_g, _, where_clause) = impl_generics.split_for_impl();
173
174 let mut ref_struct_args = Vec::new();
177 ref_struct_args.push(quote! { #lifetime });
178
179 for param in &input.generics.params {
180 match param {
181 GenericParam::Type(t) => {
182 let ident = &t.ident;
183 ref_struct_args
184 .push(quote! { <#ident as #import_location::format::DeepRef>::Ref<#lifetime> });
185 }
186 GenericParam::Const(c) => {
187 let ident = &c.ident;
188 ref_struct_args.push(quote! { #ident });
189 }
190 GenericParam::Lifetime(l) => {
191 let ident = &l.lifetime;
192 ref_struct_args.push(quote! { #ident });
193 }
194 }
195 }
196
197 let mut ref_field_extractions = Vec::with_capacity(input.fields.len());
201 let mut lifetime_used = false;
202
203 for f in &input.fields {
204 let name = f
205 .ident
206 .as_ref()
207 .ok_or_else(|| syn::Error::new_spanned(f, "FromRow requires named fields"))?;
208
209 let name_str = name.to_string();
210 let ty = &f.ty;
211
212 let (mapped_type, is_primitive) = map_type_to_ref(ty);
215 if !is_primitive {
216 lifetime_used = true;
217 }
218
219 let missing_err = format!("Missing field: {}", name_str);
220 let field_err = format!("Failed to convert field '{}'", name_str);
221
222 let stream = generate_field_try_from_ref(
223 name,
224 &name_str,
225 &missing_err,
226 &field_err,
227 &mapped_type,
228 ty,
229 import_location,
230 )?;
231 ref_field_extractions.push(stream);
232 }
233
234 let phantom_init = if !lifetime_used {
235 quote! { _phantom: std::marker::PhantomData }
236 } else {
237 quote! {}
238 };
239
240 let expanded = quote! {
241 impl #impl_g std::convert::TryFrom<#import_location::PyroRow<'a>> for #ref_struct_name < #(#ref_struct_args),* > #where_clause {
245 type Error = #import_location::PyroRow<'a>;
246
247 fn try_from(row: #import_location::PyroRow<'a>) -> Result<Self, Self::Error> {
248 let result = (|| -> Result<Self, &'static str> {
249 Ok(Self {
250 #(#ref_field_extractions,)*
251 #phantom_init
252 })
253 })();
254
255 result.map_err(|_| row)
256 }
257 }
258
259 impl #impl_g std::convert::TryFrom<& #import_location::PyroRow<'a>> for #ref_struct_name < #(#ref_struct_args),* > #where_clause {
263 type Error = &'static str;
264
265 fn try_from(row: & #import_location::PyroRow<'a>) -> Result<Self, Self::Error> {
266 Ok(Self {
267 #(#ref_field_extractions,)*
268 #phantom_init
269 })
270 }
271 }
272
273 impl #impl_g std::convert::TryFrom<#import_location::PyroValue<'a>> for #ref_struct_name < #(#ref_struct_args),* > #where_clause {
277 type Error = #import_location::PyroValue<'a>;
278
279 fn try_from(value: #import_location::PyroValue<'a>) -> Result<Self, Self::Error> {
280 match value {
281 #import_location::PyroValue::Group(r) => match <Self as std::convert::TryFrom<#import_location::PyroRow<'a>>>::try_from(r) {
282 Ok(s) => Ok(s),
283 Err(r) => Err(#import_location::PyroValue::Group(r)),
284 },
285 v => Err(v)
286 }
287 }
288 }
289
290 impl #impl_g std::convert::TryFrom<& #import_location::PyroValue<'a>> for #ref_struct_name < #(#ref_struct_args),* > #where_clause {
294 type Error = &'static str;
295
296 fn try_from(value: & #import_location::PyroValue<'a>) -> Result<Self, Self::Error> {
297 match value {
298 #import_location::PyroValue::Group(r) => {
299 <Self as std::convert::TryFrom<& #import_location::PyroRow<'a>>>::try_from(r)
300 }
301 _ => Err("Expected Group")
302 }
303 }
304 }
305 };
306
307 Ok(expanded)
308}
309
310fn generate_field_try_from_ref(
315 name: &syn::Ident,
316 name_str: &str,
317 missing_err: &str,
318 field_err: &str,
319 _mapped_type: &TokenStream, original_ty: &Type,
321 import_location: &Path,
322) -> syn::Result<TokenStream> {
323 if is_option(original_ty) {
324 let inner_ty = get_option_inner(original_ty)
326 .ok_or_else(|| syn::Error::new_spanned(original_ty, "Malformed Option type"))?;
327
328 let (inner_mapped, _) = map_type_to_ref(inner_ty);
331
332 Ok(quote! {
333 #name: {
334 match row.get(#name_str) {
335 Some(#import_location::PyroValue::Null) | None => None,
336 Some(val) => Some(
337 <#inner_mapped as std::convert::TryFrom<#import_location::PyroValue<'a>>>::try_from(val.clone())
338 .map_err(|_| #field_err)?
339 ),
340 }
341 }
342 })
343 } else {
344 let (mapped_type, _) = map_type_to_ref(original_ty);
347
348 Ok(quote! {
349 #name: {
350 let val = row.get(#name_str)
351 .ok_or_else(|| #missing_err)?
352 .clone();
353 <#mapped_type as std::convert::TryFrom<#import_location::PyroValue<'a>>>::try_from(val)
354 .map_err(|_| #field_err)?
355 }
356 })
357 }
358}
359
360fn generate_field_try_from_owned(
365 name: &syn::Ident,
366 name_str: &str,
367 missing_err: &str,
368 field_err: &str,
369 ty: &Type,
370 import_location: &Path,
371) -> syn::Result<TokenStream> {
372 if is_option(ty) {
373 let inner_ty = get_option_inner(ty)
374 .ok_or_else(|| syn::Error::new_spanned(ty, "Malformed Option type"))?;
375
376 Ok(quote! {
377 #name: {
378 match row.get(#name_str) {
379 Some(#import_location::PyroValue::Null) | None => None,
380 Some(val) => {
381 let owned: #inner_ty = val.clone().try_into()
382 .map_err(|_| #field_err)?;
383 Some(owned)
384 }
385 }
386 }
387 })
388 } else if is_nested_struct(ty) {
389 Ok(quote! {
390 #name: {
391 let val = row.get(#name_str)
392 .ok_or_else(|| #missing_err)?
393 .clone();
394 val.try_into()
395 .map_err(|_| #field_err)?
396 }
397 })
398 } else if is_vec_of_struct(ty) {
399 let inner_ty =
400 get_vec_inner(ty).ok_or_else(|| syn::Error::new_spanned(ty, "Malformed Vec type"))?;
401 let fail = format!("Failed to convert element in field '{}'", name_str);
402 let unexpected = format!("Expected List for field '{}'", name_str);
403
404 Ok(quote! {
406 #name: {
407 match row.get(#name_str)
408 .ok_or_else(|| #missing_err)?
409 {
410 #import_location::PyroValue::List(items) => {
411 items.iter()
412 .map(|v| v.clone().try_into().map_err(|_| #fail))
413 .collect::<Result<Vec<#inner_ty>, _>>()?
414 }
415 _ => return Err(#unexpected),
416 }
417 }
418 })
419 } else {
420 Ok(quote! {
421 #name: {
422 let val = row.get(#name_str)
423 .ok_or_else(|| #missing_err)?
424 .clone();
425 val.try_into()
426 .map_err(|_| #field_err)?
427 }
428 })
429 }
430}
431
432fn is_option(ty: &Type) -> bool {
437 if let Type::Path(TypePath { path, .. }) = ty
438 && let Some(seg) = path.segments.last()
439 {
440 return seg.ident == "Option";
441 }
442 false
443}
444
445fn get_option_inner(ty: &Type) -> Option<&Type> {
446 if let Type::Path(TypePath { path, .. }) = ty
447 && let Some(seg) = path.segments.last()
448 && let PathArguments::AngleBracketed(args) = &seg.arguments
449 && let Some(GenericArgument::Type(inner)) = args.args.first()
450 {
451 return Some(inner);
452 }
453 None
454}
455
456fn get_vec_inner(ty: &Type) -> Option<&Type> {
457 if let Type::Path(TypePath { path, .. }) = ty
458 && let Some(seg) = path.segments.last()
459 && let PathArguments::AngleBracketed(args) = &seg.arguments
460 && let Some(GenericArgument::Type(inner)) = args.args.first()
461 {
462 return Some(inner);
463 }
464 None
465}
466
467fn is_nested_struct(ty: &Type) -> bool {
468 if let Type::Path(TypePath { path, .. }) = ty
469 && let Some(seg) = path.segments.last()
470 {
471 let ident_str = seg.ident.to_string();
472 return !matches!(
474 ident_str.as_str(),
475 "bool"
476 | "i8"
477 | "i16"
478 | "i32"
479 | "i64"
480 | "isize"
481 | "u8"
482 | "u16"
483 | "u32"
484 | "u64"
485 | "usize"
486 | "f16"
487 | "f32"
488 | "f64"
489 | "String"
490 | "Vec"
491 | "Option"
492 );
493 }
494 false
495}
496
497fn is_vec_of_struct(ty: &Type) -> bool {
498 if let Type::Path(TypePath { path, .. }) = ty
499 && let Some(seg) = path.segments.last()
500 && seg.ident == "Vec"
501 && let PathArguments::AngleBracketed(args) = &seg.arguments
502 && let Some(GenericArgument::Type(inner)) = args.args.first()
503 {
504 return is_nested_struct(inner);
505 }
506 false
507}