qlora_paste/
lib.rs

1// Copyright 2019 David Tolnay
2// Copyright 2024 Tyler Zervas
3//
4// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
7// option. This file may not be copied, modified, or distributed
8// except according to those terms.
9
10//! [![github]](https://github.com/dtolnay/paste)&ensp;[![crates-io]](https://crates.io/crates/paste)&ensp;[![docs-rs]](https://docs.rs/paste)
11//!
12//! [github]: https://img.shields.io/badge/github-8da0cb?style=for-the-badge&labelColor=555555&logo=github
13//! [crates-io]: https://img.shields.io/badge/crates.io-fc8d62?style=for-the-badge&labelColor=555555&logo=rust
14//! [docs-rs]: https://img.shields.io/badge/docs.rs-66c2a5?style=for-the-badge&labelColor=555555&logo=docs.rs
15//!
16//! <br>
17//!
18//! The nightly-only [`concat_idents!`] macro in the Rust standard library is
19//! notoriously underpowered in that its concatenated identifiers can only refer to
20//! existing items, they can never be used to define something new.
21//!
22//! [`concat_idents!`]: https://doc.rust-lang.org/std/macro.concat_idents.html
23//!
24//! This crate provides a flexible way to paste together identifiers in a macro,
25//! including using pasted identifiers to define new items.
26//!
27//! This approach works with any Rust compiler 1.31+.
28//!
29//! <br>
30//!
31//! # Pasting identifiers
32//!
33//! Within the `paste!` macro, identifiers inside `[<`...`>]` are pasted
34//! together to form a single identifier.
35//!
36//! ```
37//! use qlora_paste::paste;
38//!
39//! paste! {
40//!     // Defines a const called `QRST`.
41//!     const [<Q R S T>]: &str = "success!";
42//! }
43//!
44//! fn main() {
45//!     assert_eq!(
46//!         paste! { [<Q R S T>].len() },
47//!         8,
48//!     );
49//! }
50//! ```
51//!
52//! <br><br>
53//!
54//! # More elaborate example
55//!
56//! The next example shows a macro that generates accessor methods for some
57//! struct fields. It demonstrates how you might find it useful to bundle a
58//! paste invocation inside of a macro\_rules macro.
59//!
60//! ```
61//! use qlora_paste::paste;
62//!
63//! macro_rules! make_a_struct_and_getters {
64//!     ($name:ident { $($field:ident),* }) => {
65//!         // Define a struct. This expands to:
66//!         //
67//!         //     pub struct S {
68//!         //         a: String,
69//!         //         b: String,
70//!         //         c: String,
71//!         //     }
72//!         pub struct $name {
73//!             $(
74//!                 $field: String,
75//!             )*
76//!         }
77//!
78//!         // Build an impl block with getters. This expands to:
79//!         //
80//!         //     impl S {
81//!         //         pub fn get_a(&self) -> &str { &self.a }
82//!         //         pub fn get_b(&self) -> &str { &self.b }
83//!         //         pub fn get_c(&self) -> &str { &self.c }
84//!         //     }
85//!         paste! {
86//!             impl $name {
87//!                 $(
88//!                     pub fn [<get_ $field>](&self) -> &str {
89//!                         &self.$field
90//!                     }
91//!                 )*
92//!             }
93//!         }
94//!     }
95//! }
96//!
97//! make_a_struct_and_getters!(S { a, b, c });
98//!
99//! fn call_some_getters(s: &S) -> bool {
100//!     s.get_a() == s.get_b() && s.get_c().is_empty()
101//! }
102//! #
103//! # fn main() {}
104//! ```
105//!
106//! <br><br>
107//!
108//! # Case conversion
109//!
110//! Use `$var:lower` or `$var:upper` in the segment list to convert an
111//! interpolated segment to lower- or uppercase as part of the paste. For
112//! example, `[<ld_ $reg:lower _expr>]` would paste to `ld_bc_expr` if invoked
113//! with $reg=`Bc`.
114//!
115//! Use `$var:snake` to convert CamelCase input to snake\_case.
116//! Use `$var:camel` to convert snake\_case to CamelCase.
117//! These compose, so for example `$var:snake:upper` would give you SCREAMING\_CASE.
118//!
119//! The precise Unicode conversions are as defined by [`str::to_lowercase`] and
120//! [`str::to_uppercase`].
121//!
122//! [`str::to_lowercase`]: https://doc.rust-lang.org/std/primitive.str.html#method.to_lowercase
123//! [`str::to_uppercase`]: https://doc.rust-lang.org/std/primitive.str.html#method.to_uppercase
124//!
125//! <br>
126//!
127//! # Pasting documentation strings
128//!
129//! Within the `paste!` macro, arguments to a #\[doc ...\] attribute are
130//! implicitly concatenated together to form a coherent documentation string.
131//!
132//! ```
133//! use qlora_paste::paste;
134//!
135//! macro_rules! method_new {
136//!     ($ret:ident) => {
137//!         paste! {
138//!             #[doc = "Create a new `" $ret "` object."]
139//!             pub fn new() -> $ret { todo!() }
140//!         }
141//!     };
142//! }
143//!
144//! pub struct Paste {}
145//!
146//! method_new!(Paste);  // expands to #[doc = "Create a new `Paste` object"]
147//! ```
148
149#![doc(html_root_url = "https://docs.rs/qlora-paste/1.0.16")]
150#![allow(
151    clippy::derive_partial_eq_without_eq,
152    clippy::doc_markdown,
153    clippy::match_same_arms,
154    clippy::module_name_repetitions,
155    clippy::needless_doctest_main,
156    clippy::too_many_lines
157)]
158
159extern crate proc_macro;
160
161mod attr;
162mod error;
163mod segment;
164
165use crate::attr::expand_attr;
166use crate::error::{Error, Result};
167use crate::segment::Segment;
168use proc_macro::{Delimiter, Group, Ident, Punct, Spacing, Span, TokenStream, TokenTree};
169use std::char;
170use std::iter;
171use std::panic;
172
173#[proc_macro]
174pub fn paste(input: TokenStream) -> TokenStream {
175    let mut contains_paste = false;
176    let flatten_single_interpolation = true;
177    match expand(
178        input.clone(),
179        &mut contains_paste,
180        flatten_single_interpolation,
181    ) {
182        Ok(expanded) => {
183            if contains_paste {
184                expanded
185            } else {
186                input
187            }
188        }
189        Err(err) => err.to_compile_error(),
190    }
191}
192
193#[doc(hidden)]
194#[proc_macro]
195pub fn item(input: TokenStream) -> TokenStream {
196    paste(input)
197}
198
199#[doc(hidden)]
200#[proc_macro]
201pub fn expr(input: TokenStream) -> TokenStream {
202    paste(input)
203}
204
205fn expand(
206    input: TokenStream,
207    contains_paste: &mut bool,
208    flatten_single_interpolation: bool,
209) -> Result<TokenStream> {
210    let mut expanded = TokenStream::new();
211    let mut lookbehind = Lookbehind::Other;
212    let mut prev_none_group = None::<Group>;
213    let mut tokens = input.into_iter().peekable();
214    loop {
215        let token = tokens.next();
216        if let Some(group) = prev_none_group.take() {
217            if match (&token, tokens.peek()) {
218                (Some(TokenTree::Punct(fst)), Some(TokenTree::Punct(snd))) => {
219                    fst.as_char() == ':' && snd.as_char() == ':' && fst.spacing() == Spacing::Joint
220                }
221                _ => false,
222            } {
223                expanded.extend(group.stream());
224                *contains_paste = true;
225            } else {
226                expanded.extend(iter::once(TokenTree::Group(group)));
227            }
228        }
229        match token {
230            Some(TokenTree::Group(group)) => {
231                let delimiter = group.delimiter();
232                let content = group.stream();
233                let span = group.span();
234                if delimiter == Delimiter::Bracket && is_paste_operation(&content) {
235                    let segments = parse_bracket_as_segments(content, span)?;
236                    let pasted = segment::paste(&segments)?;
237                    let tokens = pasted_to_tokens(pasted, span)?;
238                    expanded.extend(tokens);
239                    *contains_paste = true;
240                } else if flatten_single_interpolation
241                    && delimiter == Delimiter::None
242                    && is_single_interpolation_group(&content)
243                {
244                    expanded.extend(content);
245                    *contains_paste = true;
246                } else {
247                    let mut group_contains_paste = false;
248                    let is_attribute = delimiter == Delimiter::Bracket
249                        && (lookbehind == Lookbehind::Pound || lookbehind == Lookbehind::PoundBang);
250                    let mut nested = expand(
251                        content,
252                        &mut group_contains_paste,
253                        flatten_single_interpolation && !is_attribute,
254                    )?;
255                    if is_attribute {
256                        nested = expand_attr(nested, span, &mut group_contains_paste)?;
257                    }
258                    let group = if group_contains_paste {
259                        let mut group = Group::new(delimiter, nested);
260                        group.set_span(span);
261                        *contains_paste = true;
262                        group
263                    } else {
264                        group.clone()
265                    };
266                    if delimiter != Delimiter::None {
267                        expanded.extend(iter::once(TokenTree::Group(group)));
268                    } else if lookbehind == Lookbehind::DoubleColon {
269                        expanded.extend(group.stream());
270                        *contains_paste = true;
271                    } else {
272                        prev_none_group = Some(group);
273                    }
274                }
275                lookbehind = Lookbehind::Other;
276            }
277            Some(TokenTree::Punct(punct)) => {
278                lookbehind = match punct.as_char() {
279                    ':' if lookbehind == Lookbehind::JointColon => Lookbehind::DoubleColon,
280                    ':' if punct.spacing() == Spacing::Joint => Lookbehind::JointColon,
281                    '#' => Lookbehind::Pound,
282                    '!' if lookbehind == Lookbehind::Pound => Lookbehind::PoundBang,
283                    _ => Lookbehind::Other,
284                };
285                expanded.extend(iter::once(TokenTree::Punct(punct)));
286            }
287            Some(other) => {
288                lookbehind = Lookbehind::Other;
289                expanded.extend(iter::once(other));
290            }
291            None => return Ok(expanded),
292        }
293    }
294}
295
296#[derive(PartialEq)]
297enum Lookbehind {
298    JointColon,
299    DoubleColon,
300    Pound,
301    PoundBang,
302    Other,
303}
304
305// https://github.com/dtolnay/paste/issues/26
306fn is_single_interpolation_group(input: &TokenStream) -> bool {
307    #[derive(PartialEq)]
308    enum State {
309        Init,
310        Ident,
311        Literal,
312        Apostrophe,
313        Lifetime,
314        Colon1,
315        Colon2,
316    }
317
318    let mut state = State::Init;
319    for tt in input.clone() {
320        state = match (state, &tt) {
321            (State::Init, TokenTree::Ident(_)) => State::Ident,
322            (State::Init, TokenTree::Literal(_)) => State::Literal,
323            (State::Init, TokenTree::Punct(punct)) if punct.as_char() == '\'' => State::Apostrophe,
324            (State::Apostrophe, TokenTree::Ident(_)) => State::Lifetime,
325            (State::Ident, TokenTree::Punct(punct))
326                if punct.as_char() == ':' && punct.spacing() == Spacing::Joint =>
327            {
328                State::Colon1
329            }
330            (State::Colon1, TokenTree::Punct(punct))
331                if punct.as_char() == ':' && punct.spacing() == Spacing::Alone =>
332            {
333                State::Colon2
334            }
335            (State::Colon2, TokenTree::Ident(_)) => State::Ident,
336            _ => return false,
337        };
338    }
339
340    state == State::Ident || state == State::Literal || state == State::Lifetime
341}
342
343fn is_paste_operation(input: &TokenStream) -> bool {
344    let mut tokens = input.clone().into_iter();
345
346    match &tokens.next() {
347        Some(TokenTree::Punct(punct)) if punct.as_char() == '<' => {}
348        _ => return false,
349    }
350
351    let mut has_token = false;
352    loop {
353        match &tokens.next() {
354            Some(TokenTree::Punct(punct)) if punct.as_char() == '>' => {
355                return has_token && tokens.next().is_none();
356            }
357            Some(_) => has_token = true,
358            None => return false,
359        }
360    }
361}
362
363fn parse_bracket_as_segments(input: TokenStream, scope: Span) -> Result<Vec<Segment>> {
364    let mut tokens = input.into_iter().peekable();
365
366    match &tokens.next() {
367        Some(TokenTree::Punct(punct)) if punct.as_char() == '<' => {}
368        Some(wrong) => return Err(Error::new(wrong.span(), "expected `<`")),
369        None => return Err(Error::new(scope, "expected `[< ... >]`")),
370    }
371
372    let mut segments = segment::parse(&mut tokens)?;
373
374    match &tokens.next() {
375        Some(TokenTree::Punct(punct)) if punct.as_char() == '>' => {}
376        Some(wrong) => return Err(Error::new(wrong.span(), "expected `>`")),
377        None => return Err(Error::new(scope, "expected `[< ... >]`")),
378    }
379
380    if let Some(unexpected) = tokens.next() {
381        return Err(Error::new(
382            unexpected.span(),
383            "unexpected input, expected `[< ... >]`",
384        ));
385    }
386
387    for segment in &mut segments {
388        if let Segment::String(string) = segment {
389            if string.value.starts_with("'\\u{") {
390                let hex = &string.value[4..string.value.len() - 2];
391                if let Ok(unsigned) = u32::from_str_radix(hex, 16) {
392                    if let Some(ch) = char::from_u32(unsigned) {
393                        string.value.clear();
394                        string.value.push(ch);
395                        continue;
396                    }
397                }
398            }
399            if string.value.contains(&['#', '\\', '.', '+'][..])
400                || string.value.starts_with("b'")
401                || string.value.starts_with("b\"")
402                || string.value.starts_with("br\"")
403            {
404                return Err(Error::new(string.span, "unsupported literal"));
405            }
406            let mut range = 0..string.value.len();
407            if string.value.starts_with("r\"") {
408                range.start += 2;
409                range.end -= 1;
410            } else if string.value.starts_with(&['"', '\''][..]) {
411                range.start += 1;
412                range.end -= 1;
413            }
414            string.value = string.value[range].replace('-', "_");
415        }
416    }
417
418    Ok(segments)
419}
420
421fn pasted_to_tokens(mut pasted: String, span: Span) -> Result<TokenStream> {
422    let mut tokens = TokenStream::new();
423
424    #[cfg(not(no_literal_fromstr))]
425    {
426        use proc_macro::{LexError, Literal};
427        use std::str::FromStr;
428
429        if pasted.starts_with(|ch: char| ch.is_ascii_digit()) {
430            let literal = match panic::catch_unwind(|| Literal::from_str(&pasted)) {
431                Ok(Ok(literal)) => TokenTree::Literal(literal),
432                Ok(Err(LexError { .. })) | Err(_) => {
433                    return Err(Error::new(
434                        span,
435                        &format!("`{:?}` is not a valid literal", pasted),
436                    ));
437                }
438            };
439            tokens.extend(iter::once(literal));
440            return Ok(tokens);
441        }
442    }
443
444    if pasted.starts_with('\'') {
445        let mut apostrophe = TokenTree::Punct(Punct::new('\'', Spacing::Joint));
446        apostrophe.set_span(span);
447        tokens.extend(iter::once(apostrophe));
448        pasted.remove(0);
449    }
450
451    let ident = match panic::catch_unwind(|| Ident::new(&pasted, span)) {
452        Ok(ident) => TokenTree::Ident(ident),
453        Err(_) => {
454            return Err(Error::new(
455                span,
456                &format!("`{:?}` is not a valid identifier", pasted),
457            ));
458        }
459    };
460
461    tokens.extend(iter::once(ident));
462    Ok(tokens)
463}