extern crate proc_macro;
use proc_macro::{Delimiter, Group, Literal, Punct, Spacing, Span, TokenStream, TokenTree};
mod litstr;
fn err<T>(msg: &str, span: Span) -> Result<T, TokenStream> {
let mut out: TokenStream = "core::compile_error!"
.parse::<TokenStream>()
.unwrap()
.into_iter()
.map(|mut x| {
x.set_span(span);
x
})
.collect();
let mut msg = Literal::string(msg);
msg.set_span(span);
let mut call = TokenTree::Group(Group::new(
Delimiter::Parenthesis,
TokenTree::Literal(msg).into(),
));
call.set_span(span);
out.extend([call]);
Err(out)
}
struct MacroArg {
macro_path: TokenStream,
start_args: TokenStream,
#[allow(dead_code)]
comma: Punct,
}
fn read_macro_arg_with_comma(
iter: &mut proc_macro::token_stream::IntoIter,
) -> Result<MacroArg, TokenStream> {
#[derive(Clone, Copy, PartialEq)]
enum ParseState {
ExpectIdentOrColon,
ExpectIdent,
ExpectColonOrArgGroup,
ExpectSecondColon,
ExpectComma,
}
use ParseState::*;
let mut macro_path = TokenStream::default();
let mut start_args = TokenStream::default();
let mut state = ExpectIdentOrColon;
let comma = loop {
let Some(tt) = iter.next() else {
return err("At least three arguments are required", Span::call_site());
};
state = match (&tt, state) {
(TokenTree::Ident(_), ExpectIdentOrColon | ExpectIdent) => ExpectColonOrArgGroup,
(TokenTree::Group(g), ExpectColonOrArgGroup) => {
start_args = g.stream().clone();
ExpectComma
}
(TokenTree::Punct(p), _) if p.as_char() == ',' => break p.clone(),
(TokenTree::Punct(p), _) if p.as_char() == ':' => match state {
ExpectIdentOrColon | ExpectColonOrArgGroup => ExpectSecondColon,
ExpectSecondColon => ExpectIdent,
ExpectIdent | ExpectComma => return err("Unexpected colon", p.span()),
},
(_, ExpectComma) => return err("Expected a comma", tt.span()),
_ => return err("Expected path to macro", tt.span()),
};
if state != ExpectComma {
macro_path.extend([tt])
}
};
Ok(MacroArg {
macro_path,
start_args,
comma,
})
}
fn non_capturing_format_to_str(s: &str) -> Option<Box<str>> {
let mut pos = 0;
let mut pieces = Vec::new();
while let Some(mut i) = s[pos..].find(&['{', '}']) {
i += pos;
let c = s.as_bytes()[i];
match s.as_bytes().get(i + 1) {
Some(&next) if next == c => {
pieces.push(&s[pos..=i]);
pos = i + 2
}
None => break,
_ => return None,
}
}
pieces.push(&s[pos..]);
Some(pieces.join("").into_boxed_str())
}
fn branch_on_format_capture_impl(input: TokenStream) -> Result<TokenStream, TokenStream> {
let mut tokens = input.into_iter();
let capturing_macro = read_macro_arg_with_comma(&mut tokens)?;
let constant_str_macro = read_macro_arg_with_comma(&mut tokens)?;
let format_str = match tokens.next() {
None => return err("Unexpected end of input", Span::call_site()),
Some(TokenTree::Literal(lit)) => lit,
Some(tt) => return err("Expected string literal", tt.span()),
};
let format_str_contents = match litstr::parse_lit_str(&format_str.to_string()) {
Ok(contents) => contents,
Err(_) => return err("Expected string literal", format_str.span()),
};
let (
format_str,
MacroArg {
macro_path: mut out,
start_args: mut args,
..
},
) = match non_capturing_format_to_str(&format_str_contents) {
Some(unescaped) => {
let mut lit = Literal::string(&unescaped);
lit.set_span(format_str.span());
(lit, constant_str_macro)
}
None => (format_str, capturing_macro),
};
args.extend([TokenTree::Literal(format_str)]);
args.extend(tokens);
out.extend([
TokenTree::Punct(Punct::new('!', Spacing::Alone)),
TokenTree::Group(Group::new(Delimiter::Parenthesis, args)),
]);
Ok(out)
}
#[proc_macro]
pub fn branch_on_format_capture(input: TokenStream) -> TokenStream {
match branch_on_format_capture_impl(input) {
Ok(x) => x,
Err(x) => x,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_non_capturing_format_to_str() {
assert_eq!(non_capturing_format_to_str("").as_deref(), Some(""));
assert_eq!(non_capturing_format_to_str("x").as_deref(), Some("x"));
assert_eq!(non_capturing_format_to_str("{x}").as_deref(), None);
assert_eq!(non_capturing_format_to_str("{{x}}").as_deref(), Some("{x}"));
assert_eq!(
non_capturing_format_to_str("{{}}x{{{{}}}}x}}").as_deref(),
Some("{}x{{}}x}")
);
let _ = non_capturing_format_to_str("{{x}");
let _ = non_capturing_format_to_str("{{x}x}");
let _ = non_capturing_format_to_str("{x}}");
}
}