1#![doc(html_root_url = "https://docs.rs/qlora-paste/1.0.20")]
145#![allow(
146 clippy::derive_partial_eq_without_eq,
147 clippy::doc_markdown,
148 clippy::match_same_arms,
149 clippy::module_name_repetitions,
150 clippy::needless_doctest_main,
151 clippy::too_many_lines
152)]
153
154extern crate proc_macro;
155
156mod attr;
157mod error;
158mod segment;
159
160use crate::attr::expand_attr;
161use crate::error::{Error, Result};
162use crate::segment::Segment;
163use proc_macro::{Delimiter, Group, Ident, Punct, Spacing, Span, TokenStream, TokenTree};
164use std::char;
165use std::iter;
166use std::panic;
167
168#[proc_macro]
169pub fn paste(input: TokenStream) -> TokenStream {
170 let mut contains_paste = false;
171 let flatten_single_interpolation = true;
172 match expand(
173 input.clone(),
174 &mut contains_paste,
175 flatten_single_interpolation,
176 ) {
177 Ok(expanded) => {
178 if contains_paste {
179 expanded
180 } else {
181 input
182 }
183 }
184 Err(err) => err.to_compile_error(),
185 }
186}
187
188#[doc(hidden)]
189#[proc_macro]
190pub fn item(input: TokenStream) -> TokenStream {
191 paste(input)
192}
193
194#[doc(hidden)]
195#[proc_macro]
196pub fn expr(input: TokenStream) -> TokenStream {
197 paste(input)
198}
199
200fn expand(
201 input: TokenStream,
202 contains_paste: &mut bool,
203 flatten_single_interpolation: bool,
204) -> Result<TokenStream> {
205 let mut expanded = TokenStream::new();
206 let mut lookbehind = Lookbehind::Other;
207 let mut prev_none_group = None::<Group>;
208 let mut tokens = input.into_iter().peekable();
209 loop {
210 let token = tokens.next();
211 if let Some(group) = prev_none_group.take() {
212 if match (&token, tokens.peek()) {
213 (Some(TokenTree::Punct(fst)), Some(TokenTree::Punct(snd))) => {
214 fst.as_char() == ':' && snd.as_char() == ':' && fst.spacing() == Spacing::Joint
215 }
216 _ => false,
217 } {
218 expanded.extend(group.stream());
219 *contains_paste = true;
220 } else {
221 expanded.extend(iter::once(TokenTree::Group(group)));
222 }
223 }
224 match token {
225 Some(TokenTree::Group(group)) => {
226 let delimiter = group.delimiter();
227 let content = group.stream();
228 let span = group.span();
229 if delimiter == Delimiter::Bracket && is_paste_operation(&content) {
230 let segments = parse_bracket_as_segments(content, span)?;
231 let pasted = segment::paste(&segments)?;
232 let tokens = pasted_to_tokens(pasted, span)?;
233 expanded.extend(tokens);
234 *contains_paste = true;
235 } else if flatten_single_interpolation
236 && delimiter == Delimiter::None
237 && is_single_interpolation_group(&content)
238 {
239 expanded.extend(content);
240 *contains_paste = true;
241 } else {
242 let mut group_contains_paste = false;
243 let is_attribute = delimiter == Delimiter::Bracket
244 && (lookbehind == Lookbehind::Pound || lookbehind == Lookbehind::PoundBang);
245 let mut nested = expand(
246 content,
247 &mut group_contains_paste,
248 flatten_single_interpolation && !is_attribute,
249 )?;
250 if is_attribute {
251 nested = expand_attr(nested, span, &mut group_contains_paste)?;
252 }
253 let group = if group_contains_paste {
254 let mut group = Group::new(delimiter, nested);
255 group.set_span(span);
256 *contains_paste = true;
257 group
258 } else {
259 group.clone()
260 };
261 if delimiter != Delimiter::None {
262 expanded.extend(iter::once(TokenTree::Group(group)));
263 } else if lookbehind == Lookbehind::DoubleColon {
264 expanded.extend(group.stream());
265 *contains_paste = true;
266 } else {
267 prev_none_group = Some(group);
268 }
269 }
270 lookbehind = Lookbehind::Other;
271 }
272 Some(TokenTree::Punct(punct)) => {
273 lookbehind = match punct.as_char() {
274 ':' if lookbehind == Lookbehind::JointColon => Lookbehind::DoubleColon,
275 ':' if punct.spacing() == Spacing::Joint => Lookbehind::JointColon,
276 '#' => Lookbehind::Pound,
277 '!' if lookbehind == Lookbehind::Pound => Lookbehind::PoundBang,
278 _ => Lookbehind::Other,
279 };
280 expanded.extend(iter::once(TokenTree::Punct(punct)));
281 }
282 Some(other) => {
283 lookbehind = Lookbehind::Other;
284 expanded.extend(iter::once(other));
285 }
286 None => return Ok(expanded),
287 }
288 }
289}
290
291#[derive(PartialEq)]
292enum Lookbehind {
293 JointColon,
294 DoubleColon,
295 Pound,
296 PoundBang,
297 Other,
298}
299
300fn is_single_interpolation_group(input: &TokenStream) -> bool {
302 #[derive(PartialEq)]
303 enum State {
304 Init,
305 Ident,
306 Literal,
307 Apostrophe,
308 Lifetime,
309 Colon1,
310 Colon2,
311 }
312
313 let mut state = State::Init;
314 for tt in input.clone() {
315 state = match (state, &tt) {
316 (State::Init, TokenTree::Ident(_)) => State::Ident,
317 (State::Init, TokenTree::Literal(_)) => State::Literal,
318 (State::Init, TokenTree::Punct(punct)) if punct.as_char() == '\'' => State::Apostrophe,
319 (State::Apostrophe, TokenTree::Ident(_)) => State::Lifetime,
320 (State::Ident, TokenTree::Punct(punct))
321 if punct.as_char() == ':' && punct.spacing() == Spacing::Joint =>
322 {
323 State::Colon1
324 }
325 (State::Colon1, TokenTree::Punct(punct))
326 if punct.as_char() == ':' && punct.spacing() == Spacing::Alone =>
327 {
328 State::Colon2
329 }
330 (State::Colon2, TokenTree::Ident(_)) => State::Ident,
331 _ => return false,
332 };
333 }
334
335 state == State::Ident || state == State::Literal || state == State::Lifetime
336}
337
338fn is_paste_operation(input: &TokenStream) -> bool {
339 let mut tokens = input.clone().into_iter();
340
341 match &tokens.next() {
342 Some(TokenTree::Punct(punct)) if punct.as_char() == '<' => {}
343 _ => return false,
344 }
345
346 let mut has_token = false;
347 loop {
348 match &tokens.next() {
349 Some(TokenTree::Punct(punct)) if punct.as_char() == '>' => {
350 return has_token && tokens.next().is_none();
351 }
352 Some(_) => has_token = true,
353 None => return false,
354 }
355 }
356}
357
358fn parse_bracket_as_segments(input: TokenStream, scope: Span) -> Result<Vec<Segment>> {
359 let mut tokens = input.into_iter().peekable();
360
361 match &tokens.next() {
362 Some(TokenTree::Punct(punct)) if punct.as_char() == '<' => {}
363 Some(wrong) => return Err(Error::new(wrong.span(), "expected `<`")),
364 None => return Err(Error::new(scope, "expected `[< ... >]`")),
365 }
366
367 let mut segments = segment::parse(&mut tokens)?;
368
369 match &tokens.next() {
370 Some(TokenTree::Punct(punct)) if punct.as_char() == '>' => {}
371 Some(wrong) => return Err(Error::new(wrong.span(), "expected `>`")),
372 None => return Err(Error::new(scope, "expected `[< ... >]`")),
373 }
374
375 if let Some(unexpected) = tokens.next() {
376 return Err(Error::new(
377 unexpected.span(),
378 "unexpected input, expected `[< ... >]`",
379 ));
380 }
381
382 for segment in &mut segments {
383 if let Segment::String(string) = segment {
384 if string.value.starts_with("'\\u{") {
385 let hex = &string.value[4..string.value.len() - 2];
386 if let Ok(unsigned) = u32::from_str_radix(hex, 16) {
387 if let Some(ch) = char::from_u32(unsigned) {
388 string.value.clear();
389 string.value.push(ch);
390 continue;
391 }
392 }
393 }
394 if string.value.contains(&['#', '\\', '.', '+'][..])
395 || string.value.starts_with("b'")
396 || string.value.starts_with("b\"")
397 || string.value.starts_with("br\"")
398 {
399 return Err(Error::new(string.span, "unsupported literal"));
400 }
401 let mut range = 0..string.value.len();
402 if string.value.starts_with("r\"") {
403 range.start += 2;
404 range.end -= 1;
405 } else if string.value.starts_with(&['"', '\''][..]) {
406 range.start += 1;
407 range.end -= 1;
408 }
409 string.value = string.value[range].replace('-', "_");
410 }
411 }
412
413 Ok(segments)
414}
415
416fn pasted_to_tokens(mut pasted: String, span: Span) -> Result<TokenStream> {
417 let mut tokens = TokenStream::new();
418
419 #[cfg(not(no_literal_fromstr))]
420 {
421 use proc_macro::{LexError, Literal};
422 use std::str::FromStr;
423
424 if pasted.starts_with(|ch: char| ch.is_ascii_digit()) {
425 let literal = match panic::catch_unwind(|| Literal::from_str(&pasted)) {
426 Ok(Ok(literal)) => TokenTree::Literal(literal),
427 Ok(Err(LexError { .. })) | Err(_) => {
428 return Err(Error::new(
429 span,
430 &format!("`{:?}` is not a valid literal", pasted),
431 ));
432 }
433 };
434 tokens.extend(iter::once(literal));
435 return Ok(tokens);
436 }
437 }
438
439 if pasted.starts_with('\'') {
440 let mut apostrophe = TokenTree::Punct(Punct::new('\'', Spacing::Joint));
441 apostrophe.set_span(span);
442 tokens.extend(iter::once(apostrophe));
443 pasted.remove(0);
444 }
445
446 let ident = match panic::catch_unwind(|| Ident::new(&pasted, span)) {
447 Ok(ident) => TokenTree::Ident(ident),
448 Err(_) => {
449 return Err(Error::new(
450 span,
451 &format!("`{:?}` is not a valid identifier", pasted),
452 ));
453 }
454 };
455
456 tokens.extend(iter::once(ident));
457 Ok(tokens)
458}