litext 1.1.0

Seamless proc-macro literal extraction.
Documentation
//! # litext
//!
//! Literal extraction for proc-macro authors. Pull typed values out of
//! `TokenStream` input without writing the parsing boilerplate yourself.
//!
//! ## Overview
//!
//! Procedural macros frequently receive a `TokenStream` that contains one or
//! more literal tokens and need to extract their actual values. This crate
//! provides the [`litext!`] macro and [`extract`] function for that purpose,
//! supporting every Rust literal kind with span tracking for precise diagnostics.
//!
//! ## Quick Start
//!
//! ```ignore
//! use litext::{litext, TokenStream};
//!
//! fn my_macro(input: TokenStream) -> TokenStream {
//!     // Return early with a compile error on failure:
//!     let text: String = litext!(input);
//!
//!     // Or keep the Result for explicit handling:
//!     let result: Result<String, TokenStream> = litext!(try input);
//!
//!     // Extract a non-string type:
//!     let count: u32 = litext!(input as u32);
//!
//!     // Extract multiple literals from one TokenStream:
//!     // (input must be: "name" , 42)
//!     let (name, count): (String, u32) = litext!(input as (String , u32));
//! }
//! ```
//!
//! ## Supported Literal Types
//!
//! | Target type | Accepted input | Notes |
//! |-------------|----------------|-------|
//! | `String`, [`LitStr`] | `"hello"`, `r#"raw"#` | Full escape sequence support |
//! | `i8` to `i128`, `isize`, [`LitInt<T>`] | `42`, `0xFF`, `0b1010`, `1_000u32` | All radixes, underscore separators, type suffixes |
//! | `u8` to `u128`, `usize`, [`LitInt<T>`] | same as integer | Overflow is a compile error |
//! | `f32`, `f64`, [`LitFloat<T>`] | `3.14`, `1e10`, `1_000.5f32` | Scientific notation, underscores, suffixes |
//! | `char`, [`LitChar`] | `'a'`, `'\n'`, `'\u{1F600}'` | Full Unicode and escape support |
//! | `bool`, [`LitBool`] | `true`, `false` | Parsed as identifier tokens |
//! | `u8`, [`LitByte`] | `b'a'`, `b'\xff'` | Full 0x00..=0xFF byte range |
//! | `Vec<u8>`, [`LitByteStr`] | `b"hello"`, `br#"..."#` | Full byte range for `\x` escapes |
//! | `CString`, [`LitCStr`] | `c"hello"`, `cr#"..."#` | Interior null bytes are a compile error |
//!
//! ## Span-Aware Types
//!
//! Every literal kind has a span-aware wrapper that bundles the parsed value
//! with its source location. Use these when you need to emit diagnostics that
//! point at the exact literal in the user's code.
//!
//! ```ignore
//! use litext::literal::{LitStr, LitInt};
//!
//! let name: LitStr = litext!(input as LitStr);
//! if name.value().is_empty() {
//!     return comperr::error(name.span(), "name cannot be empty");
//! }
//!
//! let count: LitInt<u32> = litext!(input as LitInt<u32>);
//! if *count.value() == 0 {
//!     return comperr::error(count.span(), "count must be non-zero");
//! }
//! ```
//!
//! ## Multi-Extraction
//!
//! Extract several literals from one `TokenStream` in a single call by wrapping
//! the target types in a tuple with the separator token between them:
//!
//! ```ignore
//! // input: "hello" , 42
//! let (s, n): (String, i32) = litext!(input as (String , i32));
//!
//! // input: "key" ; 100
//! let (k, v): (String, u32) = litext!(input as (String ; u32));
//!
//! // Try form: returns Result instead of returning early
//! let result: Result<(String, i32), _> = litext!(try input as (String , i32));
//! ```
//!
//! Up to 12 values can be extracted per call. Type arguments in tuple position
//! must be single-token identifiers; `LitInt<u8>` does not fit. Use `LitInt`
//! (which defaults to `i32`) or a plain integer type instead.
//!
//! ## Limitations
//!
//! Negative numbers like `-42` are two tokens in Rust's token stream: a `-`
//! punctuation token and a positive literal. [`extract`] and [`litext!`] do not
//! handle them. Parse the sign at the expression level instead.
//!
//! ## See Also
//!
//! - [`literal`] module for span-aware types and the [`ToTokens`] round-trip trait
//! - [`FromLit`] trait for custom literal extraction
//! - [`extract`] for the lower-level function API

#![warn(missing_docs)]

use proc_macro2::{Literal, Span, TokenStream, TokenTree};

mod literal;
mod macros;

pub use literal::*;

