use proc_macro::TokenStream;
use proc_macro2::{Delimiter, Group, Spacing, Span, TokenStream as TokenStream2, TokenTree};
use quote::{quote, quote_spanned, ToTokens};
use syn::parse::{Parse, ParseStream, Parser};
use syn::{Error, Expr, Pat, Result as SynResult, Token, Type};
enum Stmt {
Bind(Pat, Expr),
Let(TokenStream2),
Guard(TokenStream2),
Bare(Expr),
}
struct MdoInput {
marker: Type,
stmts: Vec<Stmt>,
final_expr: Expr,
}
fn is_colon(tt: &TokenTree) -> bool {
matches!(tt, TokenTree::Punct(p) if p.as_char() == ':')
}
fn is_colon_joint(tt: &TokenTree) -> bool {
matches!(tt, TokenTree::Punct(p) if p.as_char() == ':' && p.spacing() == Spacing::Joint)
}
fn is_dot(tt: &TokenTree) -> bool {
matches!(tt, TokenTree::Punct(p) if p.as_char() == '.')
}
fn rewrite_pure(ts: TokenStream2, marker: &Type) -> TokenStream2 {
let mut out = TokenStream2::new();
let mut iter = ts.into_iter().peekable();
let mut prev_is_colon = false;
let mut prev_colon_is_joint = false; let mut prev_is_dot = false;
let mut prev2_is_colon_joint = false; let mut prev2_is_dot = false;
while let Some(tt) = iter.next() {
if let TokenTree::Ident(ref id) = tt {
if id == "pure" {
let next_is_paren = matches!(
iter.peek(),
Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Parenthesis
);
if next_is_paren {
let is_path_qual = prev_is_colon && prev2_is_colon_joint;
let is_method = prev_is_dot && !prev2_is_dot;
if !is_path_qual && !is_method {
let span = id.span();
let rewritten = quote_spanned! { span =>
<#marker as ::monadify::Applicative<_>>::pure
};
out.extend(rewritten);
prev2_is_colon_joint = prev_is_colon && prev_colon_is_joint;
prev2_is_dot = prev_is_dot;
prev_is_colon = false;
prev_colon_is_joint = false;
prev_is_dot = false;
continue;
}
}
}
}
prev2_is_colon_joint = prev_is_colon && prev_colon_is_joint;
prev2_is_dot = prev_is_dot;
prev_is_colon = is_colon(&tt);
prev_colon_is_joint = is_colon_joint(&tt);
prev_is_dot = is_dot(&tt);
match tt {
TokenTree::Group(g) => {
let inner = rewrite_pure(g.stream(), marker);
let mut new_g = Group::new(g.delimiter(), inner);
new_g.set_span(g.span());
out.extend(std::iter::once(TokenTree::Group(new_g)));
}
other => {
out.extend(std::iter::once(other));
}
}
}
out
}
fn find_left_arrow(tokens: &[TokenTree]) -> Option<usize> {
for i in 0..tokens.len().saturating_sub(1) {
if let (TokenTree::Punct(a), TokenTree::Punct(b)) = (&tokens[i], &tokens[i + 1]) {
if a.as_char() == '<' && a.spacing() == Spacing::Joint && b.as_char() == '-' {
return Some(i);
}
}
}
None
}
fn classify(tokens: Vec<TokenTree>, marker: &Type) -> SynResult<Stmt> {
if tokens.is_empty() {
return Err(Error::new(
Span::call_site(),
"empty statement in mdo! block (stray `;`?)",
));
}
if let Some(idx) = find_left_arrow(&tokens) {
let pat_ts: TokenStream2 = tokens[..idx].iter().cloned().collect();
let raw_expr_ts: TokenStream2 = tokens[idx + 2..].iter().cloned().collect();
if pat_ts.is_empty() {
return Err(Error::new(
tokens[idx].span(),
"mdo! bind requires a pattern before `<-`",
));
}
if raw_expr_ts.is_empty() {
return Err(Error::new(
tokens[idx].span(),
"mdo! bind requires a monadic expression after `<-`",
));
}
let pat = Pat::parse_single.parse2(pat_ts)?;
let expr_ts = rewrite_pure(raw_expr_ts, marker);
let expr: Expr = syn::parse2(expr_ts)?;
return Ok(Stmt::Bind(pat, expr));
}
if let TokenTree::Ident(id) = &tokens[0] {
if id == "let" {
return Ok(Stmt::Let(tokens.into_iter().collect()));
}
if id == "guard" && tokens.len() == 2 {
if let TokenTree::Group(g) = &tokens[1] {
if g.delimiter() == Delimiter::Parenthesis {
return Ok(Stmt::Guard(rewrite_pure(g.stream(), marker)));
}
}
}
}
let raw_ts: TokenStream2 = tokens.into_iter().collect();
let expr_ts = rewrite_pure(raw_ts, marker);
let expr: Expr = syn::parse2(expr_ts)?;
Ok(Stmt::Bare(expr))
}
impl Parse for MdoInput {
fn parse(input: ParseStream) -> SynResult<Self> {
if input.is_empty() {
return Err(Error::new(
Span::call_site(),
"mdo! requires a block marker followed by at least a final expression, \
e.g. `mdo! { OptionKind; OptionKind::pure(1) }`",
));
}
let marker: Type = input.parse().map_err(|_| {
Error::new(
Span::call_site(),
"mdo! must start with a block marker type, e.g. `mdo! { OptionKind; … }`",
)
})?;
if !matches!(marker, Type::Path(_)) {
return Err(Error::new_spanned(
&marker,
"mdo! block marker must be a type path, e.g. `OptionKind` or `ResultKind::<String>`",
));
}
input.parse::<Token![;]>().map_err(|_| {
Error::new(
Span::call_site(),
"mdo! marker must be followed by `;`, e.g. `mdo! { OptionKind; … }`",
)
})?;
let rest: TokenStream2 = input.parse()?;
let mut segments: Vec<(Vec<TokenTree>, bool)> = Vec::new();
let mut cur: Vec<TokenTree> = Vec::new();
for tt in rest {
if let TokenTree::Punct(p) = &tt {
if p.as_char() == ';' {
segments.push((std::mem::take(&mut cur), true));
continue;
}
}
cur.push(tt);
}
if !cur.is_empty() {
segments.push((cur, false));
}
if segments.is_empty() {
return Err(Error::new(
Span::call_site(),
"mdo! requires at least a final expression after the marker",
));
}
if segments.last().map(|(_, semi)| *semi).unwrap_or(true) {
return Err(Error::new(
Span::call_site(),
"mdo! block must end with a final monadic expression and no trailing `;`",
));
}
let (final_tokens, _) = segments.pop().expect("checked non-empty above");
if let Some(idx) = find_left_arrow(&final_tokens) {
return Err(Error::new(
final_tokens[idx].span(),
"the final line of an mdo! block must be a raw monadic value, not a `<-` bind",
));
}
let final_raw_ts: TokenStream2 = final_tokens.into_iter().collect();
let final_expr: Expr = syn::parse2(rewrite_pure(final_raw_ts, &marker))?;
let mut stmts = Vec::with_capacity(segments.len());
for (tokens, _) in segments {
stmts.push(classify(tokens, &marker)?);
}
Ok(MdoInput {
marker,
stmts,
final_expr,
})
}
}
#[proc_macro]
pub fn mdo(input: TokenStream) -> TokenStream {
let parsed = match syn::parse::<MdoInput>(input) {
Ok(p) => p,
Err(e) => return e.to_compile_error().into(),
};
let marker = &parsed.marker;
let mut acc: TokenStream2 = parsed.final_expr.to_token_stream();
for stmt in parsed.stmts.iter().rev() {
acc = match stmt {
Stmt::Bind(pat, expr) => quote! {
<#marker as ::monadify::Bind<_, _>>::bind(
(#expr).clone(),
move |#pat| { #acc }
)
},
Stmt::Bare(expr) => quote! {
<#marker as ::monadify::Bind<_, _>>::bind(
(#expr).clone(),
move |_| { #acc }
)
},
Stmt::Guard(cond) => quote! {
<#marker as ::monadify::Bind<_, _>>::bind(
(<#marker as ::monadify::MdoGuard>::guard(#cond)).clone(),
move |_| { #acc }
)
},
Stmt::Let(raw) => quote! {
{ #raw ; #acc }
},
};
}
acc.into()
}