impl_macro_internal/
lib.rs

1#![allow(clippy::panic)]
2#![allow(clippy::unwrap_used)]
3#![allow(clippy::expect_used)]
4
5use proc_macro2::TokenStream;
6use proc_macro2::TokenTree;
7use proc_macro2::Delimiter;
8use quote::quote;
9
10// ===============
11// === Helpers ===
12// ===============
13
14fn is_brace(tt: &TokenTree) -> bool {
15    if let TokenTree::Group(group) = tt {
16        group.delimiter() == Delimiter::Brace
17    } else {
18        false
19    }
20}
21
22fn is_for(tt: &TokenTree) -> bool {
23    if let TokenTree::Ident(ident) = tt {
24        ident.to_string() == "for"
25    } else {
26        false
27    }
28}
29
30fn is_where(tt: &TokenTree) -> bool {
31    if let TokenTree::Ident(ident) = tt {
32        ident.to_string() == "where"
33    } else {
34        false
35    }
36}
37
38fn is_comma(tt: &TokenTree) -> bool {
39    if let TokenTree::Punct(punct) = tt {
40        punct.as_char() == ','
41    } else {
42        false
43    }
44}
45
46fn is_open_angle(tt: &TokenTree) -> bool {
47    if let TokenTree::Punct(punct) = tt {
48        punct.as_char() == '<'
49    } else {
50        false
51    }
52}
53
54fn is_close_angle(tt: &TokenTree) -> bool {
55    if let TokenTree::Punct(punct) = tt {
56        punct.as_char() == '>'
57    } else {
58        false
59    }
60}
61
62// =================
63// === Imp Macro ===
64// =================
65
66#[proc_macro]
67pub fn imp(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
68    let input: TokenStream = input.into();
69
70    let mut prefix = Vec::new();
71    let mut targets = Vec::new();
72    let mut where_clause = Vec::new();
73    let mut body = TokenStream::new();
74    let mut tokens = input.into_iter().peekable();
75
76    while !tokens.peek().map(is_for).unwrap_or_default() {
77        prefix.push(tokens.next().unwrap());
78    }
79
80    if tokens.peek().map(is_for).unwrap_or_default() {
81        tokens.next();
82    } else {
83        panic!("Expected 'for' keyword");
84    }
85
86    // Extract for-targets (handle multiple "for ... where ..." patterns)
87    while let Some(tt) = tokens.peek() {
88        match tt {
89            _ if is_where(tt) => {
90                let mut clause = Vec::new();
91                clause.push(tokens.next().unwrap());
92                while !tokens.peek().map(is_brace).unwrap_or_default() {
93                    clause.push(tokens.next().unwrap());
94                }
95                where_clause.push(TokenStream::from_iter(clause));
96            }
97            _ if is_comma(tt) => {
98                tokens.next();
99            }
100            _ if is_for(tt) => {
101                tokens.next();
102            }
103            _ if is_brace(tt) => {
104                break;
105            }
106            _ => {
107                let mut depth: i32 = 0;
108                let mut target = Vec::new();
109                target.push(tokens.next().unwrap());
110                while let Some(tt2) = tokens.peek() {
111                    if is_brace(tt2) || is_for(tt2) || is_where(tt2) {
112                        break;
113                    }
114                    if is_open_angle(tt2) {
115                        depth += 1;
116                    } else if is_close_angle(tt2) {
117                        depth -= 1;
118                    }
119                    if depth == 0 && is_comma(tt2) {
120                        break;
121                    }
122                    target.push(tokens.next().unwrap());
123                }
124                if target.last().map(is_comma).unwrap_or_default() {
125                    target.pop();
126                }
127                targets.push(TokenStream::from_iter(target));
128            }
129        }
130    }
131
132    if let Some(tt) = tokens.peek() {
133        if is_brace(tt) {
134            body = tt.clone().into();
135            tokens.next();
136        }
137    } else {
138        panic!("Unexpected end of input");
139    }
140
141    let prefix = TokenStream::from_iter(prefix);
142    let where_clause = TokenStream::from_iter(where_clause);
143    let impls = targets.iter().map(|target| {
144        quote! { impl #prefix for #target #where_clause #body }
145    }).collect::<Vec<_>>();
146
147    let output = quote! {
148        #(#impls)*
149    };
150
151    // println!("{}", output);
152    output.into()
153}