impl_trait/
lib.rs

1// impl_trait - Rust proc macro that significantly reduces boilerplate
2// Copyright (C) 2021  Soni L.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU Affero General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8//
9// This program is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12// GNU Affero General Public License for more details.
13//
14// You should have received a copy of the GNU Affero General Public License
15// along with this program.  If not, see <https://www.gnu.org/licenses/>.
16
17extern crate proc_macro;
18use proc_macro::{TokenStream, TokenTree, Delimiter, Span};
19//use syn::parse::{Parse, ParseStream, Result as ParseResult};
20use syn::{Generics, GenericParam};
21use std::cmp::Ordering;
22use quote::ToTokens;
23
24#[proc_macro]
25#[allow(unreachable_code)]
26pub fn impl_trait(item: TokenStream) -> TokenStream {
27    //eprintln!("INPUT: {:#?}", item);
28    let mut output: Vec<_> = item.into_iter().collect();
29    let attributes: Vec<TokenTree> = {
30        let mut pos = 0;
31        let mut len = 0;
32        let mut in_attr = false;
33        while pos != output.len() {
34            let tt = &output[pos];
35            pos += 1;
36            match tt {
37                &TokenTree::Punct(ref punct) => {
38                    if punct.as_char() == '#' && !in_attr {
39                        in_attr = true;
40                        continue;
41                    }
42                }
43                &TokenTree::Group(ref group) => {
44                    if group.delimiter() == Delimiter::Bracket && in_attr {
45                        in_attr = false;
46                        len = pos;
47                        continue;
48                    }
49                }
50                _ => {}
51            }
52            break;
53        }
54        output.drain(0..len).collect()
55    };
56    //eprintln!("attributes: {:#?}", attributes);
57    // check for impl.
58    // unsafe impls are only available for traits and are automatically rejected.
59    'check_impl: loop { break {
60        if let &TokenTree::Ident(ref ident) = &output[0] {
61            if format!("{}", ident) == "impl" {
62                break 'check_impl;
63            }
64        }
65        panic!("impl_trait! may only be applied to inherent impls");
66    } }
67    let mut has_where: Option<&TokenTree> = None;
68    'check_no_for_before_where: loop { break {
69        for tt in &output {
70            if let &TokenTree::Ident(ref ident) = tt {
71                let formatted = format!("{}", ident);
72                if formatted == "where" {
73                    has_where = Some(tt);
74                    break 'check_no_for_before_where;
75                } else if formatted == "for" {
76                    panic!("impl_trait! may only be applied to inherent impls");
77                }
78            }
79        }
80    } }
81    // this is the "where [...]" part, including the "where".
82    let mut where_bounds = Vec::new();
83    if let Some(where_in) = has_where {
84        where_bounds = output.split_last().unwrap().1.into_iter().skip_while(|&tt| {
85            !std::ptr::eq(tt, where_in)
86        }).cloned().collect();
87    }
88    let where_bounds = where_bounds;
89    drop(has_where);
90    let mut count = 0;
91    // this is the "<...>" part, immediately after the "impl", and including the "<>".
92    let generics = output.split_first().unwrap().1.into_iter().take_while(|&tt| {
93        let mut result = count > 0;
94        if let &TokenTree::Punct(ref punct) = tt {
95            let c = punct.as_char();
96            if c == '<' {
97                count += 1;
98                result = true;
99            } else if c == '>' {
100                count -= 1;
101            }
102        }
103        result
104    }).cloned().collect::<Vec<_>>();
105    // so how do you find the target? well...
106    // "impl" + [_; generics.len()] + [_; target.len()] + [_; where_bounds.len()] + "{}"
107    // we have generics and where_bounds, and the total, so we can easily find target!
108    let target_start = 1 + generics.len();
109    let target_end = output.len() - 1 - where_bounds.len();
110    let target_range = target_start..target_end;
111    let target = (&output[target_range]).into_iter().cloned().collect::<Vec<_>>();
112    //eprintln!("generics: {:#?}", generics);
113    //eprintln!("target: {:#?}", target);
114    //eprintln!("where_bounds: {:#?}", where_bounds);
115    let items = output.last_mut();
116    if let &mut TokenTree::Group(ref mut group) = items.unwrap() {
117        // TODO: parse "[unsafe] impl trait" somehow. use syn for it maybe (after swallowing the "trait")
118        // luckily for us, there's only one thing that can come after an impl trait: a path
119        // (and optional generics).
120        // but we can't figure out how to parse the `where`.
121        //todo!();
122        let span = group.span();
123        let mut items = group.stream().into_iter().collect::<Vec<_>>();
124        let mut in_unsafe = false;
125        let mut in_impl = false;
126        let mut in_path = false;
127        let mut in_generic = false;
128        let mut in_attr = false;
129        let mut in_attr_cont = false;
130        let mut has_injected_generics = false;
131        let mut in_where = false;
132        let mut start = 0;
133        let mut found: Vec<Vec<TokenTree>> = Vec::new();
134        let mut to_remove: Vec<std::ops::Range<usize>> = Vec::new();
135        let mut generics_scratchpad = Vec::new();
136        let mut count = 0;
137        let mut trait_span: Option<Span> = None;
138        'main_loop: for (pos, tt) in (&items).into_iter().enumerate() {
139            if in_generic {
140                // collect the generics
141                let mut result = count > 0;
142                if let &TokenTree::Punct(ref punct) = tt {
143                    let c = punct.as_char();
144                    if c == '<' {
145                        count += 1;
146                        result = true;
147                    } else if c == '>' {
148                        count -= 1;
149                        if count == 0 {
150                            in_generic = false;
151                            in_path = true;
152                        }
153                    }
154                }
155                if result {
156                    generics_scratchpad.push(tt.clone());
157                    continue;
158                }
159            }
160            if in_path {
161                // inject the generics
162                if !has_injected_generics {
163                    has_injected_generics = true;
164                    if generics_scratchpad.is_empty() {
165                        found.last_mut().unwrap().extend(generics.clone());
166                    } else if generics.is_empty() {
167                        found.last_mut().unwrap().extend(generics_scratchpad.clone());
168                    } else {
169                        // need to *combine* generics. this is not exactly trivial.
170                        // thankfully we don't need to worry about defaults on impls.
171                        let mut this_generics: Generics = syn::parse(generics_scratchpad.drain(..).collect()).unwrap();
172                        let parent_generics: Generics = syn::parse(generics.clone().into_iter().collect()).unwrap();
173                        let mut target = parent_generics.params.into_pairs().chain(this_generics.params.clone().into_pairs()).collect::<Vec<_>>();
174                        target.sort_by(|a, b| {
175                            match (a.value(), b.value()) {
176                                (&GenericParam::Lifetime(_), &GenericParam::Const(_)) => Ordering::Less,
177                                (&GenericParam::Type(_), &GenericParam::Const(_)) => Ordering::Less,
178                                (&GenericParam::Lifetime(_), &GenericParam::Type(_)) => Ordering::Less,
179                                (&GenericParam::Lifetime(_), &GenericParam::Lifetime(_)) => Ordering::Equal,
180                                (&GenericParam::Type(_), &GenericParam::Type(_)) => Ordering::Equal,
181                                (&GenericParam::Const(_), &GenericParam::Const(_)) => Ordering::Equal,
182                                (&GenericParam::Type(_), &GenericParam::Lifetime(_)) => Ordering::Greater,
183                                (&GenericParam::Const(_), &GenericParam::Type(_)) => Ordering::Greater,
184                                (&GenericParam::Const(_), &GenericParam::Lifetime(_)) => Ordering::Greater,
185                            }
186                        });
187                        // just need to fix the one Pair::End in the middle of the thing.
188                        for item in &mut target {
189                            if matches!(item, syn::punctuated::Pair::End(_)) {
190                                let value = item.value().clone();
191                                *item = syn::punctuated::Pair::Punctuated(value, syn::token::Comma { spans: [trait_span.unwrap().into()] });
192                                break;
193                            }
194                        }
195                        this_generics.params = target.into_iter().collect();
196                        let new_generics = TokenStream::from(this_generics.into_token_stream());
197                        found.last_mut().unwrap().extend(new_generics);
198                    }
199                }
200                in_generic = false;
201                if let &TokenTree::Ident(ref ident) = tt {
202                    let formatted = format!("{}", ident);
203                    if count == 0 && formatted == "where" {
204                        in_path = false;
205                        in_where = true;
206                        // add "for"
207                        found.last_mut().unwrap().push(proc_macro::Ident::new("for", trait_span.unwrap()).into());
208                        // add Target
209                        found.last_mut().unwrap().extend(target.clone());
210                        // *then* add the "where" (from the impl-trait)
211                        found.last_mut().unwrap().push(tt.clone());
212                        // and the parent bounds (except the "where")
213                        if !where_bounds.is_empty() {
214                            found.last_mut().unwrap().extend((&where_bounds).into_iter().skip(1).cloned());
215                            // also make sure that there's an ',' at the correct place
216                            if let Some(&TokenTree::Punct(ref x)) = where_bounds.last() {
217                                if x.as_char() == ',' {
218                                    continue 'main_loop;
219                                }
220                            }
221                            found.last_mut().unwrap().push(proc_macro::Punct::new(',', proc_macro::Spacing::Alone).into());
222                        }
223                        continue 'main_loop;
224                    }
225                }
226                if let &TokenTree::Punct(ref punct) = tt {
227                    let c = punct.as_char();
228                    if c == '<' {
229                        count += 1;
230                    } else if c == '>' {
231                        // this is broken so just give up
232                        // FIXME better error handling
233                        if count == 0 {
234                            in_path = false;
235                            continue 'main_loop;
236                        }
237                        count -= 1;
238                    }
239                }
240                if let &TokenTree::Group(ref group) = tt {
241                    if group.delimiter() == Delimiter::Brace && count == 0 {
242                        to_remove.push(start..pos+1);
243                        // add "for"
244                        found.last_mut().unwrap().push(proc_macro::Ident::new("for", tt.span()).into());
245                        // add Target
246                        found.last_mut().unwrap().extend(target.clone());
247                        // and the parent bounds (including the "where")
248                        found.last_mut().unwrap().extend(where_bounds.clone());
249                        in_path = false;
250                        in_where = false;
251                        // fall through to add the block
252                    }
253                }
254                found.last_mut().unwrap().push(tt.clone());
255                continue 'main_loop;
256            }
257            if in_where {
258                // just try to find the block, and add all the stuff.
259                if let &TokenTree::Punct(ref punct) = tt {
260                    let c = punct.as_char();
261                    if c == '<' {
262                        count += 1;
263                    } else if c == '>' {
264                        // this is broken so just give up
265                        // FIXME better error handling
266                        if count == 0 {
267                            in_where = false;
268                            continue 'main_loop;
269                        }
270                        count -= 1;
271                    }
272                }
273                if let &TokenTree::Group(ref group) = tt {
274                    if group.delimiter() == Delimiter::Brace && count == 0 {
275                        // call it done!
276                        to_remove.push(start..pos+1);
277                        in_where = false;
278                    }
279                }
280                found.last_mut().unwrap().push(tt.clone());
281                continue 'main_loop;
282            }
283            if found.len() == to_remove.len() {
284                found.push(Vec::new());
285                in_unsafe = false;
286                in_impl = false;
287                in_where = false;
288                in_path = false;
289                in_attr_cont = false;
290                in_generic = false;
291                has_injected_generics = false;
292                count = 0;
293            }
294            match tt {
295                &TokenTree::Ident(ref ident) => {
296                    let formatted = format!("{}", ident);
297                    if formatted == "unsafe" && !in_impl {
298                        found.last_mut().unwrap().push(tt.clone());
299                        if !in_attr_cont {
300                            start = pos;
301                        }
302                        in_attr = false;
303                        in_unsafe = true;
304                        continue;
305                    } else if formatted == "impl" && !in_impl {
306                        if !in_attr_cont && !in_unsafe {
307                            start = pos;
308                        }
309                        found.last_mut().unwrap().push(tt.clone());
310                        in_unsafe = false;
311                        in_attr = false;
312                        in_impl = true;
313                        continue;
314                    } else if formatted == "trait" && in_impl {
315                        // swallowed. doesn't go into found.
316                        trait_span = Some(tt.span());
317                        in_generic = true;
318                        in_path = true;
319                        in_impl = false;
320                        has_injected_generics = false;
321                        continue;
322                    }
323                },
324                &TokenTree::Punct(ref punct) => {
325                    if punct.as_char() == '#' && !in_attr {
326                        found.last_mut().unwrap().push(tt.clone());
327                        if !in_attr_cont {
328                            start = pos;
329                        }
330                        in_attr = true;
331                        continue;
332                    }
333                }
334                &TokenTree::Group(ref group) => {
335                    if group.delimiter() == Delimiter::Bracket && in_attr {
336                        found.last_mut().unwrap().push(tt.clone());
337                        in_attr = false;
338                        in_attr_cont = true;
339                        continue;
340                    }
341                }
342                _ => {}
343            }
344            found.truncate(to_remove.len());
345            in_unsafe = false;
346            in_impl = false;
347            in_where = false;
348            in_path = false;
349            in_attr_cont = false;
350            in_generic = false;
351            has_injected_generics = false;
352            count = 0;
353        }
354        // must be iterated backwards
355        for range in to_remove.into_iter().rev() {
356            items.drain(range);
357        }
358        *group = proc_macro::Group::new(group.delimiter(), items.into_iter().collect());
359        group.set_span(span);
360        output.extend(found.into_iter().flatten());
361    }
362    drop(generics);
363    drop(target);
364    drop(where_bounds);
365    //eprintln!("attributes: {:#?}", attributes);
366    //eprintln!("OUTPUT: {:#?}", output);
367    //eprintln!("OUTPUT: {}", (&output).into_iter().cloned().collect::<TokenStream>());
368    attributes.into_iter().chain(output.into_iter()).collect()
369}