/// Extracts a typed value from a `TokenStream` containing exactly one literal.
///
/// This is the lower-level API. For most proc-macro code the [`litext!`] macro
/// is more convenient since it handles the early-return pattern automatically.
///
/// # Type Parameter
///
/// `T` must implement [`FromLit`]. All built-in literal types are supported:
/// `String`, `LitStr`, every integer and float type, `char`, `bool`, `u8`,
/// `Vec<u8>`, `CString`, and all span-aware wrappers (`LitInt<T>`, etc.).
///
/// # Errors
///
/// Returns `Err(TokenStream)` containing a `compile_error!` invocation when:
/// - The input is empty
/// - The input contains more than one token
/// - The input token is punctuation or a group, not a literal or identifier
/// - The literal value cannot be converted to `T` (e.g., integer overflow)
///
/// The error token stream, when returned from a proc-macro, causes the compiler
/// to display the error message at the correct source span.
///
/// # Limitations
///
/// Negative numbers (`-42`) are two tokens: a `-` punct and a positive literal.
/// This function does not handle them. Handle the sign at the expression level.
///
/// # Examples
///
/// ```ignore
/// use litext::{extract, LitStr};
/// use proc_macro2::TokenStream;
///
/// fn get_string(input: TokenStream) -> Result<String, TokenStream> {
///     extract::<String>(input)
/// }
///
/// fn get_int(input: TokenStream) -> Result<i64, TokenStream> {
///     extract::<i64>(input)
/// }
///
/// fn get_str_with_span(input: TokenStream) -> Result<LitStr, TokenStream> {
///     extract::<LitStr>(input)
/// }
/// ```
#[inline]
pub fn extract<T: literal::FromLit>(input: TokenStream) -> Result<T, TokenStream> {
    let mut iter = input.into_iter();

    let Some(token) = iter.next() else {
        #[cold]
        fn got_nothing() -> TokenStream {
            comperr::error(Span::call_site(), "expected a literal, got nothing")
        }
        return Err(got_nothing());
    };

    if let Some(next_token) = iter.next() {
        #[cold]
        fn too_many_tokens(span: Span) -> TokenStream {
            comperr::error(span, "expected exactly one literal")
        }
        return Err(too_many_tokens(next_token.span()));
    }

    match token {
        TokenTree::Literal(lit) => T::from_lit(lit),
        TokenTree::Ident(ident) => T::from_ident(ident),
        TokenTree::Punct(p) => Err(comperr::error(
            p.span(),
            "expected a literal, found punctuation",
        )),
        TokenTree::Group(g) => Err(comperr::error(g.span(), "expected a literal, found group")),
    }
}

#[doc(hidden)]
pub fn extract_from_iter<T: literal::FromLit>(
    iter: &mut proc_macro2::token_stream::IntoIter,
) -> Result<T, comperr::Error> {
    match iter.next() {
        Some(TokenTree::Literal(lit)) => {
            T::from_lit(lit).map_err(comperr::Error::from_token_stream)
        }
        Some(TokenTree::Ident(ident)) => {
            T::from_ident(ident).map_err(comperr::Error::from_token_stream)
        }
        Some(other) => Err(comperr::Error::new(other.span(), "expected a literal")),
        None => Err(comperr::Error::new(
            Span::call_site(),
            "expected a literal, got nothing",
        )),
    }
}

#[doc(hidden)]
pub fn consume_sep(
    iter: &mut proc_macro2::token_stream::IntoIter,
    expected: &str,
    span: Span,
) -> Result<(), comperr::Error> {
    match iter.next() {
        Some(TokenTree::Punct(p)) if p.to_string() == expected => Ok(()),
        Some(other) => Err(comperr::Error::new(
            other.span(),
            format!("expected `{expected}` separator"),
        )),
        None => Err(comperr::Error::new(
            span,
            format!("expected `{expected}` separator, got nothing"),
        )),
    }
}

/// Parses a string or raw string literal token and returns the unescaped content.
///
/// Rejects byte strings (`b"..."`) since those require [`unescape_bytes`].
#[inline]
pub(crate) fn parse_lit(lit: &Literal) -> Result<String, TokenStream> {
    let raw = lit.to_string();
    let span = lit.span();

    if raw.starts_with('b') && raw.len() > 1 {
        let c = raw.chars().nth(1).unwrap();
        if c == '"' || c == 'r' {
            return Err(comperr::error(
                span,
                "expected a string literal, not a byte string",
            ));
        }
    }

    if raw.starts_with('r') {
        return parse_raw(&raw).ok_or_else(|| comperr::error(span, "malformed raw string literal"));
    }

    if raw.starts_with('"') && raw.ends_with('"') && raw.len() >= 2 {
        return unescape(&raw[1..raw.len() - 1], span);
    }

    Err(comperr::error(span, "expected a string literal"))
}

