elfo_macros_impl/
msg.rs

1use std::{char, collections::HashMap};
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{
6    parse_macro_input, spanned::Spanned, Arm, ExprMatch, Ident, Pat, PatIdent, PatWild, Path, Token,
7};
8
9// TODO: use `proc-macro-error` instead of `panic!`.
10// TODO: use `proc-macro-crate`?
11
12#[derive(Debug)]
13struct MessageGroup {
14    kind: GroupKind,
15    arms: Vec<Arm>,
16}
17
18#[derive(Debug, Hash, PartialEq, Eq)]
19enum GroupKind {
20    // `msg @ Msg(..) => ...`
21    Regular(Path),
22    // `(msg @ Msg(..), token) => ...`
23    Request(Path),
24    // `_ =>`
25    // `msg =>`
26    Wild,
27}
28
29fn is_valid_token_ident(ident: &PatIdent) -> bool {
30    !ident.ident.to_string().starts_with('_')
31}
32
33fn is_type_ident(ident: &Ident) -> bool {
34    ident
35        .to_string()
36        .chars()
37        .next()
38        .map_or(false, char::is_uppercase)
39}
40
41fn extract_path_to_type(path: &Path) -> Path {
42    let mut ident_rev_it = path.segments.iter().rev();
43
44    // Handle enum variants:
45    // `some::Enum::Variant`
46    //        ^- must be uppercased
47    //
48    // Yep, it's crazy, but it seems to be a good assumption for now.
49    if let Some(prev) = ident_rev_it.nth(1) {
50        if is_type_ident(&prev.ident) {
51            let mut path = path.clone();
52            path.segments.pop().unwrap();
53
54            // Convert `Pair::Punctuated` to `Pair::End`.
55            let (last, _) = path.segments.pop().unwrap().into_tuple();
56            path.segments.push(last);
57            return path;
58        }
59    }
60
61    path.clone()
62}
63
64fn extract_kind(pat: &Pat) -> Result<GroupKind, &'static str> {
65    match pat {
66        Pat::Box(_) => Err("box patterns are forbidden"),
67        Pat::Ident(pat) => match pat.subpat.as_ref() {
68            Some(sp) => extract_kind(&sp.1),
69            None if is_type_ident(&pat.ident) => {
70                Ok(GroupKind::Regular(Path::from(pat.ident.clone())))
71            }
72            None => Ok(GroupKind::Wild),
73        },
74        Pat::Lit(_) => Err("literal patterns are forbidden"),
75        Pat::Macro(_) => Err("macros in pattern position are forbidden"),
76        Pat::Or(pat) => pat
77            .cases
78            .iter()
79            .find_map(|pat| extract_kind(pat).ok())
80            .ok_or("cannot determine the message's type"),
81        Pat::Path(pat) => Ok(GroupKind::Regular(extract_path_to_type(&pat.path))),
82        Pat::Range(_) => Err("range patterns are forbidden"),
83        Pat::Reference(pat) => extract_kind(&pat.pat),
84        Pat::Rest(_) => Err("rest patterns are forbidden"),
85        Pat::Slice(_) => Err("slice patterns are forbidden"),
86        Pat::Struct(pat) => Ok(GroupKind::Regular(extract_path_to_type(&pat.path))),
87        Pat::Tuple(pat) => {
88            assert_eq!(pat.elems.len(), 2, "invalid request pattern");
89
90            match pat.elems.last().unwrap() {
91                Pat::Ident(pat) if is_valid_token_ident(pat) => {}
92                _ => panic!("the token must be used"),
93            }
94
95            match extract_kind(pat.elems.first().unwrap())? {
96                GroupKind::Regular(path) => Ok(GroupKind::Request(path)),
97                _ => Err("cannot determine the request's type"),
98            }
99        }
100        Pat::TupleStruct(pat) => Ok(GroupKind::Regular(extract_path_to_type(&pat.path))),
101        Pat::Type(_) => Err("type ascription patterns are forbidden"),
102        Pat::Wild(_) => Ok(GroupKind::Wild),
103        _ => Err("unknown tokens"),
104    }
105}
106
107fn is_likely_type(pat: &Pat) -> bool {
108    match pat {
109        Pat::Ident(i) if i.subpat.is_none() && is_type_ident(&i.ident) => true,
110        Pat::Path(p) if extract_path_to_type(&p.path) == p.path => true,
111        _ => false,
112    }
113}
114
115/// Detects `a @ A` and `a @ some::A` patterns.
116fn is_binding_with_type(ident: &PatIdent) -> bool {
117    ident
118        .subpat
119        .as_ref()
120        .map_or(false, |sp| is_likely_type(&sp.1))
121}
122
123fn refine_pat(pat: &mut Pat) {
124    match pat {
125        // `e @ Enum`
126        // `s @ Struct` (~ `s @ Struct { .. }`)
127        Pat::Ident(ident) if is_binding_with_type(ident) => {
128            ident.subpat = None;
129        }
130        // `(e @ SomeType, token)`
131        // `(SomeType, token)`
132        Pat::Tuple(pat) => {
133            assert_eq!(pat.elems.len(), 2, "invalid request pattern");
134
135            match pat.elems.first_mut() {
136                Some(Pat::Ident(ident)) if is_binding_with_type(ident) => {
137                    ident.subpat = None;
138                }
139                Some(pat) if is_likely_type(pat) => {
140                    *pat = Pat::Wild(PatWild {
141                        attrs: Vec::new(),
142                        underscore_token: Token![_](pat.span()),
143                    });
144                }
145                _ => {}
146            }
147        }
148        // `SomeType => ...`
149        pat if is_likely_type(pat) => {
150            *pat = Pat::Wild(PatWild {
151                attrs: Vec::new(),
152                underscore_token: Token![_](pat.span()),
153            });
154        }
155        _ => {}
156    };
157}
158
159fn add_groups(groups: &mut Vec<MessageGroup>, mut arm: Arm) -> Result<(), &'static str> {
160    let mut add = |kind, arm: Arm| {
161        // println!("group {:?} {:#?}", kind, arm.pat);
162        match groups.iter_mut().find(|common| common.kind == kind) {
163            Some(common) => common.arms.push(arm),
164            None => groups.push(MessageGroup {
165                kind,
166                arms: vec![arm],
167            }),
168        }
169    };
170
171    if let Pat::Or(pat) = &arm.pat {
172        let mut map = HashMap::new();
173
174        for pat in &pat.cases {
175            let kind = extract_kind(pat)?;
176            let new_arm = map.entry(kind).or_insert_with(|| {
177                let mut arm = arm.clone();
178                if let Pat::Or(pat) = &mut arm.pat {
179                    pat.cases.clear();
180                }
181                arm
182            });
183
184            if let Pat::Or(new_pat) = &mut new_arm.pat {
185                let mut old_pat = pat.clone();
186                refine_pat(&mut old_pat);
187                new_pat.cases.push(old_pat);
188            }
189        }
190
191        for (kind, arm) in map {
192            add(kind, arm);
193        }
194    } else {
195        let kind = extract_kind(&arm.pat)?;
196        refine_pat(&mut arm.pat);
197        add(kind, arm);
198    }
199
200    Ok(())
201}
202
203pub fn msg_impl(input: TokenStream, path_to_elfo: Path) -> TokenStream {
204    let input = parse_macro_input!(input as ExprMatch);
205    let mut groups = Vec::<MessageGroup>::with_capacity(input.arms.len());
206    let crate_ = path_to_elfo;
207
208    for arm in input.arms.into_iter() {
209        add_groups(&mut groups, arm).expect("invalid pattern");
210    }
211
212    let envelope_ident = quote! { _elfo_envelope };
213
214    // println!(">>> HERE {:#?}", groups);
215
216    let groups = groups
217        .iter()
218        .map(|group| match (&group.kind, &group.arms[..]) {
219            (GroupKind::Regular(path), arms) => quote! {
220                else if #envelope_ident.is::<#path>() {
221                    // TODO: replace with `static_assertions`.
222                    trait Forbidden<A, E> { fn test(_: &E) {} }
223                    impl<E, M> Forbidden<(), E> for M {}
224                    struct Invalid;
225                    impl<E: EnvelopeOwned, M: #crate_::Request> Forbidden<Invalid, E> for M {}
226                    let _ = <#path as Forbidden<_, _>>::test(&#envelope_ident);
227                    // -----
228
229                    let message = #envelope_ident.unpack_regular();
230                    match message.downcast2::<#path>() {
231                        #(#arms)*
232                    }
233                }
234            },
235            (GroupKind::Request(path), arms) => quote! {
236                else if #envelope_ident.is::<#path>() {
237                    assert_impl_all!(#path: #crate_::Request);
238                    let (message, token) = #envelope_ident.unpack_request::<#path>();
239                    match (message.downcast2::<#path>(), token) {
240                        #(#arms)*
241                    }
242                }
243            },
244            (GroupKind::Wild, [arm]) => quote! {
245                else {
246                    match #envelope_ident { #arm }
247                }
248            },
249            (GroupKind::Wild, _) => panic!("too many default branches"),
250        });
251
252    let match_expr = input.expr;
253
254    // TODO: propagate `input.attrs`?
255    let expanded = quote! {{
256        use #crate_::_priv::*;
257        let #envelope_ident = #match_expr;
258        if false { unreachable!(); }
259        #(#groups)*
260    }};
261
262    TokenStream::from(expanded)
263}