1#![doc(html_root_url = "https://docs.rs/qlora-paste/1.0.16")]
150#![allow(
151 clippy::derive_partial_eq_without_eq,
152 clippy::doc_markdown,
153 clippy::match_same_arms,
154 clippy::module_name_repetitions,
155 clippy::needless_doctest_main,
156 clippy::too_many_lines
157)]
158
159extern crate proc_macro;
160
161mod attr;
162mod error;
163mod segment;
164
165use crate::attr::expand_attr;
166use crate::error::{Error, Result};
167use crate::segment::Segment;
168use proc_macro::{Delimiter, Group, Ident, Punct, Spacing, Span, TokenStream, TokenTree};
169use std::char;
170use std::iter;
171use std::panic;
172
173#[proc_macro]
174pub fn paste(input: TokenStream) -> TokenStream {
175 let mut contains_paste = false;
176 let flatten_single_interpolation = true;
177 match expand(
178 input.clone(),
179 &mut contains_paste,
180 flatten_single_interpolation,
181 ) {
182 Ok(expanded) => {
183 if contains_paste {
184 expanded
185 } else {
186 input
187 }
188 }
189 Err(err) => err.to_compile_error(),
190 }
191}
192
193#[doc(hidden)]
194#[proc_macro]
195pub fn item(input: TokenStream) -> TokenStream {
196 paste(input)
197}
198
199#[doc(hidden)]
200#[proc_macro]
201pub fn expr(input: TokenStream) -> TokenStream {
202 paste(input)
203}
204
205fn expand(
206 input: TokenStream,
207 contains_paste: &mut bool,
208 flatten_single_interpolation: bool,
209) -> Result<TokenStream> {
210 let mut expanded = TokenStream::new();
211 let mut lookbehind = Lookbehind::Other;
212 let mut prev_none_group = None::<Group>;
213 let mut tokens = input.into_iter().peekable();
214 loop {
215 let token = tokens.next();
216 if let Some(group) = prev_none_group.take() {
217 if match (&token, tokens.peek()) {
218 (Some(TokenTree::Punct(fst)), Some(TokenTree::Punct(snd))) => {
219 fst.as_char() == ':' && snd.as_char() == ':' && fst.spacing() == Spacing::Joint
220 }
221 _ => false,
222 } {
223 expanded.extend(group.stream());
224 *contains_paste = true;
225 } else {
226 expanded.extend(iter::once(TokenTree::Group(group)));
227 }
228 }
229 match token {
230 Some(TokenTree::Group(group)) => {
231 let delimiter = group.delimiter();
232 let content = group.stream();
233 let span = group.span();
234 if delimiter == Delimiter::Bracket && is_paste_operation(&content) {
235 let segments = parse_bracket_as_segments(content, span)?;
236 let pasted = segment::paste(&segments)?;
237 let tokens = pasted_to_tokens(pasted, span)?;
238 expanded.extend(tokens);
239 *contains_paste = true;
240 } else if flatten_single_interpolation
241 && delimiter == Delimiter::None
242 && is_single_interpolation_group(&content)
243 {
244 expanded.extend(content);
245 *contains_paste = true;
246 } else {
247 let mut group_contains_paste = false;
248 let is_attribute = delimiter == Delimiter::Bracket
249 && (lookbehind == Lookbehind::Pound || lookbehind == Lookbehind::PoundBang);
250 let mut nested = expand(
251 content,
252 &mut group_contains_paste,
253 flatten_single_interpolation && !is_attribute,
254 )?;
255 if is_attribute {
256 nested = expand_attr(nested, span, &mut group_contains_paste)?;
257 }
258 let group = if group_contains_paste {
259 let mut group = Group::new(delimiter, nested);
260 group.set_span(span);
261 *contains_paste = true;
262 group
263 } else {
264 group.clone()
265 };
266 if delimiter != Delimiter::None {
267 expanded.extend(iter::once(TokenTree::Group(group)));
268 } else if lookbehind == Lookbehind::DoubleColon {
269 expanded.extend(group.stream());
270 *contains_paste = true;
271 } else {
272 prev_none_group = Some(group);
273 }
274 }
275 lookbehind = Lookbehind::Other;
276 }
277 Some(TokenTree::Punct(punct)) => {
278 lookbehind = match punct.as_char() {
279 ':' if lookbehind == Lookbehind::JointColon => Lookbehind::DoubleColon,
280 ':' if punct.spacing() == Spacing::Joint => Lookbehind::JointColon,
281 '#' => Lookbehind::Pound,
282 '!' if lookbehind == Lookbehind::Pound => Lookbehind::PoundBang,
283 _ => Lookbehind::Other,
284 };
285 expanded.extend(iter::once(TokenTree::Punct(punct)));
286 }
287 Some(other) => {
288 lookbehind = Lookbehind::Other;
289 expanded.extend(iter::once(other));
290 }
291 None => return Ok(expanded),
292 }
293 }
294}
295
296#[derive(PartialEq)]
297enum Lookbehind {
298 JointColon,
299 DoubleColon,
300 Pound,
301 PoundBang,
302 Other,
303}
304
305fn is_single_interpolation_group(input: &TokenStream) -> bool {
307 #[derive(PartialEq)]
308 enum State {
309 Init,
310 Ident,
311 Literal,
312 Apostrophe,
313 Lifetime,
314 Colon1,
315 Colon2,
316 }
317
318 let mut state = State::Init;
319 for tt in input.clone() {
320 state = match (state, &tt) {
321 (State::Init, TokenTree::Ident(_)) => State::Ident,
322 (State::Init, TokenTree::Literal(_)) => State::Literal,
323 (State::Init, TokenTree::Punct(punct)) if punct.as_char() == '\'' => State::Apostrophe,
324 (State::Apostrophe, TokenTree::Ident(_)) => State::Lifetime,
325 (State::Ident, TokenTree::Punct(punct))
326 if punct.as_char() == ':' && punct.spacing() == Spacing::Joint =>
327 {
328 State::Colon1
329 }
330 (State::Colon1, TokenTree::Punct(punct))
331 if punct.as_char() == ':' && punct.spacing() == Spacing::Alone =>
332 {
333 State::Colon2
334 }
335 (State::Colon2, TokenTree::Ident(_)) => State::Ident,
336 _ => return false,
337 };
338 }
339
340 state == State::Ident || state == State::Literal || state == State::Lifetime
341}
342
343fn is_paste_operation(input: &TokenStream) -> bool {
344 let mut tokens = input.clone().into_iter();
345
346 match &tokens.next() {
347 Some(TokenTree::Punct(punct)) if punct.as_char() == '<' => {}
348 _ => return false,
349 }
350
351 let mut has_token = false;
352 loop {
353 match &tokens.next() {
354 Some(TokenTree::Punct(punct)) if punct.as_char() == '>' => {
355 return has_token && tokens.next().is_none();
356 }
357 Some(_) => has_token = true,
358 None => return false,
359 }
360 }
361}
362
363fn parse_bracket_as_segments(input: TokenStream, scope: Span) -> Result<Vec<Segment>> {
364 let mut tokens = input.into_iter().peekable();
365
366 match &tokens.next() {
367 Some(TokenTree::Punct(punct)) if punct.as_char() == '<' => {}
368 Some(wrong) => return Err(Error::new(wrong.span(), "expected `<`")),
369 None => return Err(Error::new(scope, "expected `[< ... >]`")),
370 }
371
372 let mut segments = segment::parse(&mut tokens)?;
373
374 match &tokens.next() {
375 Some(TokenTree::Punct(punct)) if punct.as_char() == '>' => {}
376 Some(wrong) => return Err(Error::new(wrong.span(), "expected `>`")),
377 None => return Err(Error::new(scope, "expected `[< ... >]`")),
378 }
379
380 if let Some(unexpected) = tokens.next() {
381 return Err(Error::new(
382 unexpected.span(),
383 "unexpected input, expected `[< ... >]`",
384 ));
385 }
386
387 for segment in &mut segments {
388 if let Segment::String(string) = segment {
389 if string.value.starts_with("'\\u{") {
390 let hex = &string.value[4..string.value.len() - 2];
391 if let Ok(unsigned) = u32::from_str_radix(hex, 16) {
392 if let Some(ch) = char::from_u32(unsigned) {
393 string.value.clear();
394 string.value.push(ch);
395 continue;
396 }
397 }
398 }
399 if string.value.contains(&['#', '\\', '.', '+'][..])
400 || string.value.starts_with("b'")
401 || string.value.starts_with("b\"")
402 || string.value.starts_with("br\"")
403 {
404 return Err(Error::new(string.span, "unsupported literal"));
405 }
406 let mut range = 0..string.value.len();
407 if string.value.starts_with("r\"") {
408 range.start += 2;
409 range.end -= 1;
410 } else if string.value.starts_with(&['"', '\''][..]) {
411 range.start += 1;
412 range.end -= 1;
413 }
414 string.value = string.value[range].replace('-', "_");
415 }
416 }
417
418 Ok(segments)
419}
420
421fn pasted_to_tokens(mut pasted: String, span: Span) -> Result<TokenStream> {
422 let mut tokens = TokenStream::new();
423
424 #[cfg(not(no_literal_fromstr))]
425 {
426 use proc_macro::{LexError, Literal};
427 use std::str::FromStr;
428
429 if pasted.starts_with(|ch: char| ch.is_ascii_digit()) {
430 let literal = match panic::catch_unwind(|| Literal::from_str(&pasted)) {
431 Ok(Ok(literal)) => TokenTree::Literal(literal),
432 Ok(Err(LexError { .. })) | Err(_) => {
433 return Err(Error::new(
434 span,
435 &format!("`{:?}` is not a valid literal", pasted),
436 ));
437 }
438 };
439 tokens.extend(iter::once(literal));
440 return Ok(tokens);
441 }
442 }
443
444 if pasted.starts_with('\'') {
445 let mut apostrophe = TokenTree::Punct(Punct::new('\'', Spacing::Joint));
446 apostrophe.set_span(span);
447 tokens.extend(iter::once(apostrophe));
448 pasted.remove(0);
449 }
450
451 let ident = match panic::catch_unwind(|| Ident::new(&pasted, span)) {
452 Ok(ident) => TokenTree::Ident(ident),
453 Err(_) => {
454 return Err(Error::new(
455 span,
456 &format!("`{:?}` is not a valid identifier", pasted),
457 ));
458 }
459 };
460
461 tokens.extend(iter::once(ident));
462 Ok(tokens)
463}