/// Extracts the verbatim content from a raw string literal (`r#"..."#`).
///
/// Returns `None` if `raw` is not a well-formed raw string. Handles zero or
/// more hash marks. Raw string content is returned as-is with no escape processing.
#[inline]
pub(crate) fn parse_raw(raw: &str) -> Option<String> {
    let rest = raw.strip_prefix('r')?;
    let hashes = rest.chars().take_while(|c| *c == '#').count();
    let hash_str = "#".repeat(hashes);
    let inner = rest
        .strip_prefix(&hash_str)?
        .strip_prefix('"')?
        .strip_suffix(&hash_str)?
        .strip_suffix('"')?;
    Some(inner.to_string())
}

/// Processes escape sequences in a string literal and returns the unescaped content.
///
/// Handles `\n`, `\r`, `\t`, `\\`, `\"`, `\0`, `\xNN` (0x00..=0x7F only),
/// `\u{NNNN}`, and line continuation (`\` followed by newline).
/// Use [`unescape_bytes`] for byte literal contexts where `\xFF` is valid.
#[inline]
pub(crate) fn unescape(s: &str, span: Span) -> Result<String, TokenStream> {
    let mut output = String::with_capacity(s.len());
    let mut chars = s.chars();

    while let Some(c) = chars.next() {
        if c != '\\' {
            output.push(c);
            continue;
        }

        match chars.next() {
            Some('n') => output.push('\n'),
            Some('r') => output.push('\r'),
            Some('t') => output.push('\t'),
            Some('\\') => output.push('\\'),
            Some('"') => output.push('"'),
            Some('0') => output.push('\0'),

            Some('x') => {
                let h1 = chars
                    .next()
                    .ok_or_else(|| comperr::error(span, "invalid \\x escape"))?;
                let h2 = chars
                    .next()
                    .ok_or_else(|| comperr::error(span, "invalid \\x escape"))?;
                let hex = format!("{h1}{h2}");
                let byte = u8::from_str_radix(&hex, 16)
                    .map_err(|_| comperr::error(span, "invalid \\x escape"))?;
                if byte > 0x7F {
                    return Err(comperr::error(
                        span,
                        "\\x escape must be in range 0x00..=0x7F",
                    ));
                }
                output.push(byte as char);
            }

            Some('u') => {
                match chars.next() {
                    Some('{') => {}
                    _ => return Err(comperr::error(span, "invalid \\u escape, expected '{'")),
                }
                let mut hex = String::new();
                loop {
                    match chars.next() {
                        Some('}') => break,
                        Some(c) => hex.push(c),
                        None => return Err(comperr::error(span, "unterminated \\u escape")),
                    }
                }
                let codepoint = u32::from_str_radix(&hex, 16)
                    .map_err(|_| comperr::error(span, "invalid \\u codepoint"))?;
                let ch = char::from_u32(codepoint)
                    .ok_or_else(|| comperr::error(span, "invalid unicode codepoint"))?;
                output.push(ch);
            }

            Some('\n') => {
                while let Some(&c) = chars.as_str().chars().next().as_ref() {
                    if c.is_whitespace() {
                        chars.next();
                    } else {
                        break;
                    }
                }
            }

            _ => return Err(comperr::error(span, "invalid escape sequence")),
        }
    }

    Ok(output)
}

/// Processes escape sequences in a byte string or byte literal and returns the raw bytes.
///
/// Unlike [`unescape`], this allows `\x` escapes in the full 0x00..=0xFF range, which
/// is valid in byte literals (`b'\xff'`) and byte string literals (`b"\xff"`).
#[inline]
pub(crate) fn unescape_bytes(s: &str, span: Span) -> Result<Vec<u8>, TokenStream> {
    let mut output = Vec::with_capacity(s.len());
    let mut chars = s.chars();

    while let Some(c) = chars.next() {
        if c != '\\' {
            output.push(c as u8);
            continue;
        }

        match chars.next() {
            Some('n') => output.push(b'\n'),
            Some('r') => output.push(b'\r'),
            Some('t') => output.push(b'\t'),
            Some('\\') => output.push(b'\\'),
            Some('"') => output.push(b'"'),
            Some('\'') => output.push(b'\''),
            Some('0') => output.push(0),

            Some('x') => {
                let h1 = chars
                    .next()
                    .ok_or_else(|| comperr::error(span, "invalid \\x escape"))?;
                let h2 = chars
                    .next()
                    .ok_or_else(|| comperr::error(span, "invalid \\x escape"))?;
                let hex = format!("{h1}{h2}");
                let byte = u8::from_str_radix(&hex, 16)
                    .map_err(|_| comperr::error(span, "invalid \\x escape"))?;
                output.push(byte);
            }

            Some('\n') => {
                while let Some(&c) = chars.as_str().chars().next().as_ref() {
                    if c.is_whitespace() {
                        chars.next();
                    } else {
                        break;
                    }
                }
            }

            _ => return Err(comperr::error(span, "invalid escape sequence")),
        }
    }

    Ok(output)
}