format_args_conditional 0.1.1

A procedural macro that can expand to one macro or another based on whether `format_args!` input could be optimized as `write_str`
Documentation
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

extern crate proc_macro;

use proc_macro::{Delimiter, Group, Literal, Punct, Spacing, Span, TokenStream, TokenTree};

mod litstr;

fn err<T>(msg: &str, span: Span) -> Result<T, TokenStream> {
    let mut out: TokenStream = "core::compile_error!"
        .parse::<TokenStream>()
        .unwrap()
        .into_iter()
        .map(|mut x| {
            x.set_span(span);
            x
        })
        .collect();
    let mut msg = Literal::string(msg);
    msg.set_span(span);
    let mut call = TokenTree::Group(Group::new(
        Delimiter::Parenthesis,
        TokenTree::Literal(msg).into(),
    ));
    call.set_span(span);
    out.extend([call]);
    Err(out)
}

struct MacroArg {
    macro_path: TokenStream,
    start_args: TokenStream,
    #[allow(dead_code)]
    comma: Punct,
}

fn read_macro_arg_with_comma(
    iter: &mut proc_macro::token_stream::IntoIter,
) -> Result<MacroArg, TokenStream> {
    #[derive(Clone, Copy, PartialEq)]
    enum ParseState {
        ExpectIdentOrColon,
        ExpectIdent,
        ExpectColonOrArgGroup,
        ExpectSecondColon,
        ExpectComma,
    }
    use ParseState::*;
    let mut macro_path = TokenStream::default();
    let mut start_args = TokenStream::default();
    let mut state = ExpectIdentOrColon;
    let comma = loop {
        let Some(tt) = iter.next() else {
            return err("At least three arguments are required", Span::call_site());
        };
        state = match (&tt, state) {
            (TokenTree::Ident(_), ExpectIdentOrColon | ExpectIdent) => ExpectColonOrArgGroup,
            (TokenTree::Group(g), ExpectColonOrArgGroup) => {
                start_args = g.stream().clone();
                ExpectComma
            }
            (TokenTree::Punct(p), _) if p.as_char() == ',' => break p.clone(),
            (TokenTree::Punct(p), _) if p.as_char() == ':' => match state {
                ExpectIdentOrColon | ExpectColonOrArgGroup => ExpectSecondColon,
                ExpectSecondColon => ExpectIdent,
                ExpectIdent | ExpectComma => return err("Unexpected colon", p.span()),
            },
            (_, ExpectComma) => return err("Expected a comma", tt.span()),
            _ => return err("Expected path to macro", tt.span()),
        };
        if state != ExpectComma {
            // Just parsed an arg group, this is not part of the macro path
            macro_path.extend([tt])
        }
    };
    Ok(MacroArg {
        macro_path,
        start_args,
        comma,
    })
}

/// Returns `Some` if `s` does *not* contain any `format_args!` format arguments,
/// and so can be passed directly to a lower-level string printing function.
///
/// When this function returns `Some`, [`fmt::Arguments::as_str`] must return `Some`
/// with the same input string, so long as `s` was accepted by `format_args!`.
///
/// [`fmt::Arguments::as_str`]: core::fmt::Arguments::as_str
fn non_capturing_format_to_str(s: &str) -> Option<Box<str>> {
    let mut pos = 0;
    // `{` indicates an arg to format, though exclude any `{{`, same with `}`.
    let mut pieces = Vec::new();
    while let Some(mut i) = s[pos..].find(&['{', '}']) {
        i += pos;
        let c = s.as_bytes()[i];
        match s.as_bytes().get(i + 1) {
            Some(&next) if next == c => {
                pieces.push(&s[pos..=i]);
                pos = i + 2
            }
            None => break,
            _ => return None,
        }
    }
    pieces.push(&s[pos..]);
    Some(pieces.join("").into_boxed_str())
}

