using_param/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use proc_macro::*;
4use proc_macro::Spacing::*;
5use proc_macro_tool::*;
6
7#[proc_macro_attribute]
8pub fn using_param(attr: TokenStream, item: TokenStream) -> TokenStream {
9    process(Type::Param, attr, item)
10}
11
12#[proc_macro_attribute]
13pub fn using_generic(attr: TokenStream, item: TokenStream) -> TokenStream {
14    process(Type::Generic, attr, item)
15}
16
17#[proc_macro_attribute]
18pub fn using_return(attr: TokenStream, item: TokenStream) -> TokenStream {
19    process(Type::Ret, attr, item)
20}
21
22#[derive(Debug, Clone, Copy)]
23enum Type { Param, Generic, Ret }
24
25fn process(ty: Type, attr: TokenStream, item: TokenStream) -> TokenStream {
26    let cfg = match ty {
27        Type::Param => param_cfg(attr),
28        Type::Generic => generic_cfg(attr),
29        Type::Ret => Conf { return_type: attr, ..Default::default() },
30    };
31
32    let mut iter = item.parse_iter();
33    let mut out = TokenStream::new();
34    out.extend(iter.next_attributes());
35
36    if !iter.peek_is(|i| i.is_keyword("impl")) {
37        err!("expected impl keyword", iter.span())
38    }
39    let block = iter.reduce(|a, b| {
40        out.push(a);
41        b
42    }).unwrap();
43
44    let items = block.to_brace_stream().unwrap();
45    match process_impl_block(&cfg, items) {
46        Err(e) => e,
47        Ok(b) => {
48            out.push(b.grouped_brace().tt());
49            out
50        }
51    }
52}
53
54#[derive(Debug, Default)]
55struct Conf {
56    params_after: bool,
57    param: TokenStream,
58    generics: TokenStream,
59    generics_after: bool,
60    return_type: TokenStream,
61}
62
63fn fn_generic(iter: &mut ParseIter<impl Iterator<Item = TokenTree>>) -> TokenStream {
64    let mut out = TokenStream::new();
65
66    loop {
67        if let Some(arrow) = iter.next_puncts("->") {
68            out.extend(arrow);
69        } else if iter.is_puncts(">")
70            && iter.peek_i(1).is_none_or(|t| t.is_delimiter_paren())
71        {
72            break;
73        } else if let Some(tt) = iter.next() {
74            out.push(tt);
75        } else {
76            break;
77        }
78    }
79
80    out
81}
82
83fn param_cfg(attr: TokenStream) -> Conf {
84    let mut iter = attr.parse_iter();
85
86    Conf {
87        params_after: iter.next_if(|t| t.is_punch(',')).is_some(),
88        param: iter.collect(),
89        ..Default::default()
90    }
91}
92
93fn generic_cfg(attr: TokenStream) -> Conf {
94    let mut iter = attr.parse_iter();
95
96    Conf {
97        generics_after: iter.next_if(|t| t.is_punch(',')).is_some(),
98        generics: iter.collect(),
99        ..Default::default()
100    }
101}
102
103fn join_with_comma(a: TokenStream, b: TokenStream) -> TokenStream {
104    if a.is_empty() {
105        return b;
106    }
107    if b.is_empty() {
108        return a;
109    }
110
111    let mut left = a.into_iter().collect::<Vec<_>>();
112    let right = b.into_iter().collect::<Vec<_>>();
113
114    left.pop_if(|t| t.is_punch(','));
115
116    if !left.is_empty() {
117        left.push(','.punct(Alone).tt());
118    }
119
120    let rcom = right.first().is_some_and(|t| t.is_punch(','));
121    stream(left.into_iter().chain(right.into_iter().skip(rcom.into())))
122}
123
124fn self_param(iter: &mut ParseIter<impl Iterator<Item = TokenTree>>) -> TokenStream {
125    macro_rules! ok {
126        () => {
127            return iter
128                .split_puncts_include(",")
129                .unwrap_or_else(|| iter.collect());
130        };
131    }
132    if iter.peek_is(|t| t.is_punch('&')) {
133        if iter.peek_i_is(1, |t| t.is_punch('\''))
134        && iter.peek_i_is(2, |t| t.is_ident())
135        {
136            if iter.peek_i_is(3, |t| t.is_keyword("self")) {
137                ok!();
138            }
139
140            if iter.peek_i_is(3, |t| t.is_keyword("mut"))
141            && iter.peek_i_is(4, |t| t.is_keyword("self"))
142            {
143                ok!();
144            }
145        }
146
147        if iter.peek_i_is(1, |t| t.is_keyword("self")) {
148            ok!();
149        }
150
151        if iter.peek_i_is(1, |t| t.is_keyword("mut"))
152        && iter.peek_i_is(2, |t| t.is_keyword("self"))
153        {
154            ok!();
155        }
156    } else if iter.peek_is(|t| t.is_keyword("self")) {
157        ok!();
158    }
159    TokenStream::new()
160}
161
162fn process_impl_block(
163    cfg: &Conf,
164    items: TokenStream,
165) -> Result<TokenStream, TokenStream> {
166    let mut out = TokenStream::new();
167    let mut iter = items.parse_iter();
168
169    out.extend(iter.next_outer_attributes());
170
171    while iter.peek().is_some() {
172        out.extend(iter.next_attributes());
173        out.extend(iter.next_vis());
174
175        if iter.peek_is(|t| t.is_keyword("fn"))
176            && iter.peek_i_is(1, |t| t.is_ident())
177        {
178            out.extend(iter.next_tts::<2>());
179
180            if iter.push_if_to(&mut out, |t| t.is_punch('<')) {
181                let generic = fn_generic(&mut iter);
182
183                out.add(if cfg.generics_after {
184                    join_with_comma(generic, cfg.generics.clone())
185                } else {
186                    join_with_comma(cfg.generics.clone(), generic)
187                });
188
189                iter.push_if_to(&mut out, |t| t.is_punch('>'));
190            } else if !cfg.generics.is_empty() {
191                out.push('<'.punct(Alone).tt());
192                out.add(cfg.generics.clone());
193                out.push('>'.punct(Alone).tt());
194            }
195
196            if let Some(TokenTree::Group(paren))
197                = iter.next_if(|t| t.is_delimiter_paren())
198            {
199
200                let params_group = paren.map(|paren| {
201                    let mut iter = paren.parse_iter();
202                    let mut self_ = self_param(&mut iter);
203                    let mut param = cfg.param.clone().parse_iter();
204                    let other_self = self_param(&mut param);
205
206                    if self_.is_empty() {
207                        self_ = other_self;
208                    }
209
210                    join_with_comma(self_, if cfg.params_after {
211                        join_with_comma(iter.collect(), param.collect())
212                    } else {
213                        join_with_comma(param.collect(), iter.collect())
214                    })
215                });
216                out.push(params_group.tt());
217
218                if !cfg.return_type.is_empty() && !iter.is_puncts("->") {
219                    out.push('-'.punct(Joint).tt());
220                    out.push('>'.punct(Alone).tt());
221                    out.add(cfg.return_type.clone());
222                }
223            }
224        } else {
225            out.push(iter.next().unwrap());
226        }
227    }
228
229    Ok(out)
230}
231
232/// ```
233/// using_param::__test_join! {}
234/// ```
235#[doc(hidden)]
236#[proc_macro]
237pub fn __test_join(_: TokenStream) -> TokenStream {
238    let datas = [
239        ("", "", ""),
240        ("a", "", "a"),
241        ("a,", "", "a,"),
242        ("", "a", "a"),
243        ("", "a,", "a,"),
244        ("a", "b", "a, b"),
245        ("a,", "b", "a, b"),
246        ("a,", "b,", "a, b,"),
247        ("a,", "b,", "a, b,"),
248        ("a", "b,", "a, b,"),
249    ];
250    for (a, b, expected) in datas {
251        let out = join_with_comma(a.parse().unwrap(), b.parse().unwrap());
252        assert_eq!(out.to_string(), expected, "{a:?}, {b:?}");
253    }
254    TokenStream::new()
255}
256
257/// ```
258/// using_param::__test_before! {}
259/// ```
260#[doc(hidden)]
261#[proc_macro]
262pub fn __test_before(_: TokenStream) -> TokenStream {
263    let out = using_param("ctx: i32".parse().unwrap(), "
264impl Foo {
265    #[doc(hidden)]
266    pub fn foo(&self, s: &str) -> &str {
267        s
268    }
269    pub fn bar(&self) -> i32 {
270        ctx
271    }
272    pub fn baz() -> i32 {
273        ctx
274    }
275    pub fn f(self: &Self) -> i32 {
276        ctx
277    }
278    pub fn a(x: i32) -> i32 {
279        ctx+x
280    }
281    pub fn b(&mut self, a: i32, b: i32) -> i32 {
282        ctx+a+b
283    }
284    pub fn c(&'a mut self, a: i32, b: i32) -> i32 {
285        ctx+a+b
286    }
287    pub fn d(&'static mut self, a: i32, b: i32) -> i32 {
288        ctx+a+b
289    }
290}
291    ".parse().unwrap()).to_string();
292    assert_eq!(out, "
293impl Foo {
294    #[doc(hidden)]
295    pub fn foo(& self, ctx : i32, s : & str) -> & str {
296        s
297    }
298    pub fn bar(& self, ctx : i32) -> i32 {
299        ctx
300    }
301    pub fn baz(ctx : i32) -> i32 {
302        ctx
303    }
304    pub fn f(self : & Self, ctx : i32) -> i32 {
305        ctx
306    }
307    pub fn a(ctx : i32, x : i32) -> i32 {
308        ctx+x
309    }
310    pub fn b(& mut self, ctx : i32, a : i32, b : i32) -> i32 {
311        ctx+a+b
312    }
313    pub fn c(& 'a mut self, ctx : i32, a : i32, b : i32) -> i32 {
314        ctx+a+b
315    }
316    pub fn d(& 'static mut self, ctx : i32, a : i32, b : i32) -> i32 {
317        ctx+a+b
318    }
319}
320    ".parse::<TokenStream>().unwrap().to_string());
321    TokenStream::new()
322}
323
324
325/// ```
326/// using_param::__test_after! {}
327/// ```
328#[doc(hidden)]
329#[proc_macro]
330pub fn __test_after(_: TokenStream) -> TokenStream {
331    let out = using_param(", ctx: i32".parse().unwrap(), "
332impl Foo {
333    #[doc(hidden)]
334    pub fn foo(&self, s: &str) -> &str {
335        s
336    }
337    pub fn bar(&self) -> i32 {
338        ctx
339    }
340    pub fn baz() -> i32 {
341        ctx
342    }
343    pub fn f(self: &Self) -> i32 {
344        ctx
345    }
346    pub fn a(x: i32) -> i32 {
347        ctx+x
348    }
349    pub fn b(&mut self, a: i32, b: i32) -> i32 {
350        ctx+a+b
351    }
352}
353    ".parse().unwrap()).to_string();
354    assert_eq!(out, "
355impl Foo {
356    #[doc(hidden)]
357    pub fn foo(& self, s : & str, ctx : i32) -> & str {
358        s
359    }
360    pub fn bar(& self, ctx : i32) -> i32 {
361        ctx
362    }
363    pub fn baz(ctx : i32) -> i32 {
364        ctx
365    }
366    pub fn f(self : & Self, ctx : i32) -> i32 {
367        ctx
368    }
369    pub fn a(x : i32, ctx : i32) -> i32 {
370        ctx+x
371    }
372    pub fn b(& mut self, a : i32, b : i32, ctx : i32) -> i32 {
373        ctx+a+b
374    }
375}
376    ".parse::<TokenStream>().unwrap().to_string());
377    TokenStream::new()
378}
379
380
381/// ```
382/// using_param::__test_self_param! {}
383/// ```
384#[doc(hidden)]
385#[proc_macro]
386pub fn __test_self_param(_: TokenStream) -> TokenStream {
387    let out = using_param("&'static self, ctx: i32".parse().unwrap(), "
388impl Foo {
389    pub fn foo(&self, s: &str) -> &str {
390        s
391    }
392    pub fn bar(&mut self) -> i32 {
393        ctx
394    }
395    pub fn baz(self: &Self) -> i32 {
396        ctx
397    }
398    pub fn a(this: &Self) -> i32 {
399        ctx
400    }
401}
402    ".parse().unwrap()).to_string();
403    assert_eq!(out, "
404impl Foo {
405    pub fn foo(& self, ctx : i32, s : & str) -> & str {
406        s
407    }
408    pub fn bar(& mut self, ctx : i32) -> i32 {
409        ctx
410    }
411    pub fn baz(self : & Self, ctx : i32) -> i32 {
412        ctx
413    }
414    pub fn a(& 'static self, ctx : i32, this : & Self) -> i32 {
415        ctx
416    }
417}
418    ".parse::<TokenStream>().unwrap().to_string());
419    TokenStream::new()
420}
421
422
423/// ```
424/// using_param::__test_self_param! {}
425/// ```
426#[doc(hidden)]
427#[proc_macro]
428pub fn __test_self_param_after(_: TokenStream) -> TokenStream {
429    let out = using_param(", &'static self, ctx: i32".parse().unwrap(), "
430impl Foo {
431    pub fn foo(&self, s: &str) -> &str {
432        s
433    }
434    pub fn bar(&mut self) -> i32 {
435        ctx
436    }
437    pub fn baz(self: &Self) -> i32 {
438        ctx
439    }
440    pub fn a(this: &Self) -> i32 {
441        ctx
442    }
443}
444    ".parse().unwrap()).to_string();
445    assert_eq!(out, "
446impl Foo {
447    pub fn foo(& self, s : & str, ctx : i32) -> & str {
448        s
449    }
450    pub fn bar(& mut self, ctx : i32) -> i32 {
451        ctx
452    }
453    pub fn baz(self : & Self, ctx : i32) -> i32 {
454        ctx
455    }
456    pub fn a(& 'static self, this : & Self, ctx : i32) -> i32 {
457        ctx
458    }
459}
460    ".parse::<TokenStream>().unwrap().to_string());
461    TokenStream::new()
462}
463
464
465/// ```
466/// using_param::__test_generic_before! {}
467/// ```
468#[doc(hidden)]
469#[proc_macro]
470pub fn __test_generic_before(_: TokenStream) -> TokenStream {
471    let out = using_generic("'a".parse().unwrap(), "
472impl Foo {
473    fn foo() {}
474    fn bar<'b>() {}
475    fn baz<'b, T>() {}
476}
477    ".parse().unwrap()).to_string();
478    assert_eq!(out, "
479impl Foo {
480    fn foo < 'a > () {}
481    fn bar < 'a, 'b > () {}
482    fn baz < 'a, 'b, T > () {}
483}
484    ".parse::<TokenStream>().unwrap().to_string());
485    TokenStream::new()
486}
487
488
489/// ```
490/// using_param::__test_generic_after! {}
491/// ```
492#[doc(hidden)]
493#[proc_macro]
494pub fn __test_generic_after(_: TokenStream) -> TokenStream {
495    let out = using_generic(", 'a".parse().unwrap(), "
496impl Foo {
497    fn foo() {}
498    fn bar<'b>() {}
499}
500    ".parse().unwrap()).to_string();
501    assert_eq!(out, "
502impl Foo {
503    fn foo < 'a > () {}
504    fn bar < 'b, 'a > () {}
505}
506    ".parse::<TokenStream>().unwrap().to_string());
507    TokenStream::new()
508}
509
510
511/// ```
512/// using_param::__test_other_assoc_item! {}
513/// ```
514#[doc(hidden)]
515#[proc_macro]
516pub fn __test_other_assoc_item(_: TokenStream) -> TokenStream {
517    let out = using_param("ctx: i32".parse().unwrap(), "
518impl Foo {
519    pub const M: usize = 3;
520    pub type C = i32;
521    some_macro!();
522    fn foo() {}
523    fn bar(m: i32) { m+ctx }
524    fn baz(self, m: i32) { m+ctx }
525}
526    ".parse().unwrap()).to_string();
527    assert_eq!(out, "
528impl Foo {
529    pub const M : usize = 3;
530    pub type C = i32;
531    some_macro! ();
532    fn foo(ctx : i32) {}
533    fn bar(ctx : i32, m : i32) { m+ctx }
534    fn baz(self, ctx : i32, m : i32) { m+ctx }
535}
536    ".parse::<TokenStream>().unwrap().to_string());
537    TokenStream::new()
538}
539
540
541/// ```
542/// using_param::__test_return_type! {}
543/// ```
544#[doc(hidden)]
545#[proc_macro]
546pub fn __test_return_type(_: TokenStream) -> TokenStream {
547    let out = using_return("i32".parse().unwrap(), "
548impl Foo {
549    pub const M: usize = 3;
550    pub type C = i32;
551    some_macro!();
552    fn foo() {}
553    fn bar(m: i32) { m+ctx }
554    fn baz(self, m: i32) { m+ctx }
555    fn xfoo() -> u32 {}
556    fn xbar(m: i32) -> u32 { m+ctx }
557    fn xbaz(self, m: i32) -> u32 { m+ctx }
558}
559    ".parse().unwrap()).to_string();
560    assert_eq!(out, "
561impl Foo {
562    pub const M : usize = 3;
563    pub type C = i32;
564    some_macro! ();
565    fn foo() -> i32 {}
566    fn bar(m : i32) -> i32 { m+ctx }
567    fn baz(self, m : i32) -> i32 { m+ctx }
568    fn xfoo() -> u32 {}
569    fn xbar(m : i32) -> u32 { m+ctx }
570    fn xbaz(self, m : i32) -> u32 { m+ctx }
571}
572    ".parse::<TokenStream>().unwrap().to_string());
573    TokenStream::new()
574}