1use std::{char, collections::HashMap};
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{
6 parse_macro_input, spanned::Spanned, Arm, ExprMatch, Ident, Pat, PatIdent, PatWild, Path, Token,
7};
8
9#[derive(Debug)]
13struct MessageGroup {
14 kind: GroupKind,
15 arms: Vec<Arm>,
16}
17
18#[derive(Debug, Hash, PartialEq, Eq)]
19enum GroupKind {
20 Regular(Path),
22 Request(Path),
24 Wild,
27}
28
29fn is_valid_token_ident(ident: &PatIdent) -> bool {
30 !ident.ident.to_string().starts_with('_')
31}
32
33fn is_type_ident(ident: &Ident) -> bool {
34 ident
35 .to_string()
36 .chars()
37 .next()
38 .map_or(false, char::is_uppercase)
39}
40
41fn extract_path_to_type(path: &Path) -> Path {
42 let mut ident_rev_it = path.segments.iter().rev();
43
44 if let Some(prev) = ident_rev_it.nth(1) {
50 if is_type_ident(&prev.ident) {
51 let mut path = path.clone();
52 path.segments.pop().unwrap();
53
54 let (last, _) = path.segments.pop().unwrap().into_tuple();
56 path.segments.push(last);
57 return path;
58 }
59 }
60
61 path.clone()
62}
63
64fn extract_kind(pat: &Pat) -> Result<GroupKind, &'static str> {
65 match pat {
66 Pat::Box(_) => Err("box patterns are forbidden"),
67 Pat::Ident(pat) => match pat.subpat.as_ref() {
68 Some(sp) => extract_kind(&sp.1),
69 None if is_type_ident(&pat.ident) => {
70 Ok(GroupKind::Regular(Path::from(pat.ident.clone())))
71 }
72 None => Ok(GroupKind::Wild),
73 },
74 Pat::Lit(_) => Err("literal patterns are forbidden"),
75 Pat::Macro(_) => Err("macros in pattern position are forbidden"),
76 Pat::Or(pat) => pat
77 .cases
78 .iter()
79 .find_map(|pat| extract_kind(pat).ok())
80 .ok_or("cannot determine the message's type"),
81 Pat::Path(pat) => Ok(GroupKind::Regular(extract_path_to_type(&pat.path))),
82 Pat::Range(_) => Err("range patterns are forbidden"),
83 Pat::Reference(pat) => extract_kind(&pat.pat),
84 Pat::Rest(_) => Err("rest patterns are forbidden"),
85 Pat::Slice(_) => Err("slice patterns are forbidden"),
86 Pat::Struct(pat) => Ok(GroupKind::Regular(extract_path_to_type(&pat.path))),
87 Pat::Tuple(pat) => {
88 assert_eq!(pat.elems.len(), 2, "invalid request pattern");
89
90 match pat.elems.last().unwrap() {
91 Pat::Ident(pat) if is_valid_token_ident(pat) => {}
92 _ => panic!("the token must be used"),
93 }
94
95 match extract_kind(pat.elems.first().unwrap())? {
96 GroupKind::Regular(path) => Ok(GroupKind::Request(path)),
97 _ => Err("cannot determine the request's type"),
98 }
99 }
100 Pat::TupleStruct(pat) => Ok(GroupKind::Regular(extract_path_to_type(&pat.path))),
101 Pat::Type(_) => Err("type ascription patterns are forbidden"),
102 Pat::Wild(_) => Ok(GroupKind::Wild),
103 _ => Err("unknown tokens"),
104 }
105}
106
107fn is_likely_type(pat: &Pat) -> bool {
108 match pat {
109 Pat::Ident(i) if i.subpat.is_none() && is_type_ident(&i.ident) => true,
110 Pat::Path(p) if extract_path_to_type(&p.path) == p.path => true,
111 _ => false,
112 }
113}
114
115fn is_binding_with_type(ident: &PatIdent) -> bool {
117 ident
118 .subpat
119 .as_ref()
120 .map_or(false, |sp| is_likely_type(&sp.1))
121}
122
123fn refine_pat(pat: &mut Pat) {
124 match pat {
125 Pat::Ident(ident) if is_binding_with_type(ident) => {
128 ident.subpat = None;
129 }
130 Pat::Tuple(pat) => {
133 assert_eq!(pat.elems.len(), 2, "invalid request pattern");
134
135 match pat.elems.first_mut() {
136 Some(Pat::Ident(ident)) if is_binding_with_type(ident) => {
137 ident.subpat = None;
138 }
139 Some(pat) if is_likely_type(pat) => {
140 *pat = Pat::Wild(PatWild {
141 attrs: Vec::new(),
142 underscore_token: Token),
143 });
144 }
145 _ => {}
146 }
147 }
148 pat if is_likely_type(pat) => {
150 *pat = Pat::Wild(PatWild {
151 attrs: Vec::new(),
152 underscore_token: Token),
153 });
154 }
155 _ => {}
156 };
157}
158
159fn add_groups(groups: &mut Vec<MessageGroup>, mut arm: Arm) -> Result<(), &'static str> {
160 let mut add = |kind, arm: Arm| {
161 match groups.iter_mut().find(|common| common.kind == kind) {
163 Some(common) => common.arms.push(arm),
164 None => groups.push(MessageGroup {
165 kind,
166 arms: vec![arm],
167 }),
168 }
169 };
170
171 if let Pat::Or(pat) = &arm.pat {
172 let mut map = HashMap::new();
173
174 for pat in &pat.cases {
175 let kind = extract_kind(pat)?;
176 let new_arm = map.entry(kind).or_insert_with(|| {
177 let mut arm = arm.clone();
178 if let Pat::Or(pat) = &mut arm.pat {
179 pat.cases.clear();
180 }
181 arm
182 });
183
184 if let Pat::Or(new_pat) = &mut new_arm.pat {
185 let mut old_pat = pat.clone();
186 refine_pat(&mut old_pat);
187 new_pat.cases.push(old_pat);
188 }
189 }
190
191 for (kind, arm) in map {
192 add(kind, arm);
193 }
194 } else {
195 let kind = extract_kind(&arm.pat)?;
196 refine_pat(&mut arm.pat);
197 add(kind, arm);
198 }
199
200 Ok(())
201}
202
203pub fn msg_impl(input: TokenStream, path_to_elfo: Path) -> TokenStream {
204 let input = parse_macro_input!(input as ExprMatch);
205 let mut groups = Vec::<MessageGroup>::with_capacity(input.arms.len());
206 let crate_ = path_to_elfo;
207
208 for arm in input.arms.into_iter() {
209 add_groups(&mut groups, arm).expect("invalid pattern");
210 }
211
212 let envelope_ident = quote! { _elfo_envelope };
213
214 let groups = groups
217 .iter()
218 .map(|group| match (&group.kind, &group.arms[..]) {
219 (GroupKind::Regular(path), arms) => quote! {
220 else if #envelope_ident.is::<#path>() {
221 trait Forbidden<A, E> { fn test(_: &E) {} }
223 impl<E, M> Forbidden<(), E> for M {}
224 struct Invalid;
225 impl<E: EnvelopeOwned, M: #crate_::Request> Forbidden<Invalid, E> for M {}
226 let _ = <#path as Forbidden<_, _>>::test(&#envelope_ident);
227 let message = #envelope_ident.unpack_regular();
230 match message.downcast2::<#path>() {
231 #(#arms)*
232 }
233 }
234 },
235 (GroupKind::Request(path), arms) => quote! {
236 else if #envelope_ident.is::<#path>() {
237 assert_impl_all!(#path: #crate_::Request);
238 let (message, token) = #envelope_ident.unpack_request::<#path>();
239 match (message.downcast2::<#path>(), token) {
240 #(#arms)*
241 }
242 }
243 },
244 (GroupKind::Wild, [arm]) => quote! {
245 else {
246 match #envelope_ident { #arm }
247 }
248 },
249 (GroupKind::Wild, _) => panic!("too many default branches"),
250 });
251
252 let match_expr = input.expr;
253
254 let expanded = quote! {{
256 use #crate_::_priv::*;
257 let #envelope_ident = #match_expr;
258 if false { unreachable!(); }
259 #(#groups)*
260 }};
261
262 TokenStream::from(expanded)
263}