using_param/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use proc_macro::*;
4use proc_macro::Spacing::*;
5use proc_macro_tool::*;
6
7/// Apply on impl block, add to parameter
8///
9/// # Examples
10///
11/// ```
12/// struct Foo;
13/// #[using_param::using_param(id: i32)]
14/// impl Foo {
15///     fn foo(&self, s: &str) {}
16///     // like this:
17///     // fn foo(&self, id: i32, s: &str) {}
18/// }
19/// #[using_param::using_param(, patch: i32)]
20/// impl Foo {
21///     fn bar(&self, s: &str) {}
22///     // like this:
23///     // fn bar(&self, s: &str, patch: i32) {}
24/// }
25/// ```
26#[proc_macro_attribute]
27pub fn using_param(attr: TokenStream, item: TokenStream) -> TokenStream {
28    process(Type::Param, attr, item)
29}
30
31/// Apply on impl block, add to generics parameter
32///
33/// # Examples
34///
35/// ```
36/// struct Foo;
37/// #[using_param::using_generic(T)]
38/// impl Foo {
39///     fn foo<U>(&self, t: T, u: U) {}
40///     // like this:
41///     // fn foo<T, U>(&self, t: T, u: U) {}
42/// }
43/// #[using_param::using_generic(, T)]
44/// impl Foo {
45///     fn bar<U>(&self, u: U, t: T) {}
46///     // like this:
47///     // fn bar<U, T>(&self, u: U, t: T) {}
48/// }
49/// ```
50#[proc_macro_attribute]
51pub fn using_generic(attr: TokenStream, item: TokenStream) -> TokenStream {
52    process(Type::Generic, attr, item)
53}
54
55/// Apply on impl block, replace the default return type
56///
57/// # Examples
58///
59/// ```
60/// struct Foo;
61/// #[using_param::using_return(i32)]
62/// impl Foo {
63///     fn foo() { 2 }
64///     // like this:
65///     // fn foo() -> i32 { 2 }
66///
67///     fn bar() -> () {}
68///     // like this:
69///     // fn bar() {}
70/// }
71/// ```
72#[proc_macro_attribute]
73pub fn using_return(attr: TokenStream, item: TokenStream) -> TokenStream {
74    process(Type::Ret, attr, item)
75}
76
77#[derive(Debug, Clone, Copy)]
78enum Type { Param, Generic, Ret }
79
80fn process(ty: Type, attr: TokenStream, item: TokenStream) -> TokenStream {
81    let cfg = match ty {
82        Type::Param => param_cfg(attr),
83        Type::Generic => generic_cfg(attr),
84        Type::Ret => Conf { return_type: attr, ..Default::default() },
85    };
86
87    let mut iter = item.parse_iter();
88    let mut out = TokenStream::new();
89    out.extend(iter.next_attributes());
90
91    if !iter.peek_is(|i| i.is_keyword("impl")) {
92        err!("expected impl keyword", iter.span())
93    }
94    let block = iter.reduce(|a, b| {
95        out.push(a);
96        b
97    }).unwrap();
98
99    let items = block.to_brace_stream().unwrap();
100    match process_impl_block(&cfg, items) {
101        Err(e) => e,
102        Ok(b) => {
103            out.push(b.grouped_brace().tt());
104            out
105        }
106    }
107}
108
109#[derive(Debug, Default)]
110struct Conf {
111    params_after: bool,
112    param: TokenStream,
113    generics: TokenStream,
114    generics_after: bool,
115    return_type: TokenStream,
116}
117
118fn fn_generic(iter: &mut ParseIter<impl Iterator<Item = TokenTree>>) -> TokenStream {
119    let mut out = TokenStream::new();
120
121    loop {
122        if let Some(arrow) = iter.next_puncts("->") {
123            out.extend(arrow);
124        } else if iter.is_puncts(">")
125            && iter.peek_i(1).is_none_or(|t| t.is_delimiter_paren())
126        {
127            break;
128        } else if let Some(tt) = iter.next() {
129            out.push(tt);
130        } else {
131            break;
132        }
133    }
134
135    out
136}
137
138fn param_cfg(attr: TokenStream) -> Conf {
139    let mut iter = attr.parse_iter();
140
141    Conf {
142        params_after: iter.next_if(|t| t.is_punch(',')).is_some(),
143        param: iter.collect(),
144        ..Default::default()
145    }
146}
147
148fn generic_cfg(attr: TokenStream) -> Conf {
149    let mut iter = attr.parse_iter();
150
151    Conf {
152        generics_after: iter.next_if(|t| t.is_punch(',')).is_some(),
153        generics: iter.collect(),
154        ..Default::default()
155    }
156}
157
158fn join_with_comma(a: TokenStream, b: TokenStream) -> TokenStream {
159    if a.is_empty() {
160        return b;
161    }
162    if b.is_empty() {
163        return a;
164    }
165
166    let mut left = a.into_iter().collect::<Vec<_>>();
167    let right = b.into_iter().collect::<Vec<_>>();
168
169    left.pop_if(|t| t.is_punch(','));
170
171    if !left.is_empty() {
172        left.push(','.punct(Alone).tt());
173    }
174
175    let rcom = right.first().is_some_and(|t| t.is_punch(','));
176    stream(left.into_iter().chain(right.into_iter().skip(rcom.into())))
177}
178
179fn self_param(iter: &mut ParseIter<impl Iterator<Item = TokenTree>>) -> TokenStream {
180    macro_rules! ok {
181        () => {
182            return iter
183                .split_puncts_include(",")
184                .unwrap_or_else(|| iter.collect());
185        };
186    }
187    if iter.peek_is(|t| t.is_punch('&')) {
188        if iter.peek_i_is(1, |t| t.is_punch('\''))
189        && iter.peek_i_is(2, |t| t.is_ident())
190        {
191            if iter.peek_i_is(3, |t| t.is_keyword("self")) {
192                ok!();
193            }
194
195            if iter.peek_i_is(3, |t| t.is_keyword("mut"))
196            && iter.peek_i_is(4, |t| t.is_keyword("self"))
197            {
198                ok!();
199            }
200        }
201
202        if iter.peek_i_is(1, |t| t.is_keyword("self")) {
203            ok!();
204        }
205
206        if iter.peek_i_is(1, |t| t.is_keyword("mut"))
207        && iter.peek_i_is(2, |t| t.is_keyword("self"))
208        {
209            ok!();
210        }
211    } else if iter.peek_is(|t| t.is_keyword("self")) {
212        ok!();
213    }
214    TokenStream::new()
215}
216
217fn process_impl_block(
218    cfg: &Conf,
219    items: TokenStream,
220) -> Result<TokenStream, TokenStream> {
221    let mut out = TokenStream::new();
222    let mut iter = items.parse_iter();
223
224    out.extend(iter.next_outer_attributes());
225
226    while iter.peek().is_some() {
227        out.extend(iter.next_attributes());
228        out.extend(iter.next_vis());
229
230        if iter.peek_is(|t| t.is_keyword("fn"))
231            && iter.peek_i_is(1, |t| t.is_ident())
232        {
233            out.extend(iter.next_tts::<2>());
234
235            if iter.push_if_to(&mut out, |t| t.is_punch('<')) {
236                let generic = fn_generic(&mut iter);
237
238                out.add(if cfg.generics_after {
239                    join_with_comma(generic, cfg.generics.clone())
240                } else {
241                    join_with_comma(cfg.generics.clone(), generic)
242                });
243
244                iter.push_if_to(&mut out, |t| t.is_punch('>'));
245            } else if !cfg.generics.is_empty() {
246                out.push('<'.punct(Alone).tt());
247                out.add(cfg.generics.clone());
248                out.push('>'.punct(Alone).tt());
249            }
250
251            if let Some(TokenTree::Group(paren))
252                = iter.next_if(|t| t.is_delimiter_paren())
253            {
254
255                let params_group = paren.map(|paren| {
256                    let mut iter = paren.parse_iter();
257                    let mut self_ = self_param(&mut iter);
258                    let mut param = cfg.param.clone().parse_iter();
259                    let other_self = self_param(&mut param);
260
261                    if self_.is_empty() {
262                        self_ = other_self;
263                    }
264
265                    join_with_comma(self_, if cfg.params_after {
266                        join_with_comma(iter.collect(), param.collect())
267                    } else {
268                        join_with_comma(param.collect(), iter.collect())
269                    })
270                });
271                out.push(params_group.tt());
272
273                if !cfg.return_type.is_empty() && !iter.is_puncts("->") {
274                    out.push('-'.punct(Joint).tt());
275                    out.push('>'.punct(Alone).tt());
276                    out.add(cfg.return_type.clone());
277                }
278            }
279        } else {
280            out.push(iter.next().unwrap());
281        }
282    }
283
284    Ok(out)
285}
286
287/// ```
288/// using_param::__test_join! {}
289/// ```
290#[doc(hidden)]
291#[proc_macro]
292pub fn __test_join(_: TokenStream) -> TokenStream {
293    let datas = [
294        ("", "", ""),
295        ("a", "", "a"),
296        ("a,", "", "a,"),
297        ("", "a", "a"),
298        ("", "a,", "a,"),
299        ("a", "b", "a, b"),
300        ("a,", "b", "a, b"),
301        ("a,", "b,", "a, b,"),
302        ("a,", "b,", "a, b,"),
303        ("a", "b,", "a, b,"),
304    ];
305    for (a, b, expected) in datas {
306        let out = join_with_comma(a.parse().unwrap(), b.parse().unwrap());
307        assert_eq!(out.to_string(), expected, "{a:?}, {b:?}");
308    }
309    TokenStream::new()
310}
311
312/// ```
313/// using_param::__test_before! {}
314/// ```
315#[doc(hidden)]
316#[proc_macro]
317pub fn __test_before(_: TokenStream) -> TokenStream {
318    let out = using_param("ctx: i32".parse().unwrap(), "
319impl Foo {
320    #[doc(hidden)]
321    pub fn foo(&self, s: &str) -> &str {
322        s
323    }
324    pub fn bar(&self) -> i32 {
325        ctx
326    }
327    pub fn baz() -> i32 {
328        ctx
329    }
330    pub fn f(self: &Self) -> i32 {
331        ctx
332    }
333    pub fn a(x: i32) -> i32 {
334        ctx+x
335    }
336    pub fn b(&mut self, a: i32, b: i32) -> i32 {
337        ctx+a+b
338    }
339    pub fn c(&'a mut self, a: i32, b: i32) -> i32 {
340        ctx+a+b
341    }
342    pub fn d(&'static mut self, a: i32, b: i32) -> i32 {
343        ctx+a+b
344    }
345}
346    ".parse().unwrap()).to_string();
347    assert_eq!(out, "
348impl Foo {
349    #[doc(hidden)]
350    pub fn foo(& self, ctx : i32, s : & str) -> & str {
351        s
352    }
353    pub fn bar(& self, ctx : i32) -> i32 {
354        ctx
355    }
356    pub fn baz(ctx : i32) -> i32 {
357        ctx
358    }
359    pub fn f(self : & Self, ctx : i32) -> i32 {
360        ctx
361    }
362    pub fn a(ctx : i32, x : i32) -> i32 {
363        ctx+x
364    }
365    pub fn b(& mut self, ctx : i32, a : i32, b : i32) -> i32 {
366        ctx+a+b
367    }
368    pub fn c(& 'a mut self, ctx : i32, a : i32, b : i32) -> i32 {
369        ctx+a+b
370    }
371    pub fn d(& 'static mut self, ctx : i32, a : i32, b : i32) -> i32 {
372        ctx+a+b
373    }
374}
375    ".parse::<TokenStream>().unwrap().to_string());
376    TokenStream::new()
377}
378
379
380/// ```
381/// using_param::__test_after! {}
382/// ```
383#[doc(hidden)]
384#[proc_macro]
385pub fn __test_after(_: TokenStream) -> TokenStream {
386    let out = using_param(", ctx: i32".parse().unwrap(), "
387impl Foo {
388    #[doc(hidden)]
389    pub fn foo(&self, s: &str) -> &str {
390        s
391    }
392    pub fn bar(&self) -> i32 {
393        ctx
394    }
395    pub fn baz() -> i32 {
396        ctx
397    }
398    pub fn f(self: &Self) -> i32 {
399        ctx
400    }
401    pub fn a(x: i32) -> i32 {
402        ctx+x
403    }
404    pub fn b(&mut self, a: i32, b: i32) -> i32 {
405        ctx+a+b
406    }
407}
408    ".parse().unwrap()).to_string();
409    assert_eq!(out, "
410impl Foo {
411    #[doc(hidden)]
412    pub fn foo(& self, s : & str, ctx : i32) -> & str {
413        s
414    }
415    pub fn bar(& self, ctx : i32) -> i32 {
416        ctx
417    }
418    pub fn baz(ctx : i32) -> i32 {
419        ctx
420    }
421    pub fn f(self : & Self, ctx : i32) -> i32 {
422        ctx
423    }
424    pub fn a(x : i32, ctx : i32) -> i32 {
425        ctx+x
426    }
427    pub fn b(& mut self, a : i32, b : i32, ctx : i32) -> i32 {
428        ctx+a+b
429    }
430}
431    ".parse::<TokenStream>().unwrap().to_string());
432    TokenStream::new()
433}
434
435
436/// ```
437/// using_param::__test_self_param! {}
438/// ```
439#[doc(hidden)]
440#[proc_macro]
441pub fn __test_self_param(_: TokenStream) -> TokenStream {
442    let out = using_param("&'static self, ctx: i32".parse().unwrap(), "
443impl Foo {
444    pub fn foo(&self, s: &str) -> &str {
445        s
446    }
447    pub fn bar(&mut self) -> i32 {
448        ctx
449    }
450    pub fn baz(self: &Self) -> i32 {
451        ctx
452    }
453    pub fn a(this: &Self) -> i32 {
454        ctx
455    }
456}
457    ".parse().unwrap()).to_string();
458    assert_eq!(out, "
459impl Foo {
460    pub fn foo(& self, ctx : i32, s : & str) -> & str {
461        s
462    }
463    pub fn bar(& mut self, ctx : i32) -> i32 {
464        ctx
465    }
466    pub fn baz(self : & Self, ctx : i32) -> i32 {
467        ctx
468    }
469    pub fn a(& 'static self, ctx : i32, this : & Self) -> i32 {
470        ctx
471    }
472}
473    ".parse::<TokenStream>().unwrap().to_string());
474    TokenStream::new()
475}
476
477
478/// ```
479/// using_param::__test_self_param! {}
480/// ```
481#[doc(hidden)]
482#[proc_macro]
483pub fn __test_self_param_after(_: TokenStream) -> TokenStream {
484    let out = using_param(", &'static self, ctx: i32".parse().unwrap(), "
485impl Foo {
486    pub fn foo(&self, s: &str) -> &str {
487        s
488    }
489    pub fn bar(&mut self) -> i32 {
490        ctx
491    }
492    pub fn baz(self: &Self) -> i32 {
493        ctx
494    }
495    pub fn a(this: &Self) -> i32 {
496        ctx
497    }
498}
499    ".parse().unwrap()).to_string();
500    assert_eq!(out, "
501impl Foo {
502    pub fn foo(& self, s : & str, ctx : i32) -> & str {
503        s
504    }
505    pub fn bar(& mut self, ctx : i32) -> i32 {
506        ctx
507    }
508    pub fn baz(self : & Self, ctx : i32) -> i32 {
509        ctx
510    }
511    pub fn a(& 'static self, this : & Self, ctx : i32) -> i32 {
512        ctx
513    }
514}
515    ".parse::<TokenStream>().unwrap().to_string());
516    TokenStream::new()
517}
518
519
520/// ```
521/// using_param::__test_generic_before! {}
522/// ```
523#[doc(hidden)]
524#[proc_macro]
525pub fn __test_generic_before(_: TokenStream) -> TokenStream {
526    let out = using_generic("'a".parse().unwrap(), "
527impl Foo {
528    fn foo() {}
529    fn bar<'b>() {}
530    fn baz<'b, T>() {}
531}
532    ".parse().unwrap()).to_string();
533    assert_eq!(out, "
534impl Foo {
535    fn foo < 'a > () {}
536    fn bar < 'a, 'b > () {}
537    fn baz < 'a, 'b, T > () {}
538}
539    ".parse::<TokenStream>().unwrap().to_string());
540    TokenStream::new()
541}
542
543
544/// ```
545/// using_param::__test_generic_after! {}
546/// ```
547#[doc(hidden)]
548#[proc_macro]
549pub fn __test_generic_after(_: TokenStream) -> TokenStream {
550    let out = using_generic(", 'a".parse().unwrap(), "
551impl Foo {
552    fn foo() {}
553    fn bar<'b>() {}
554}
555    ".parse().unwrap()).to_string();
556    assert_eq!(out, "
557impl Foo {
558    fn foo < 'a > () {}
559    fn bar < 'b, 'a > () {}
560}
561    ".parse::<TokenStream>().unwrap().to_string());
562    TokenStream::new()
563}
564
565
566/// ```
567/// using_param::__test_other_assoc_item! {}
568/// ```
569#[doc(hidden)]
570#[proc_macro]
571pub fn __test_other_assoc_item(_: TokenStream) -> TokenStream {
572    let out = using_param("ctx: i32".parse().unwrap(), "
573impl Foo {
574    pub const M: usize = 3;
575    pub type C = i32;
576    some_macro!();
577    fn foo() {}
578    fn bar(m: i32) { m+ctx }
579    fn baz(self, m: i32) { m+ctx }
580}
581    ".parse().unwrap()).to_string();
582    assert_eq!(out, "
583impl Foo {
584    pub const M : usize = 3;
585    pub type C = i32;
586    some_macro! ();
587    fn foo(ctx : i32) {}
588    fn bar(ctx : i32, m : i32) { m+ctx }
589    fn baz(self, ctx : i32, m : i32) { m+ctx }
590}
591    ".parse::<TokenStream>().unwrap().to_string());
592    TokenStream::new()
593}
594
595
596/// ```
597/// using_param::__test_return_type! {}
598/// ```
599#[doc(hidden)]
600#[proc_macro]
601pub fn __test_return_type(_: TokenStream) -> TokenStream {
602    let out = using_return("i32".parse().unwrap(), "
603impl Foo {
604    pub const M: usize = 3;
605    pub type C = i32;
606    some_macro!();
607    fn foo() {}
608    fn bar(m: i32) { m+ctx }
609    fn baz(self, m: i32) { m+ctx }
610    fn xfoo() -> u32 {}
611    fn xbar(m: i32) -> u32 { m+ctx }
612    fn xbaz(self, m: i32) -> u32 { m+ctx }
613}
614    ".parse().unwrap()).to_string();
615    assert_eq!(out, "
616impl Foo {
617    pub const M : usize = 3;
618    pub type C = i32;
619    some_macro! ();
620    fn foo() -> i32 {}
621    fn bar(m : i32) -> i32 { m+ctx }
622    fn baz(self, m : i32) -> i32 { m+ctx }
623    fn xfoo() -> u32 {}
624    fn xbar(m : i32) -> u32 { m+ctx }
625    fn xbaz(self, m : i32) -> u32 { m+ctx }
626}
627    ".parse::<TokenStream>().unwrap().to_string());
628    TokenStream::new()
629}