pyro_macro/format/
deep_ref.rs1use proc_macro2::TokenStream;
2use quote::{format_ident, quote};
3use syn::{GenericArgument, Ident, ItemStruct, Path, PathArguments, Type, TypePath};
4
5pub fn deep_ref(
8 input: &ItemStruct,
9 import_location: &Path,
10 derives_to_pass: &Vec<Ident>,
11) -> syn::Result<TokenStream> {
12 if !input.generics.params.is_empty() {
13 return Err(syn::Error::new_spanned(
14 &input.generics,
15 "DeepRef cannot be derived for structs with generic parameters (types, lifetimes, or consts)",
16 ));
17 }
18
19 let struct_name = &input.ident;
20 let ref_struct_name = format_ident!("{}Ref", struct_name);
21
22 let fields = &input.fields;
24
25 let mut lifetime_used = false;
27 let ref_fields = fields
28 .iter()
29 .map(|f| {
30 let name = &f.ident;
31 let vis = &f.vis;
32 let ty = &f.ty;
33 let (mapped_type, is_primitive) = map_type_to_ref(ty);
34
35 if !is_primitive {
36 lifetime_used = true;
37 }
38
39 quote! { #vis #name: #mapped_type }
40 })
41 .collect::<Vec<_>>();
42
43 let phantom_field = if !lifetime_used {
44 quote! { _phantom: std::marker::PhantomData<&'a ()> }
45 } else {
46 quote! {}
47 };
48
49 let struct_def = quote! {
51 #[derive(#(#derives_to_pass),*)]
52 pub struct #ref_struct_name<'a> {
53 #(#ref_fields,)*
54 #phantom_field
55 }
56 };
57
58 let field_conversions = fields.iter().map(|f| {
60 let field_name = f.ident.as_ref().unwrap();
61 let ty = &f.ty;
62 generate_field_conversion(field_name, ty)
63 });
64
65 let phantom_init = if !lifetime_used {
66 quote! { _phantom: std::marker::PhantomData }
67 } else {
68 quote! {}
69 };
70
71 let impl_owned = quote! {
72 impl #import_location::format::DeepRef for #struct_name {
73 type Ref<'a> = #ref_struct_name<'a>;
74
75 fn as_deep_ref(&self) -> Self::Ref<'_> {
76 #ref_struct_name {
77 #(#field_conversions,)*
78 #phantom_init
79 }
80 }
81 }
82 };
83
84 Ok(quote! {
85 #struct_def
86 #impl_owned
87 })
88}
89
90pub fn deep_ref_rkyv(input: &ItemStruct, import_location: &Path) -> syn::Result<TokenStream> {
93 let struct_name = &input.ident;
94 let archived_struct_name = format_ident!("Archived{}", struct_name);
96 let ref_struct_name = format_ident!("{}Ref", struct_name);
97
98 let fields = &input.fields;
102 let rkyv_field_conversions = fields.iter().map(|f| {
103 let field_name = f.ident.as_ref().unwrap();
104 let ty = &f.ty;
105 generate_rkyv_field_conversion(field_name, ty)
106 });
107
108 let mut lifetime_used = false;
110 for f in fields {
111 let (_, is_prim) = map_type_to_ref(&f.ty);
112 if !is_prim {
113 lifetime_used = true;
114 break;
115 }
116 }
117
118 let phantom_init = if !lifetime_used {
119 quote! { _phantom: std::marker::PhantomData }
120 } else {
121 quote! {}
122 };
123
124 let impl_archived = quote! {
126 #[cfg(target_endian = "little")]
127 impl #import_location::format::DeepRef for #archived_struct_name {
128 type Ref<'a> = #ref_struct_name<'a>;
129
130 fn as_deep_ref(&self) -> Self::Ref<'_> {
131 #ref_struct_name {
132 #(#rkyv_field_conversions,)*
133 #phantom_init
134 }
135 }
136 }
137 };
138
139 Ok(quote! {
141 #impl_archived
142 })
143}
144
145pub(crate) fn map_type_to_ref(ty: &Type) -> (TokenStream, bool) {
151 match ty {
152 Type::Path(TypePath { path, .. }) => {
153 let segment = path.segments.last().unwrap();
154 let ident_str = segment.ident.to_string();
155
156 if is_string_like(ty) {
158 return (quote! { &'a str }, false);
159 }
160
161 match ident_str.as_str() {
162 "bool" | "i8" | "i16" | "i32" | "i64" | "isize" | "u8" | "u16" | "u32" | "u64"
163 | "usize" | "f16" | "f32" | "f64" => {
164 let ident = &segment.ident;
165 (quote! { #ident }, true)
166 }
167 "Vec" => {
168 if let PathArguments::AngleBracketed(args) = &segment.arguments
169 && let Some(GenericArgument::Type(inner_ty)) = args.args.first()
170 {
171 let (inner_ref, is_prim) = map_type_to_ref(inner_ty);
172 if is_prim {
173 return (quote! { &'a [#inner_ref] }, false);
174 } else {
175 return (quote! { Vec<#inner_ref> }, false);
177 }
178 }
179 (quote! { Vec<()> }, false)
180 }
181 "Option" => {
182 if let PathArguments::AngleBracketed(args) = &segment.arguments
183 && let Some(GenericArgument::Type(inner_ty)) = args.args.first()
184 {
185 if is_primitive(inner_ty) {
190 return (quote! { Option<#inner_ty> }, true);
191 }
192
193 if is_string_like(inner_ty) {
195 return (quote! { Option<&'a str> }, false);
196 }
197
198 let (inner_ref, _) = map_type_to_ref(inner_ty);
200 return (quote! { Option<#inner_ref> }, false);
201 }
202 (quote! { Option<()> }, false)
203 }
204 other => {
206 let ref_name = format_ident!("{}Ref", other);
207 (quote! { #ref_name<'a> }, false)
208 }
209 }
210 }
211 _ => (quote! { () }, true),
212 }
213}
214
215fn generate_field_conversion(field_name: &Ident, ty: &Type) -> TokenStream {
217 match ty {
218 Type::Path(TypePath { path, .. }) => {
219 let segment = path.segments.last().unwrap();
220 let ident_str = segment.ident.to_string();
221
222 if is_string_like(ty) {
224 if ident_str == "String" {
228 return quote! { #field_name: self.#field_name.as_str() };
229 } else {
230 return quote! { #field_name: &self.#field_name };
231 }
232 }
233
234 match ident_str.as_str() {
235 "bool" | "i8" | "i16" | "i32" | "i64" | "isize" | "u8" | "u16" | "u32" | "u64"
237 | "usize" | "f16" | "f32" | "f64" => {
238 quote! { #field_name: self.#field_name }
239 }
240
241 "Vec" => {
243 if let PathArguments::AngleBracketed(args) = &segment.arguments {
244 if let Some(GenericArgument::Type(inner_ty)) = args.args.first() {
245 if is_primitive(inner_ty) {
246 quote! { #field_name: self.#field_name.as_slice() }
248 } else if is_string_like(inner_ty) {
249 if ident_str == "String" || is_string(inner_ty) {
251 quote! { #field_name: self.#field_name.iter().map(|x| x.as_str()).collect() }
252 } else {
253 quote! { #field_name: self.#field_name.iter().map(|x| x.as_ref()).collect() }
254 }
255 } else {
256 quote! {
258 #field_name: self.#field_name.iter().map(|x| x.as_deep_ref()).collect()
259 }
260 }
261 } else {
262 quote! { #field_name: vec![] }
263 }
264 } else {
265 quote! { #field_name: vec![] }
266 }
267 }
268
269 "Option" => {
271 if let PathArguments::AngleBracketed(args) = &segment.arguments {
272 if let Some(GenericArgument::Type(inner_ty)) = args.args.first() {
273 if is_primitive(inner_ty) {
274 quote! { #field_name: self.#field_name }
276 } else if is_string_like(inner_ty) {
277 quote! { #field_name: self.#field_name.as_deref() }
279 } else {
280 quote! { #field_name: self.#field_name.as_ref().map(|x| x.as_deep_ref()) }
282 }
283 } else {
284 quote! { #field_name: None }
285 }
286 } else {
287 quote! { #field_name: None }
288 }
289 }
290
291 _ => {
293 quote! { #field_name: self.#field_name.as_deep_ref() }
294 }
295 }
296 }
297 _ => quote! { #field_name: self.#field_name.as_deep_ref() },
298 }
299}
300
301fn generate_rkyv_field_conversion(field_name: &Ident, ty: &Type) -> TokenStream {
304 match ty {
305 Type::Path(TypePath { path, .. }) => {
306 let segment = path.segments.last().unwrap();
307 let ident_str = segment.ident.to_string();
308
309 match ident_str.as_str() {
310 "bool" | "i8" | "u8" => {
312 quote! { #field_name: self.#field_name }
313 }
314
315 "i16" | "i16_le" | "i32" | "i32_le" | "i64" | "i64_le" | "isize" | "u16"
317 | "u16_le" | "u32" | "u32_le" | "u64" | "usize" | "u64_le" | "f16" | "f16_le"
318 | "f32" | "f32_le" | "f64" | "f64_le" => {
319 quote! { #field_name: self.#field_name.to_native() as _ }
320 }
321
322 "String" | "ArchivedString" => {
324 quote! { #field_name: self.#field_name.as_str() }
325 }
326
327 "Vec" | "ArchivedVec" => {
329 if let PathArguments::AngleBracketed(args) = &segment.arguments {
330 if let Some(GenericArgument::Type(inner_ty)) = args.args.first() {
331 if is_primitive(inner_ty) {
332 quote! { #field_name: unsafe { std::mem::transmute(self.#field_name.as_slice()) }}
333 } else if is_string_like(inner_ty) {
334 quote! {
337 #field_name: self.#field_name.iter().map(|x| x.as_str()).collect()
338 }
339 } else {
340 quote! {
341 #field_name: self.#field_name.iter().map(|x| x.as_deep_ref()).collect()
342 }
343 }
344 } else {
345 quote! { #field_name: vec![] }
346 }
347 } else {
348 quote! { #field_name: vec![] }
349 }
350 }
351
352 "Option" | "ArchivedOption" => {
354 if let PathArguments::AngleBracketed(args) = &segment.arguments {
355 if let Some(GenericArgument::Type(inner_ty)) = args.args.first() {
356 if is_primitive(inner_ty) {
357 quote! {
361 #field_name: self.#field_name.as_ref().map(|x| x.to_native() as _)
362 }
363 } else if is_string_like(inner_ty) {
364 quote! {
369 #field_name: self.#field_name.as_ref().map(|x| x.as_str())
370 }
371 } else {
372 quote! { #field_name: self.#field_name.as_ref().map(|x| x.as_deep_ref()) }
374 }
375 } else {
376 quote! { #field_name: None }
377 }
378 } else {
379 quote! { #field_name: None }
380 }
381 }
382
383 _ => {
385 quote! { #field_name: self.#field_name.as_deep_ref() }
386 }
387 }
388 }
389 _ => quote! { #field_name: self.#field_name.as_deep_ref() },
390 }
391}
392
393fn is_primitive(ty: &Type) -> bool {
395 if let Type::Path(TypePath { path, .. }) = ty {
396 let ident = path.segments.last().unwrap().ident.to_string();
397 matches!(
398 ident.as_str(),
399 "i8" | "i16"
400 | "i32"
401 | "i64"
402 | "isize"
403 | "u8"
404 | "u16"
405 | "u32"
406 | "u64"
407 | "usize"
408 | "f16"
409 | "f32"
410 | "f64"
411 | "bool"
412 )
413 } else {
414 false
415 }
416}
417
418fn is_string(ty: &Type) -> bool {
419 if let Type::Path(TypePath { path, .. }) = ty {
420 let ident = path.segments.last().unwrap().ident.to_string();
421 ident == "String"
422 } else {
423 false
424 }
425}
426
427fn is_string_like(ty: &Type) -> bool {
430 if let Type::Path(TypePath { path, .. }) = ty {
431 let segment = path.segments.last().unwrap();
432 let ident = segment.ident.to_string();
433
434 if ident == "String" {
436 return true;
437 }
438
439 if matches!(ident.as_str(), "Arc" | "Box" | "Cow" | "Rc")
441 && let PathArguments::AngleBracketed(args) = &segment.arguments
442 && let Some(GenericArgument::Type(inner_ty)) = args.args.first()
443 {
444 if let Type::Path(TypePath {
446 path: inner_path, ..
447 }) = inner_ty
448 && let Some(inner_seg) = inner_path.segments.last()
449 {
450 return inner_seg.ident == "str";
451 }
452 }
453 }
454 false
455}