fn branch_on_format_capture_impl(input: TokenStream) -> Result<TokenStream, TokenStream> {
    let mut tokens = input.into_iter();

    let capturing_macro = read_macro_arg_with_comma(&mut tokens)?;
    let constant_str_macro = read_macro_arg_with_comma(&mut tokens)?;

    let format_str = match tokens.next() {
        None => return err("Unexpected end of input", Span::call_site()),
        Some(TokenTree::Literal(lit)) => lit,
        Some(tt) => return err("Expected string literal", tt.span()),
    };
    let format_str_contents = match litstr::parse_lit_str(&format_str.to_string()) {
        Ok(contents) => contents,
        Err(_) => return err("Expected string literal", format_str.span()),
    };

    let (
        format_str,
        MacroArg {
            macro_path: mut out,
            start_args: mut args,
            ..
        },
    ) = match non_capturing_format_to_str(&format_str_contents) {
        Some(unescaped) => {
            let mut lit = Literal::string(&unescaped);
            lit.set_span(format_str.span());
            (lit, constant_str_macro)
        }
        None => (format_str, capturing_macro),
    };
    args.extend([TokenTree::Literal(format_str)]);
    args.extend(tokens);
    out.extend([
        TokenTree::Punct(Punct::new('!', Spacing::Alone)),
        TokenTree::Group(Group::new(Delimiter::Parenthesis, args)),
    ]);
    Ok(out)
}

/// Accepts two macro specifications and a set of `format_args!` arguments
/// and conditionally expands based on whether the `format_args!` specifies
/// any arguments to format.
///
/// Translating into `macro_rules!`, the syntax is roughly:
///
/// ```
/// # macro_rules! branch { (
/// $capturing_macro:path $({$($capturing_prefix:tt)*})?,
/// $constant_macro:path $({$($constant_prefix:tt)*})?,
/// $format_string:literal
/// $($remaining_tokens:tt)*
/// # ) => {}; }
/// ```
///
/// This is a higher-order macro that accepts paths to other macros as input.
/// A macro specification is the path to a macro excluding the final `!`,
/// optionally followed by a `{}` containing _prefix tokens_.
///
/// If the format string would cause `format_args!` to capture arguments,
/// the first path is invoked as a macro, with the arguments to that macro
/// being the prefix tokens for that path followed by the format string
/// and remaining tokens. Otherwise, the second macro path is used and invoked.
///
/// # Example
///
/// This is an optimizing `format!` macro that uses `String::from` instead of
/// `format!` if there are no arguments other than the format string.
///
/// ```
/// # use format_args_conditional::branch_on_format_capture;
/// #[macro_export]
/// macro_rules! opt_format {
///     () => {
///         (alloc::string::String::new(), true)
///     };
///     (@capture $($args:tt)*) => {
///         (format!($($args)*), false)
///     };
///     (@nocapture $str:literal) => {
///         (alloc::string::String::from($str), true)
///     };
///     ($msg:tt $(,)?) => {
///         format_args_conditional::branch_on_format_capture!(
///             $crate::opt_format {@capture},
///             $crate::opt_format {@nocapture},
///             $msg
///         )
///     };
///     ($($args:tt)*) => {
///         opt_format!(@capture $($args)*)
///     };
/// }
/// ```
#[proc_macro]
pub fn branch_on_format_capture(input: TokenStream) -> TokenStream {
    match branch_on_format_capture_impl(input) {
        Ok(x) => x,
        Err(x) => x,
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    #[test]
    fn test_non_capturing_format_to_str() {
        assert_eq!(non_capturing_format_to_str("").as_deref(), Some(""));
        assert_eq!(non_capturing_format_to_str("x").as_deref(), Some("x"));
        assert_eq!(non_capturing_format_to_str("{x}").as_deref(), None);
        assert_eq!(non_capturing_format_to_str("{{x}}").as_deref(), Some("{x}"));
        assert_eq!(
            non_capturing_format_to_str("{{}}x{{{{}}}}x}}").as_deref(),
            Some("{}x{{}}x}")
        );

        // While format_args! refuses these, these should not panic.
        let _ = non_capturing_format_to_str("{{x}");
        let _ = non_capturing_format_to_str("{{x}x}");
        let _ = non_capturing_format_to_str("{x}}");
    }
}