despatma_lib/
trait_specifier.rs

1use proc_macro2::TokenStream;
2use quote::ToTokens;
3use std::collections::HashMap;
4use syn::parse::{Parse, ParseStream, Result};
5use syn::{Token, Type};
6use tokenstream2_tmpl::{interpolate, Interpolate};
7
8/// Type that holds an abstract type and how it will map to a concrete type.
9///
10/// An acceptable stream will have the following form:
11/// ```text
12/// trait => concrete
13/// ```
14#[cfg_attr(any(test, feature = "extra-traits"), derive(Eq, PartialEq, Debug))]
15pub struct TraitSpecifier {
16    pub abstract_trait: Type,
17    pub arrow_token: Token![=>],
18    pub concrete: Type,
19}
20
21/// Make TraitSpecifier parsable from a token stream
22impl Parse for TraitSpecifier {
23    fn parse(input: ParseStream) -> Result<Self> {
24        Ok(TraitSpecifier {
25            abstract_trait: input.parse()?,
26            arrow_token: input.parse()?,
27            concrete: input.parse()?,
28        })
29    }
30}
31
32/// Make TraitSpecifier interpolatible
33impl Interpolate for TraitSpecifier {
34    fn interpolate(&self, stream: TokenStream) -> TokenStream {
35        let mut replacements: HashMap<_, &dyn ToTokens> = HashMap::new();
36
37        // Replace each "TRAIT" with the absract trait
38        replacements.insert("TRAIT", &self.abstract_trait);
39
40        // Replace each "CONCRETE" with the concrete type
41        replacements.insert("CONCRETE", &self.concrete);
42
43        interpolate(stream, &replacements)
44    }
45}
46
47#[cfg(test)]
48mod tests {
49    use super::*;
50    use despatma_test_helpers::reformat;
51    use pretty_assertions::assert_eq;
52    use quote::quote;
53    use syn::parse_str;
54
55    type Result = std::result::Result<(), Box<dyn std::error::Error>>;
56
57    #[test]
58    fn parse() -> Result {
59        let actual: TraitSpecifier = parse_str("abstract_trait => concrete")?;
60        let expected = TraitSpecifier {
61            abstract_trait: parse_str("abstract_trait")?,
62            arrow_token: Default::default(),
63            concrete: parse_str("concrete")?,
64        };
65
66        assert_eq!(actual, expected);
67        Ok(())
68    }
69
70    #[test]
71    #[should_panic(expected = "expected one of")]
72    fn missing_trait() {
73        parse_str::<TraitSpecifier>("=> concrete").unwrap();
74    }
75
76    #[test]
77    #[should_panic(expected = "expected `=>`")]
78    fn missing_arrow_joiner() {
79        parse_str::<TraitSpecifier>("IButton -> RoundButton").unwrap();
80    }
81
82    #[test]
83    #[should_panic(expected = "unexpected end of input")]
84    fn missing_concrete() {
85        parse_str::<TraitSpecifier>("abstract_trait => ").unwrap();
86    }
87
88    #[test]
89    fn interpolate() -> Result {
90        let input = quote! {
91            impl Factory<TRAIT> for Gnome {
92                fn create(&self) -> CONCRETE {
93                    CONCRETE{}
94                }
95            }
96        };
97        let expected = quote! {
98            impl Factory<abstract_trait> for Gnome {
99                fn create(&self) -> concrete {
100                    concrete{}
101                }
102            }
103        };
104        let specifier = TraitSpecifier {
105            abstract_trait: parse_str("abstract_trait")?,
106            arrow_token: Default::default(),
107            concrete: parse_str("concrete")?,
108        };
109
110        assert_eq!(reformat(&specifier.interpolate(input)), reformat(&expected));
111
112        Ok(())
113    }
114}