dync_derive/
lib.rs

1use std::collections::{BTreeSet, HashMap};
2
3use heck::*;
4use lazy_static::lazy_static;
5use proc_macro2::{Span, TokenStream};
6use quote::{quote, TokenStreamExt};
7use syn::parse::{Parse, ParseStream};
8use syn::punctuated::Punctuated;
9use syn::*;
10
11type GenericsMap = HashMap<Ident, Punctuated<TypeParamBound, Token![+]>>;
12type TraitMap = HashMap<Trait, TraitData>;
13
14#[derive(Debug)]
15struct Config {
16    dync_crate_name: String,
17    suffix: String,
18    build_vtable_only: bool,
19}
20
21impl Default for Config {
22    fn default() -> Self {
23        Config {
24            dync_crate_name: String::from("dync"),
25            suffix: String::from("VTable"),
26            build_vtable_only: false,
27        }
28    }
29}
30
31#[derive(Debug)]
32struct DynAttrib {
33    ident: Ident,
34    eq: Option<Token![=]>,
35    value: Option<Lit>,
36}
37
38impl Parse for DynAttrib {
39    fn parse(input: ParseStream) -> Result<Self> {
40        let ident = input.parse()?;
41        let eq = input.parse()?;
42        let value = input.parse()?;
43        Ok(DynAttrib { ident, eq, value })
44    }
45}
46
47impl Parse for Config {
48    fn parse(input: ParseStream) -> Result<Self> {
49        let mut config = Config::default();
50        let attribs: Punctuated<DynAttrib, Token![,]> = Punctuated::parse_terminated(input)?;
51        for attrib in attribs.iter() {
52            let name = attrib.ident.to_string();
53            match (name.as_str(), &attrib.value) {
54                ("build_vtable_only", None) => config.build_vtable_only = true,
55                ("dync_crate_name", Some(Lit::Str(ref lit))) => {
56                    config.dync_crate_name = lit.value().clone()
57                }
58                ("suffix", Some(Lit::Str(ref lit))) => config.suffix = lit.value().clone(),
59                _ => {}
60            }
61        }
62        Ok(config)
63    }
64}
65
66#[derive(Clone, Debug, PartialEq)]
67struct TraitData {
68    pub path: String,
69    pub methods: Vec<TraitMethod>,
70    pub super_traits: BTreeSet<Trait>,
71}
72
73impl TraitData {
74    fn path(&self) -> Path {
75        syn::parse_str(&self.path).unwrap()
76    }
77    fn has_trait(&self) -> Ident {
78        let path = self.path();
79        let name = path.segments.last().unwrap().ident.clone();
80        Ident::new(&format!("Has{}", name), Span::call_site())
81    }
82    fn bytes_trait(&self) -> Ident {
83        let path = self.path();
84        let name = path.segments.last().unwrap().ident.clone();
85        Ident::new(&format!("{}Bytes", name), Span::call_site())
86    }
87    fn vtable_name(&self) -> Ident {
88        let path = self.path();
89        let seg = path.segments.last().unwrap();
90        let trait_name = &seg.ident;
91        Ident::new(&format!("{}VTable", &trait_name), Span::call_site())
92    }
93}
94
95#[derive(Clone, Debug, PartialEq)]
96struct TraitMethod {
97    pub name: String,
98}
99
100impl TraitMethod {
101    fn new(mut name: &str) -> Self {
102        if name.ends_with("_fn") {
103            name = &name[..name.len() - 3];
104        }
105        TraitMethod {
106            name: name.to_string(),
107        }
108    }
109    fn fn_type(&self) -> Ident {
110        Ident::new(
111            &format!("{}Fn", self.name.to_camel_case()),
112            Span::call_site(),
113        )
114    }
115    fn bytes_fn(&self) -> Ident {
116        Ident::new(
117            &format!("{}_bytes", self.name.to_snek_case()),
118            Span::call_site(),
119        )
120    }
121    fn has_fn(&self) -> Ident {
122        Ident::new(
123            &format!("{}_fn", self.name.to_snek_case()),
124            Span::call_site(),
125        )
126    }
127}
128
129#[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Ord, Hash)]
130enum Trait {
131    Drop,
132    Clone,
133    PartialEq,
134    Eq,
135    Hash,
136    Debug,
137    Send,
138    Sync,
139    Custom(String),
140}
141
142impl Trait {
143    fn is_unsafe(&self) -> bool {
144        matches!(self, Trait::Send) || matches!(self, Trait::Sync)
145    }
146    fn prefix(&self, crate_name: &str) -> TokenStream {
147        if BUILTINS.contains_key(self) {
148            let crate_name = Ident::new(crate_name, Span::call_site());
149            quote! { #crate_name::traits:: }
150        } else {
151            TokenStream::new()
152        }
153    }
154}
155
156impl From<String> for Trait {
157    fn from(s: String) -> Trait {
158        match s.as_str() {
159            "Drop" | "std::ops::Drop" => Trait::Drop,
160            "Clone" | "std::clone::Clone" => Trait::Clone,
161            "PartialEq" | "std::cmp::PartialEq" => Trait::PartialEq,
162            "Eq" | "std::cmp::Eq" => Trait::Eq,
163            "Hash" | "std::hash::Hash" => Trait::Hash,
164            "Debug" | "std::fmt::Debug" => Trait::Debug,
165            "Send" | "std::marker::Send" => Trait::Send,
166            "Sync" | "std::marker::Sync" => Trait::Sync,
167            x => Trait::Custom(x.to_string()),
168        }
169    }
170}
171
172impl<'a> From<Path> for Trait {
173    fn from(p: Path) -> Trait {
174        match p {
175            x if x == parse_quote! { Drop } => Trait::Drop,
176            x if x == parse_quote! { Clone } => Trait::Clone,
177            x if x == parse_quote! { PartialEq } => Trait::PartialEq,
178            x if x == parse_quote! { Eq } => Trait::Eq,
179            x if x == parse_quote! { std::hash::Hash } => Trait::Hash,
180            x if x == parse_quote! { std::fmt::Debug } => Trait::Debug,
181            x if x == parse_quote! { Send } => Trait::Send,
182            x if x == parse_quote! { Sync } => Trait::Sync,
183            x => Trait::Custom(format!("{}", quote! { #x })),
184        }
185    }
186}
187
188lazy_static! {
189    static ref BUILTINS: HashMap<Trait, TraitData> = {
190        let mut m = HashMap::new();
191        m.insert(
192            Trait::Drop,
193            TraitData {
194                path: "Drop".to_string(),
195                methods: vec![TraitMethod::new("drop")],
196                super_traits: BTreeSet::new(),
197            },
198        );
199        m.insert(
200            Trait::Clone,
201            TraitData {
202                path: "Clone".to_string(),
203                methods: vec![
204                    TraitMethod::new("clone"),
205                    TraitMethod::new("clone_from"),
206                    TraitMethod::new("clone_into_raw"),
207                ],
208                super_traits: BTreeSet::new(),
209            },
210        );
211        m.insert(
212            Trait::PartialEq,
213            TraitData {
214                path: "PartialEq".to_string(),
215                methods: vec![TraitMethod::new("eq")],
216                super_traits: BTreeSet::new(),
217            },
218        );
219        m.insert(
220            Trait::Eq,
221            TraitData {
222                path: "Eq".to_string(),
223                methods: vec![],
224                super_traits: [Trait::PartialEq].iter().cloned().collect(),
225            },
226        );
227        m.insert(
228            Trait::Hash,
229            TraitData {
230                path: "std::hash::Hash".to_string(),
231                methods: vec![TraitMethod::new("hash")],
232                super_traits: BTreeSet::new(),
233            },
234        );
235        m.insert(
236            Trait::Debug,
237            TraitData {
238                path: "std::fmt::Debug".to_string(),
239                methods: vec![TraitMethod::new("fmt")],
240                super_traits: BTreeSet::new(),
241            },
242        );
243        m.insert(
244            Trait::Send,
245            TraitData {
246                path: "Send".to_string(),
247                methods: vec![],
248                super_traits: BTreeSet::new(),
249            },
250        );
251        m.insert(
252            Trait::Sync,
253            TraitData {
254                path: "Sync".to_string(),
255                methods: vec![],
256                super_traits: BTreeSet::new(),
257            },
258        );
259        m
260    };
261}
262
263#[proc_macro_attribute]
264pub fn dync_mod(
265    attr: proc_macro::TokenStream,
266    item: proc_macro::TokenStream,
267) -> proc_macro::TokenStream {
268    let config: Config = syn::parse(attr).expect("Failed to parse attributes");
269
270    let mut item_mod: ItemMod =
271        syn::parse(item).expect("the dync_mod attribute applies only to mod definitions");
272
273    validate_item_mod(&item_mod);
274
275    let mut trait_map = BUILTINS.clone();
276
277    fill_and_flatten_trait_map_from_mod(&item_mod, &mut trait_map);
278
279    fill_drop_for_inheritance(&mut trait_map);
280
281    //fill_inherited_impls(&mut impls, &config, &trait_inheritance);
282
283    let mut dync_items = Vec::new();
284
285    for item in item_mod.content.as_ref().unwrap().1.iter() {
286        if let Item::Trait(item_trait) = item {
287            dync_items.append(&mut construct_dync_items(&item_trait, &config, &trait_map));
288        }
289    }
290
291    item_mod.content.as_mut().unwrap().1.append(&mut dync_items);
292
293    let tokens = quote! { #item_mod };
294
295    tokens.into()
296}
297
298// Reject unsupported item traits.
299fn validate_item_mod(item_mod: &ItemMod) {
300    assert!(
301        item_mod.content.is_some() && item_mod.semi.is_none(),
302        "dync_mod attribute only works on modules containing trait definitions"
303    );
304}
305
306//fn fill_inherited_impls(
307//    impls: &mut ImplMap,
308//    config: &Config,
309//    trait_inheritance: &InheritMap,
310//) {
311//    for (trait_path, _) in trait_inheritance.iter() {
312//        let seg = trait_path.segments.last().unwrap();
313//        let trait_name = &seg.ident;
314//        let vtable_name = Ident::new(
315//            &format!("{}{}", &trait_name, &config.suffix),
316//            Span::call_site(),
317//        );
318//        // TODO: Construct the trait functions.
319//        impls
320//            .entry(trait_path.clone())
321//            .or_insert_with(|| (parse_quote! { #vtable_name }, vec![]));
322//    }
323//}
324
325fn fill_and_flatten_trait_map_from_mod(item_mod: &ItemMod, trait_map: &mut TraitMap) {
326    // First we fill out all traits we know with their supertraits.
327    for item in item_mod.content.as_ref().unwrap().1.iter() {
328        if let Item::Trait(item_trait) = item {
329            fill_trait_map_from_item_trait(&item_trait, trait_map);
330        }
331    }
332
333    // Then we eliminate propagate inherited traits down to the base traits.
334    for item in item_mod.content.as_ref().unwrap().1.iter() {
335        if let Item::Trait(item_trait) = item {
336            flatten_inheritance(&item_trait, trait_map);
337        }
338    }
339}
340
341// Add the drop trait to everything, so that the drop function will always be included.
342fn fill_drop_for_inheritance(trait_map: &mut TraitMap) {
343    for (trait_key, trait_data) in trait_map.iter_mut() {
344        if !matches!(trait_key, &Trait::Drop) {
345            trait_data.super_traits.insert(Trait::Drop);
346        }
347    }
348}
349
350fn fill_trait_map_from_item_trait(item_trait: &ItemTrait, trait_map: &mut TraitMap) {
351    let trait_name = item_trait.ident.to_string();
352    trait_map
353        .entry(Trait::from(trait_name.clone()))
354        .or_insert_with(|| {
355            let super_traits: BTreeSet<Trait> = item_trait
356                .supertraits
357                .iter()
358                .filter_map(|bound| {
359                    if let TypeParamBound::Trait(bound) = bound {
360                        if bound.lifetimes.is_some() || bound.modifier != TraitBoundModifier::None {
361                            // We are looking for recognizable traits only
362                            None
363                        } else {
364                            // Generic traits are ignored
365                            if bound
366                                .path
367                                .segments
368                                .iter()
369                                .all(|seg| seg.arguments.is_empty())
370                            {
371                                Some(Trait::from(bound.path.clone()))
372                            } else {
373                                None
374                            }
375                        }
376                    } else {
377                        None
378                    }
379                })
380                .collect();
381            TraitData {
382                path: trait_name,
383                methods: vec![], // TODO: implement loading trait methods.
384                super_traits,
385            }
386        });
387}
388
389fn flatten_inheritance(item_trait: &ItemTrait, trait_map: &mut TraitMap) {
390    let trait_name = &item_trait.ident;
391    let trait_key = Trait::from(trait_name.to_string());
392    let mut trait_data = trait_map.remove(&trait_key).unwrap();
393    trait_data.super_traits = trait_data
394        .super_traits
395        .into_iter()
396        .flat_map(|super_trait| {
397            union_children(&super_trait, trait_map)
398                .into_iter()
399                .chain(std::iter::once(super_trait.clone()))
400        })
401        .collect();
402    assert!(trait_map.insert(trait_key, trait_data).is_none());
403}
404
405fn union_children(trait_key: &Trait, trait_map: &TraitMap) -> BTreeSet<Trait> {
406    let mut res = BTreeSet::new();
407    if let Some(trait_data) = trait_map.get(trait_key) {
408        res.extend(trait_data.super_traits.iter().cloned());
409        for super_trait in trait_data.super_traits.iter() {
410            res.extend(union_children(super_trait, trait_map).into_iter());
411        }
412    }
413    res
414}
415
416#[proc_macro_attribute]
417pub fn dync_trait(
418    attr: proc_macro::TokenStream,
419    item: proc_macro::TokenStream,
420) -> proc_macro::TokenStream {
421    let config: Config = syn::parse(attr).expect("Failed to parse attributes");
422
423    let item_trait: ItemTrait =
424        syn::parse(item).expect("the dync_trait attribute applies only to trait definitions");
425
426    validate_item_trait(&item_trait);
427
428    let mut trait_map = BUILTINS.clone();
429
430    fill_trait_map_from_item_trait(&item_trait, &mut trait_map);
431
432    fill_drop_for_inheritance(&mut trait_map);
433
434    //fill_inherited_impls(&mut impls, &config, &trait_inheritance);
435
436    let dync_items = construct_dync_items(&item_trait, &config, &trait_map);
437
438    let mut tokens = quote! { #item_trait };
439    for item in dync_items.into_iter() {
440        tokens.append_all(quote! { #item });
441    }
442    tokens.into()
443}
444
445// Reject unsupported item traits.
446fn validate_item_trait(item_trait: &ItemTrait) {
447    assert!(
448        item_trait.generics.params.is_empty(),
449        "trait generics are not supported by dync_trait"
450    );
451    assert!(
452        item_trait.generics.where_clause.is_none(),
453        "traits with where clauses are not supported by dync_trait"
454    );
455}
456
457fn vtable_struct_def(
458    trait_data: &TraitData,
459    vis: &Visibility,
460    config: &Config,
461    trait_map: &TraitMap,
462) -> Item {
463    let vtable_name = trait_data.vtable_name();
464
465    // The vtable is flattened.
466    let vtable_fields: Punctuated<Field, Token![,]> = trait_data
467        .super_traits
468        .iter()
469        .flat_map(|trait_key| {
470            trait_map
471                .get(&trait_key)
472                .into_iter()
473                .flat_map(move |trait_data| {
474                    let prefix = trait_key.prefix(&config.dync_crate_name);
475
476                    trait_data.methods.iter().map(move |method| {
477                        let fn_ty = method.fn_type();
478                        Field {
479                            attrs: Vec::new(),
480                            vis: Visibility::Inherited,
481                            ident: None,
482                            colon_token: None,
483                            ty: parse_quote! { #prefix #fn_ty },
484                        }
485                    })
486                })
487        })
488        .collect();
489
490    parse_quote! {
491        #[derive(Copy, Clone)]
492        #vis struct #vtable_name (#vtable_fields);
493    }
494}
495
496fn construct_dync_items(
497    item_trait: &ItemTrait,
498    config: &Config,
499    trait_map: &TraitMap,
500) -> Vec<Item> {
501    let vis = item_trait.vis.clone();
502    let crate_name = &config.dync_crate_name;
503
504    let trait_name = &item_trait.ident;
505    let trait_key = Trait::from(trait_name.to_string());
506    let trait_data = trait_map.get(&trait_key).unwrap(); // We should have already entered it.
507    let vtable_name = trait_data.vtable_name();
508
509    let vtable_def = vtable_struct_def(&trait_data, &vis, config, trait_map);
510    //eprintln!("{}", quote! { vtable_def });
511
512    // Construct HasTrait impls
513    let mut has_impls: Vec<Item> = Vec::new();
514    let mut has_impl_deps: Punctuated<Path, Token![+]> = Punctuated::new();
515
516    let mut fn_idx_usize = 0;
517    for super_trait_key in trait_data.super_traits.iter() {
518        //dbg!(super_trait_key);
519        let impl_entry = trait_map.get(&super_trait_key);
520        if impl_entry.is_none() {
521            continue;
522        }
523
524        let prefix = super_trait_key.prefix(crate_name);
525
526        let super_trait_data = impl_entry.unwrap();
527        let mut methods = TokenStream::new();
528        for method in super_trait_data.methods.iter() {
529            let fn_idx = syn::Index::from(fn_idx_usize);
530            let fn_ty = method.fn_type();
531            let has_fn = method.has_fn();
532            methods.append_all(quote! {
533                #[inline]
534                fn #has_fn ( &self ) -> &#prefix #fn_ty { &self.#fn_idx }
535            });
536            fn_idx_usize += 1;
537        }
538
539        let has_trait = super_trait_data.has_trait();
540
541        //eprintln!("{}", &methods);
542
543        let maybe_unsafe = if super_trait_key.is_unsafe() {
544            quote! { unsafe }
545        } else {
546            TokenStream::new()
547        };
548        has_impls.push(parse_quote! {
549            #maybe_unsafe impl #prefix #has_trait for #vtable_name {
550                #methods
551            }
552        });
553        has_impl_deps.push(parse_quote! { #prefix #has_trait });
554    }
555
556    // HasTrait for the current trait.
557    let has_trait = Ident::new(&format!("Has{}", trait_name.to_string()), Span::call_site());
558    has_impls.push(parse_quote! {
559        #vis trait #has_trait: #has_impl_deps {
560            // TODO: add has fns
561        }
562    });
563    has_impls.push(parse_quote! {
564        impl #has_trait for #vtable_name {
565            // TODO: add has fns impls
566        }
567    });
568
569    let vtable_constructor = trait_data
570        .super_traits
571        .iter()
572        .flat_map(|super_trait_key| {
573            let crate_name = &crate_name;
574            trait_map
575                .get(&super_trait_key)
576                .into_iter()
577                .flat_map(move |super_trait_data| {
578                    let prefix = super_trait_key.prefix(crate_name);
579                    let bytes_trait = super_trait_data.bytes_trait();
580                    super_trait_data.methods.iter().map(move |method| {
581                        let bytes_fn = method.bytes_fn();
582                        let tuple: Expr = parse_quote! { <T as #prefix #bytes_trait> :: #bytes_fn };
583                        //eprintln!("{}", quote! { #tuple });
584                        tuple
585                    })
586                })
587        })
588        .collect::<Punctuated<Expr, Token![,]>>();
589
590    let crate_name_ident = Ident::new(&crate_name, Span::call_site());
591
592    let mut res = has_impls;
593    res.push(parse_quote! {
594        #crate_name_ident::downcast::impl_downcast!(#has_trait);
595    });
596    res.push(vtable_def);
597    res.push(parse_quote! {
598        impl<T: #trait_name + 'static> #crate_name_ident::VTable<T> for #vtable_name {
599            #[inline]
600            fn build_vtable() -> #vtable_name {
601                #vtable_name(#vtable_constructor)
602            }
603        }
604    });
605
606    //let a = res.last().unwrap();
607    //eprintln!("{}", quote! { #a } );
608
609    // Conversions to base tables
610    for super_trait_key in trait_data.super_traits.iter() {
611        let mut conversion_exprs = Punctuated::<Expr, Token![,]>::new();
612        if let Some(super_trait_data) = trait_map.get(&super_trait_key) {
613            //let base_trait_path = &super_trait_data.path;
614            //eprintln!("base: {}", quote! { #base_trait_path });
615            for ss_trait_key in super_trait_data.super_traits.iter() {
616                if let Some(ss_trait_data) = trait_map.get(&ss_trait_key) {
617                    //let ss_trait_path = &ss_trait_data.path;
618                    //eprintln!("  inherit: {}", quote! { #ss_trait_path });
619                    for method in ss_trait_data.methods.iter() {
620                        let has_fn = method.has_fn();
621                        let expr: Expr = parse_quote! { *derived.#has_fn() };
622                        //eprintln!("    inherit_fn: {}", quote! { #expr });
623                        conversion_exprs.push(parse_quote! { #expr });
624                    }
625                }
626            }
627            for method in super_trait_data.methods.iter() {
628                let has_fn = method.has_fn();
629                let expr: Expr = parse_quote! { *derived.#has_fn() };
630                //eprintln!("  self_fn: {}", quote! { #expr});
631                conversion_exprs.push(expr);
632            }
633
634            let prefix = super_trait_key.prefix(crate_name);
635            let base_vtable_name = super_trait_data.vtable_name();
636
637            // super trait is a custom one, we should be able to define the conversion
638
639            //let convert_item = quote! {
640            //    impl From<#vtable_name> for #prefix #base_vtable_name {
641            //        fn from(derived: #vtable_name) -> Self {
642            //            use #crate_name_ident :: traits::*;
643            //            #base_vtable_name ( #conversion_exprs )
644            //        }
645            //    }
646            //};
647            //eprintln!("{}", quote! { #convert_item });
648            res.push(parse_quote! {
649                impl From<#vtable_name> for #prefix #base_vtable_name {
650                    fn from(derived: #vtable_name) -> Self {
651                        use #crate_name_ident :: traits::*;
652                        #base_vtable_name ( #conversion_exprs )
653                    }
654                }
655            });
656        }
657
658        //let a = res.last().unwrap();
659        //eprintln!("{}", quote! { #a } );
660    }
661
662    res
663}
664
665//struct UtilityFns {
666//    from_bytes_fn: ItemFn,
667//    from_bytes_mut_fn: ItemFn,
668//    as_bytes_fn: ItemFn,
669//    box_into_box_bytes_fn: ItemFn,
670//    clone_fn: (TypeBareFn, ItemFn),
671//    clone_from_fn: (TypeBareFn, ItemFn),
672//    clone_into_raw_fn: (TypeBareFn, ItemFn),
673//    eq_fn: (TypeBareFn, ItemFn),
674//    hash_fn: (TypeBareFn, ItemFn),
675//    fmt_fn: (TypeBareFn, ItemFn),
676//}
677//
678//impl UtilityFns {
679//    fn new() -> Self {
680//        // Byte Helpers
681//        let from_bytes_fn: ItemFn = parse_quote! {
682//            #[inline]
683//            unsafe fn from_bytes<S: 'static>(bytes: &[u8]) -> &S {
684//                assert_eq!(bytes.len(), std::mem::size_of::<S>());
685//                &*(bytes.as_ptr() as *const S)
686//            }
687//        };
688//
689//        let from_bytes_mut_fn: ItemFn = parse_quote! {
690//            #[inline]
691//            unsafe fn from_bytes_mut<S: 'static>(bytes: &mut [u8]) -> &mut S {
692//                assert_eq!(bytes.len(), std::mem::size_of::<S>());
693//                &mut *(bytes.as_mut_ptr() as *mut S)
694//            }
695//        };
696//
697//        let as_bytes_fn: ItemFn = parse_quote! {
698//            #[inline]
699//            unsafe fn as_bytes<S: 'static>(s: &S) -> &[u8] {
700//                // This is safe since any memory can be represented by bytes and we are looking at
701//                // sized types only.
702//                unsafe { std::slice::from_raw_parts(s as *const S as *const u8, std::mem::size_of::<S>()) }
703//            }
704//        };
705//
706//        let box_into_box_bytes_fn: ItemFn = parse_quote! {
707//            #[inline]
708//            fn box_into_box_bytes<S: 'static>(b: Box<S>) -> Box<[u8]> {
709//                let byte_ptr = Box::into_raw(b) as *mut u8;
710//                // This is safe since any memory can be represented by bytes and we are looking at
711//                // sized types only.
712//                unsafe { Box::from_raw(std::slice::from_raw_parts_mut(byte_ptr, std::mem::size_of::<S>())) }
713//            }
714//        };
715//
716//        // Implement known trait functions.
717//        let clone_fn: (TypeBareFn, ItemFn) = (
718//            parse_quote! { unsafe fn (&[u8]) -> Box<[u8]> },
719//            parse_quote! {
720//                #[inline]
721//                unsafe fn clone_fn<S: Clone + 'static>(src: &[u8]) -> Box<[u8]> {
722//                    let typed_src: &S = from_bytes(src);
723//                    box_into_box_bytes(Box::new(typed_src.clone()))
724//                }
725//            },
726//        );
727//        let clone_from_fn: (TypeBareFn, ItemFn) = (
728//            parse_quote! { unsafe fn (&mut [u8], &[u8]) },
729//            parse_quote! {
730//                #[inline]
731//                unsafe fn clone_from_fn<S: Clone + 'static>(dst: &mut [u8], src: &[u8]) {
732//                    let typed_src: &S = from_bytes(src);
733//                    let typed_dst: &mut S = from_bytes_mut(dst);
734//                    typed_dst.clone_from(typed_src);
735//                }
736//            },
737//        );
738//
739//        let clone_into_raw_fn: (TypeBareFn, ItemFn) = (
740//            parse_quote! { unsafe fn (&[u8], &mut [u8]) },
741//            parse_quote! {
742//                #[inline]
743//                unsafe fn clone_into_raw_fn<S: Clone + 'static>(src: &[u8], dst: &mut [u8]) {
744//                    let typed_src: &S = from_bytes(src);
745//                    let cloned = S::clone(typed_src);
746//                    let cloned_bytes = as_bytes(&cloned);
747//                    dst.copy_from_slice(cloned_bytes);
748//                    let _ = std::mem::ManuallyDrop::new(cloned);
749//                }
750//            },
751//        );
752//
753//        let eq_fn: (TypeBareFn, ItemFn) = (
754//            parse_quote! { unsafe fn (&[u8], &[u8]) -> bool },
755//            parse_quote! {
756//                #[inline]
757//                unsafe fn eq_fn<S: PartialEq + 'static>(a: &[u8], b: &[u8]) -> bool {
758//                    let (a, b): (&S, &S) = (from_bytes(a), from_bytes(b));
759//                    a.eq(b)
760//                }
761//            },
762//        );
763//        let hash_fn: (TypeBareFn, ItemFn) = (
764//            parse_quote! { unsafe fn (&[u8], &mut dyn std::hash::Hasher) },
765//            parse_quote! {
766//                #[inline]
767//                unsafe fn hash_fn<S: std::hash::Hash + 'static>(bytes: &[u8], mut state: &mut dyn std::hash::Hasher) {
768//                    let typed_data: &S = from_bytes(bytes);
769//                    typed_data.hash(&mut state)
770//                }
771//            },
772//        );
773//        let fmt_fn: (TypeBareFn, ItemFn) = (
774//            parse_quote! { unsafe fn (&[u8], &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> },
775//            parse_quote! {
776//                #[inline]
777//                unsafe fn fmt_fn<S: std::fmt::Debug + 'static>(bytes: &[u8], f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
778//                    let typed_data: &S = from_bytes(bytes);
779//                    typed_data.fmt(f)
780//                }
781//            },
782//        );
783//
784//        UtilityFns {
785//            from_bytes_fn,
786//            from_bytes_mut_fn,
787//            as_bytes_fn,
788//            box_into_box_bytes_fn,
789//            clone_fn,
790//            clone_from_fn,
791//            clone_into_raw_fn,
792//            eq_fn,
793//            hash_fn,
794//            fmt_fn,
795//        }
796//    }
797//}
798
799#[proc_macro_attribute]
800pub fn dync_trait_method(
801    _attr: proc_macro::TokenStream,
802    item: proc_macro::TokenStream,
803) -> proc_macro::TokenStream {
804    let mut trait_method: TraitItemMethod = syn::parse(item).expect(
805        "the dync_trait_function attribute applies only to trait function definitions only",
806    );
807
808    trait_method.sig = dync_fn_sig(trait_method.sig);
809
810    let tokens = quote! { #trait_method };
811
812    tokens.into()
813}
814
815/// Convert a function signature by replacing self types with bytes.
816fn dync_fn_sig(sig: Signature) -> Signature {
817    assert!(
818        sig.constness.is_none(),
819        "const functions not supported by dync_trait"
820    );
821    assert!(
822        sig.asyncness.is_none(),
823        "async functions not supported by dync_trait"
824    );
825    assert!(
826        sig.abi.is_none(),
827        "extern functions not supported by dync_trait"
828    );
829    assert!(
830        sig.variadic.is_none(),
831        "variadic functions not supported by dync_trait"
832    );
833
834    let dync_name = format!("{}_bytes", sig.ident);
835    let dync_ident = Ident::new(&dync_name, sig.ident.span());
836
837    let mut generics = GenericsMap::new();
838
839    // Convert generics into `dyn Trait` if possible.
840    for gen in sig.generics.params.iter() {
841        match gen {
842            GenericParam::Type(ty) => {
843                assert!(
844                    ty.attrs.is_empty(),
845                    "type parameter attributes are not supported by dync_trait"
846                );
847                assert!(
848                    ty.colon_token.is_some(),
849                    "unbound type parameters are not supported by dync_trait"
850                );
851                assert!(
852                    ty.eq_token.is_none() && ty.default.is_none(),
853                    "default type parameters are not supported by dync_trait"
854                );
855                generics.insert(ty.ident.clone(), ty.bounds.clone());
856            }
857            GenericParam::Lifetime(_) => {
858                panic!("lifetime parameters in trait functions are not supported by dync_trait");
859            }
860            GenericParam::Const(_) => {
861                panic!("const parameters in trait functions are not supported by dync_trait");
862            }
863        }
864    }
865    if let Some(where_clause) = sig.generics.where_clause {
866        for pred in where_clause.predicates.iter() {
867            match pred {
868                WherePredicate::Type(ty) => {
869                    assert!(
870                        ty.lifetimes.is_none(),
871                        "lifetimes in for bindings are not supported by dync_trait"
872                    );
873                    if let Type::Path(ty_path) = ty.bounded_ty.clone() {
874                        assert!(
875                            ty_path.qself.is_none(),
876                            "complex trait bounds are not supported by dync_trait"
877                        );
878                        assert!(
879                            ty_path.path.leading_colon.is_none(),
880                            "complex trait bounds are not supported by dync_trait"
881                        );
882                        assert!(
883                            ty_path.path.segments.len() != 1,
884                            "complex trait bounds are not supported by dync_trait"
885                        );
886                        let seg = ty_path.path.segments.first().unwrap();
887                        assert!(
888                            !seg.arguments.is_empty(),
889                            "complex trait bounds are not supported by dync_trait"
890                        );
891                        generics.insert(seg.ident.clone(), ty.bounds.clone());
892                    }
893                }
894                WherePredicate::Lifetime(_) => {
895                    panic!(
896                        "lifetime parameters in trait functions are not supported by dync_trait"
897                    );
898                }
899                _ => {}
900            }
901        }
902    }
903
904    // Convert inputs.
905    let dync_inputs: Punctuated<FnArg, Token![,]> = sig
906        .inputs
907        .iter()
908        .map(|fn_arg| {
909            FnArg::Typed(match fn_arg {
910                FnArg::Receiver(Receiver {
911                    attrs,
912                    reference,
913                    mutability,
914                    ..
915                }) => {
916                    let ty: Type = if let Some((_, lifetime)) = reference {
917                        syn::parse(
918                            quote! { & #lifetime #mutability [std::mem::MaybeUninit<u8>] }.into(),
919                        )
920                        .unwrap()
921                    } else {
922                        syn::parse(quote! { #mutability Box<[std::mem::MaybeUninit<u8>]> }.into())
923                            .unwrap()
924                    };
925                    PatType {
926                        attrs: attrs.to_vec(),
927                        pat: syn::parse(quote! { _self_ }.into()).unwrap(),
928                        colon_token: Token![:](Span::call_site()),
929                        ty: Box::new(ty),
930                    }
931                }
932                FnArg::Typed(pat_ty) => PatType {
933                    ty: Box::new(type_to_bytes(process_generics(
934                        *pat_ty.ty.clone(),
935                        &generics,
936                    ))),
937                    ..pat_ty.clone()
938                },
939            })
940        })
941        .collect();
942
943    // Convert return type.
944    let dync_output: Type = match sig.output {
945        ReturnType::Type(_, ty) => type_to_bytes(process_generics(*ty, &generics)),
946        ReturnType::Default => syn::parse(quote! { () }.into()).unwrap(),
947    };
948
949    Signature {
950        unsafety: Some(Token![unsafe](Span::call_site())),
951        ident: dync_ident,
952        generics: Generics {
953            lt_token: None,
954            params: Punctuated::new(),
955            gt_token: None,
956            where_clause: None,
957        },
958        inputs: dync_inputs,
959        output: ReturnType::Type(Token![->](Span::call_site()), Box::new(dync_output)),
960        ..sig
961    }
962}
963
964// Translate any generics occuring in types according to the accumulated generics map by converting
965// generic types into trait objects.
966fn process_generics(ty: Type, generics: &GenericsMap) -> Type {
967    match ty {
968        Type::Paren(paren) => Type::Paren(TypeParen {
969            elem: Box::new(process_generics(*paren.elem, generics)),
970            ..paren
971        }),
972        Type::Path(path) => process_generic_type_path(path, generics, true),
973        Type::Ptr(ptr) => Type::Ptr(TypePtr {
974            elem: Box::new(generic_ref_to_trait_object(*ptr.elem, generics)),
975            ..ptr
976        }),
977        Type::Reference(reference) => Type::Reference(TypeReference {
978            elem: Box::new(generic_ref_to_trait_object(*reference.elem, generics)),
979            ..reference
980        }),
981        pass_through => {
982            check_for_unsupported_generics(&pass_through, generics);
983            pass_through
984        }
985    }
986}
987
988// Convert Self type into a the given type or pass through
989fn process_generic_type_path(ty: TypePath, generics: &GenericsMap, owned: bool) -> Type {
990    if ty.path.leading_colon.is_some() || ty.path.segments.len() != 1 {
991        return Type::Path(ty);
992    }
993
994    let seg = ty.path.segments.first().unwrap();
995    if !seg.arguments.is_empty() {
996        return Type::Path(ty);
997    }
998
999    // Generic types wouldn't have arguments.
1000    if let Some(bounds) = generics.get(&seg.ident) {
1001        if owned {
1002            syn::parse(quote! { Box<dyn #bounds> }.into()).unwrap()
1003        } else {
1004            syn::parse(quote! { dyn #bounds }.into()).unwrap()
1005        }
1006    } else {
1007        Type::Path(ty)
1008    }
1009}
1010
1011// Convert reference or pointer to self into a reference to bytes or pass through
1012fn generic_ref_to_trait_object(ty: Type, generics: &GenericsMap) -> Type {
1013    match ty {
1014        Type::Path(path) => process_generic_type_path(path, generics, false),
1015        other => other,
1016    }
1017}
1018
1019// Check if there are instances of generics in unsupported places.
1020fn check_for_unsupported_generics(ty: &Type, generics: &GenericsMap) {
1021    match ty {
1022        Type::Array(arr) => check_for_unsupported_generics(&arr.elem, generics),
1023        Type::BareFn(barefn) => {
1024            for input in barefn.inputs.iter() {
1025                check_for_unsupported_generics(&input.ty, generics);
1026            }
1027            if let ReturnType::Type(_, output_ty) = &barefn.output {
1028                check_for_unsupported_generics(&*output_ty, generics);
1029            }
1030        }
1031        Type::Group(group) => check_for_unsupported_generics(&group.elem, generics),
1032        Type::Paren(paren) => check_for_unsupported_generics(&paren.elem, generics),
1033        Type::Path(path) => {
1034            assert!(
1035                path.qself.is_none(),
1036                "qualified paths not supported by dync_trait"
1037            );
1038            if path.path.leading_colon.is_none() && path.path.segments.len() == 1 {
1039                let seg = path.path.segments.first().unwrap();
1040                assert!(
1041                    seg.arguments.is_empty() && seg.ident == "Self",
1042                    "using Self in this context is not supported by dync_trait"
1043                );
1044            }
1045        }
1046        Type::Ptr(ptr) => check_for_unsupported_generics(&ptr.elem, generics),
1047        Type::Reference(reference) => check_for_unsupported_generics(&reference.elem, generics),
1048        Type::Slice(slice) => check_for_unsupported_generics(&slice.elem, generics),
1049        Type::Tuple(tuple) => {
1050            for elem in tuple.elems.iter() {
1051                check_for_unsupported_generics(elem, generics);
1052            }
1053        }
1054        _ => {}
1055    }
1056}
1057
1058fn type_to_bytes(ty: Type) -> Type {
1059    // It is quite difficult to convert occurances of Self in a function signature to the
1060    // corresponding byte representation because of composability of types. Each type containing
1061    // self must know how to convert its contents to bytes, which is completely out of the scope
1062    // here.
1063    //
1064    // However some builtin types (like arrays, tuples and slices) and std library types can be
1065    // handled.  This probably one of the reasons why trait objects don't support traits with
1066    // functions that take in `Self` as a parameter. We will try to relax this constraint as much
1067    // as we can in this function.
1068
1069    match ty {
1070        //Type::Array(arr) => Type::Array(TypeArray {
1071        //    elem: Box::new(type_to_bytes(*arr.elem)),
1072        //    ..arr
1073        //}),
1074        //Type::Group(group) => Type::Group(TypeGroup {
1075        //    elem: Box::new(type_to_bytes(*group.elem),
1076        //    ..group
1077        //}),
1078        Type::ImplTrait(impl_trait) => Type::TraitObject(TypeTraitObject {
1079            // Convert `impl Trait` to `dyn Trait`.
1080            dyn_token: Some(Token![dyn](Span::call_site())),
1081            bounds: impl_trait.bounds,
1082        }),
1083        Type::Paren(paren) => Type::Paren(TypeParen {
1084            elem: Box::new(type_to_bytes(*paren.elem)),
1085            ..paren
1086        }),
1087        Type::Path(path) => self_type_path_into(
1088            path,
1089            syn::parse(quote! { Box<[std::mem::MaybeUninit<u8>]> }.into()).unwrap(),
1090        ),
1091        Type::Ptr(ptr) => Type::Ptr(TypePtr {
1092            elem: Box::new(self_to_byte_slice(*ptr.elem)),
1093            ..ptr
1094        }),
1095        Type::Reference(reference) => Type::Reference(TypeReference {
1096            elem: Box::new(self_to_byte_slice(*reference.elem)),
1097            ..reference
1098        }),
1099        //Type::Slice(slice) => Type::Slice(TypeSlice {
1100        //    elem: Box::new(type_to_bytes(*slice.elem)),
1101        //    ..slice
1102        //}),
1103        //Type::Tuple(tuple) => Type::Tuple(TypeTuple {
1104        //    elems: elems.into_iter().map(|elem| type_to_bytes(elem)),
1105        //    ..tuple
1106        //}),
1107        pass_through => {
1108            check_for_unsupported_self(&pass_through);
1109            pass_through
1110        }
1111    }
1112}
1113
1114// Convert Self type into a the given type or pass through
1115fn self_type_path_into(path: TypePath, into_ty: Type) -> Type {
1116    assert!(
1117        path.qself.is_none(),
1118        "qualified paths not supported by dync_trait"
1119    );
1120    if path.path.leading_colon.is_none() && path.path.segments.len() == 1 {
1121        let seg = path.path.segments.first().unwrap();
1122        if seg.arguments.is_empty() // Self types wouldn't have arguments.
1123            && seg.ident == "Self"
1124        {
1125            into_ty
1126        } else {
1127            Type::Path(path)
1128        }
1129    } else {
1130        Type::Path(path)
1131    }
1132}
1133
1134// Convert reference or pointer to self into a reference to bytes or pass through
1135fn self_to_byte_slice(ty: Type) -> Type {
1136    match ty {
1137        Type::Path(path) => self_type_path_into(
1138            path,
1139            syn::parse(quote! { [std::mem::MaybeUninit<u8>] }.into()).unwrap(),
1140        ),
1141        other => other,
1142    }
1143}
1144
1145// Check if there are instances of Self in the given type, and panic if there are.
1146fn check_for_unsupported_self(ty: &Type) {
1147    match ty {
1148        Type::Array(arr) => check_for_unsupported_self(&arr.elem),
1149        Type::BareFn(barefn) => {
1150            for input in barefn.inputs.iter() {
1151                check_for_unsupported_self(&input.ty);
1152            }
1153            if let ReturnType::Type(_, output_ty) = &barefn.output {
1154                check_for_unsupported_self(&*output_ty);
1155            }
1156        }
1157        Type::Group(group) => check_for_unsupported_self(&group.elem),
1158        Type::Paren(paren) => check_for_unsupported_self(&paren.elem),
1159        Type::Path(path) => {
1160            assert!(
1161                path.qself.is_none(),
1162                "qualified paths not supported by dync_trait"
1163            );
1164            if path.path.leading_colon.is_none() && path.path.segments.len() == 1 {
1165                let seg = path.path.segments.first().unwrap();
1166                assert!(
1167                    seg.arguments.is_empty() && seg.ident == "Self",
1168                    "using Self in this context is not supported by dync_trait"
1169                );
1170            }
1171        }
1172        Type::Ptr(ptr) => check_for_unsupported_self(&ptr.elem),
1173        Type::Reference(reference) => check_for_unsupported_self(&reference.elem),
1174        Type::Slice(slice) => check_for_unsupported_self(&slice.elem),
1175        Type::Tuple(tuple) => {
1176            for elem in tuple.elems.iter() {
1177                check_for_unsupported_self(elem);
1178            }
1179        }
1180        _ => {}
1181    }
1182}