rsmonad_macros/
lib.rs

1//! Provides the `monad! {...}` macro, which parses (1) a data structure definition, (2) a function `bind`, and (3) a function `consume`, and implements the most idiomatic macro available as Rust continues to evolve.
2//! # Use
3//! ```rust
4//! // use rsmonad::prelude::*; // <-- In your code, use this; here, though, we redefine a simpler `Maybe`, so we can't import everything
5//! use rsmonad::prelude::{Monad, monad};
6//!
7//! monad! {
8//!     /// Encodes the possibility of failure.
9//!     enum ExampleMaybe<A> {
10//!         EgNothing,
11//!         EgJust(A),
12//!     }
13//!
14//!     fn bind(self, f) {
15//!         match self {
16//!             EgNothing => EgNothing,
17//!             EgJust(b) => f(b),
18//!         }
19//!     }
20//!
21//!     fn consume(a) {
22//!         EgJust(a)
23//!     }
24//! }
25//!
26//! fn could_overflow(x: u8) -> ExampleMaybe<u8> {
27//!     x.checked_add(1).map_or(EgNothing, EgJust)
28//! }
29//!
30//! # fn main() {
31//! assert_eq!(
32//!     EgNothing >> could_overflow,
33//!     EgNothing
34//! );
35//! assert_eq!(
36//!     EgJust(1) >> could_overflow,
37//!     EgJust(2)
38//! );
39//! assert_eq!(
40//!     EgJust(255) >> could_overflow,
41//!     EgNothing
42//! );
43//! # }
44//! ```
45
46#![deny(warnings)]
47#![warn(
48    clippy::all,
49    clippy::missing_docs_in_private_items,
50    clippy::nursery,
51    clippy::pedantic,
52    clippy::restriction,
53    clippy::cargo,
54    missing_docs,
55    rustdoc::all
56)]
57#![allow(
58    clippy::blanket_clippy_restriction_lints,
59    clippy::implicit_return,
60    clippy::pattern_type_mismatch,
61    clippy::question_mark_used,
62    clippy::shadow_reuse,
63    clippy::shadow_unrelated,
64    clippy::string_add,
65    clippy::wildcard_enum_match_arm
66)]
67
68use proc_macro2::{Delimiter, Span, TokenStream, TokenTree};
69use quote::{quote, ToTokens};
70use syn::spanned::Spanned;
71
72/// Write the boilerplate for a monad given the minimal definition.
73#[proc_macro]
74pub fn monad(ts: proc_macro::TokenStream) -> proc_macro::TokenStream {
75    match transmute(ts.into()) {
76        Ok(out) => out,
77        Err(e) => e.to_compile_error(),
78    }
79    .into()
80}
81
82/// Gets the next character and makes sure it exists.
83macro_rules! next {
84    ($tokens:ident, $span:expr, $msg:expr $(,)?) => {
85        $tokens.next().ok_or_else(|| syn::Error::new($span, $msg))?
86    };
87}
88
89/// Exits immediately with a custom compilation error.
90macro_rules! bail {
91    ($span:expr, $msg:expr $(,)?) => {
92        return Err(syn::Error::new($span, $msg))
93    };
94}
95
96/// Matches very safely against a token tree without forcing you to repeat yourself.
97macro_rules! match_tt {
98    ($tokens:ident, $Type:ident, $msg:expr, $prev_span:expr $(,)?) => {
99        match next!($tokens, $prev_span, concat!($msg, " after this")) {
100            TokenTree::$Type(matched) => matched,
101            other => bail!(other.span(), $msg),
102        }
103    };
104}
105
106/// Speed through attributes, pasting them verbatim
107fn skip_attributes(
108    out: &mut TokenStream,
109    tokens: &mut proc_macro2::token_stream::IntoIter,
110) -> syn::Result<TokenTree> {
111    loop {
112        let tt = next!(
113            tokens,
114            Span::call_site(),
115            "Expected a data structure definition",
116        );
117        let TokenTree::Punct(pound) = tt else {
118            return Ok(tt);
119        };
120        if pound.as_char() != '#' {
121            bail!(pound.span(), "Expected a data structure definition; found a single character (that is not a '#' before an attribute)");
122        }
123        pound.to_tokens(out);
124
125        let attr = match_tt!(tokens, Group, "Expected an attribute", pound.span());
126        if attr.delimiter() != Delimiter::Bracket {
127            bail!(attr.span(), "Expected an attribute in [...] brackets")
128        }
129        attr.to_tokens(out);
130    }
131}
132
133/// Actually transform the AST, returning an error without boilerplate to be handled above.
134#[allow(clippy::too_many_lines)]
135fn transmute(raw_ts: TokenStream) -> syn::Result<TokenStream> {
136    let mut out = TokenStream::new();
137    let mut tokens = raw_ts.into_iter();
138
139    // Parse the data-structure declaration
140    let mut data_structure = TokenStream::new();
141    let mut structure = match skip_attributes(&mut out, &mut tokens)? {
142        TokenTree::Ident(i) => i,
143        tt => bail!(tt.span(), "Expected a data structure definition",),
144    };
145    let mut publicity = None;
146    structure.to_tokens(&mut data_structure);
147    if structure == "pub" {
148        publicity = Some(structure);
149        structure = match_tt!(
150            tokens,
151            Ident,
152            "Expected a data structure definition",
153            publicity.span(),
154        );
155        structure.to_tokens(&mut data_structure);
156    }
157    let name = match_tt!(tokens, Ident, "Expected a name", Span::call_site());
158    name.to_tokens(&mut data_structure);
159    let generic_open = match_tt!(tokens, Punct, "Expected generics, e.g. `<A>`", name.span());
160    if generic_open.as_char() != '<' {
161        bail!(generic_open.span(), "Expected generics, e.g. `<A>`");
162    }
163    generic_open.to_tokens(&mut data_structure);
164    let mut inception: u8 = 1;
165    'generic_loop: loop {
166        let generic = next!(tokens, generic_open.span(), "Unterminated generics");
167        generic.to_tokens(&mut data_structure);
168        if let TokenTree::Punct(ref maybe_close) = generic {
169            match maybe_close.as_char() {
170                '<' => {
171                    inception = inception.checked_add(1).ok_or_else(|| {
172                        syn::Error::new(
173                            maybe_close.span(),
174                            "Call Christopher Nolan: this inception is too deep",
175                        )
176                    })?;
177                }
178                '>' => {
179                    inception = inception.wrapping_sub(1);
180                    if inception == 0 {
181                        break 'generic_loop;
182                    }
183                }
184                _ => (),
185            }
186        }
187    }
188
189    // Parse the definition itself
190    let def_block = match_tt!(
191        tokens,
192        Group,
193        "Expected a definition block, e.g. `{...}`",
194        data_structure.span()
195    );
196    def_block.to_tokens(&mut data_structure);
197    if def_block.delimiter() == Delimiter::Parenthesis {
198        let semicolon = match_tt!(
199            tokens,
200            Punct,
201            "Expected a semicolon (after a tuple-struct)",
202            def_block.span()
203        );
204        if semicolon.as_char() != ';' {
205            bail!(
206                semicolon.span(),
207                "Expected a semicolon (after a tuple-struct)",
208            );
209        }
210        semicolon.to_tokens(&mut data_structure);
211    }
212
213    // Parse as either a `struct` or `enum`
214    let (ident, generics, fields) = if structure == "enum" {
215        from_enum(
216            &mut out,
217            syn::parse2(data_structure).map_err(move |e| {
218                syn::Error::new(
219                    e.span(),
220                    e.to_string() + " (`syn` error while parsing as an enum)",
221                )
222            })?,
223            publicity,
224        )?
225    } else if structure == "struct" {
226        from_struct(
227            &mut out,
228            syn::parse2(data_structure).map_err(move |e| {
229                syn::Error::new(
230                    e.span(),
231                    e.to_string() + " (`syn` error while parsing as a struct)",
232                )
233            })?,
234        )?
235    } else {
236        bail!(structure.span(), "Expected either `struct` or `enum`");
237    };
238
239    let bind = parse_bind(&mut tokens, &def_block, &generics)?;
240    let consume = parse_consume(&mut tokens, &def_block)?;
241    impl_mod(&ident, bind, consume)?.to_tokens(&mut out);
242
243    write_arbitrary_impl(ident, fields)?.to_tokens(&mut out);
244
245    Ok(out)
246}
247
248/// Write a `quickcheck::Arbitrary` implementation.
249#[allow(clippy::too_many_lines)]
250fn write_arbitrary_impl(ident: syn::Ident, fields: Fields) -> syn::Result<syn::ItemImpl> {
251    Ok(syn::ItemImpl {
252        attrs: vec![],
253        defaultness: None,
254        unsafety: None,
255        impl_token: syn::token::Impl {
256            span: Span::call_site(),
257        },
258        generics: syn::parse2(quote! { <A: quickcheck::Arbitrary> })?,
259        trait_: Some((
260            None,
261            syn::parse2(quote! { quickcheck::Arbitrary })?,
262            syn::token::For {
263                span: Span::call_site(),
264            },
265        )),
266        self_ty: Box::new(syn::Type::Path(syn::TypePath {
267            qself: None,
268            path: syn::Path {
269                leading_colon: None,
270                segments: {
271                    let mut p = syn::punctuated::Punctuated::new();
272                    p.push_value(syn::PathSegment {
273                        ident,
274                        arguments: syn::PathArguments::AngleBracketed(syn::parse2(quote! { <A> })?),
275                    });
276                    p
277                },
278            },
279        })),
280        brace_token: syn::token::Brace {
281            span: proc_macro2::Group::new(Delimiter::Brace, TokenStream::new()).delim_span(),
282        },
283        items: vec![syn::ImplItem::Fn({
284            let mut def: syn::ImplItemFn =
285                syn::parse2(quote! { fn arbitrary(g: &mut quickcheck::Gen) -> FixItInPost {} })?;
286            def.sig.output = syn::ReturnType::Type(
287                syn::parse2(quote! { -> })?,
288                Box::new(syn::parse2(quote! { Self })?),
289            );
290            def.block.stmts.push(syn::Stmt::Expr(
291                match fields {
292                    Fields::EnumVariants(variants) => {
293                        let mut elems = syn::punctuated::Punctuated::new();
294                        for variant in variants {
295                            let body = if matches!(variant.fields, syn::Fields::Unit) {
296                                Box::new(syn::Expr::Path(syn::ExprPath {
297                                    attrs: vec![],
298                                    qself: None,
299                                    path: syn::Path {
300                                        leading_colon: None,
301                                        segments: {
302                                            let mut p = syn::punctuated::Punctuated::new();
303                                            p.push_value(syn::PathSegment {
304                                                ident: variant.ident,
305                                                arguments: syn::PathArguments::None,
306                                            });
307                                            p
308                                        },
309                                    },
310                                }))
311                            } else {
312                                Box::new(syn::Expr::Call(syn::ExprCall {
313                                    attrs: vec![],
314                                    func: Box::new(syn::Expr::Path(syn::ExprPath {
315                                        attrs: vec![],
316                                        qself: None,
317                                        path: syn::Path {
318                                            leading_colon: None,
319                                            segments: {
320                                                let mut p = syn::punctuated::Punctuated::new();
321                                                p.push_value(syn::PathSegment {
322                                                    ident: variant.ident,
323                                                    arguments: syn::PathArguments::None,
324                                                });
325                                                p
326                                            },
327                                        },
328                                    })),
329                                    paren_token: syn::token::Paren {
330                                        span: proc_macro2::Group::new(
331                                            Delimiter::Parenthesis,
332                                            TokenStream::new(),
333                                        )
334                                        .delim_span(),
335                                    },
336                                    args: {
337                                        let mut p = syn::punctuated::Punctuated::new();
338                                        match variant.fields {
339                                            // SAFETY:
340                                            // Logically impossible. See `if` statement at definition of `body`.
341                                            syn::Fields::Unit => unsafe { core::hint::unreachable_unchecked() },
342                                            syn::Fields::Unnamed(members) => {
343                                                for member in members.unnamed {
344                                                    p.push(syn::Expr::Call({
345                                                        let mut init: syn::ExprCall = syn::parse2(quote! { <FixItInPost as quickcheck::Arbitrary>::arbitrary(gen) })?;
346                                                        match init.func.as_mut() {
347                                                            syn::Expr::Path(path) => {
348                                                                let Some(qself) = &mut path.qself else {
349                                                                    bail!(init.span(), "rsmonad-internal error: couldn't parse qself in `<T as quickcheck::Arbitrary>::arbitrary(gen)`");
350                                                                };
351                                                                *qself.ty.as_mut() = member.ty;
352                                                            }
353                                                            _ => bail!(init.span(), "rsmonad-internal error: couldn't parse `<T as quickcheck::Arbitrary>::arbitrary(gen)` as a path"),
354                                                        }
355                                                        init
356                                                    }));
357                                                }
358                                            }
359                                            syn::Fields::Named(members) => {
360                                                for member in members.named {
361                                                    p.push(syn::Expr::Call({
362                                                        let mut init: syn::ExprCall = syn::parse2(quote! { <FixItInPost as quickcheck::Arbitrary>::arbitrary(gen) })?;
363                                                        match init.func.as_mut() {
364                                                            syn::Expr::Path(path) => {
365                                                                let Some(qself) = &mut path.qself else {
366                                                                    bail!(init.span(), "rsmonad-internal error: couldn't parse qself in `<T as quickcheck::Arbitrary>::arbitrary(gen)`");
367                                                                };
368                                                                *qself.ty.as_mut() = member.ty;
369                                                            }
370                                                            _ => bail!(init.span(), "rsmonad-internal error: couldn't parse `<T as quickcheck::Arbitrary>::arbitrary(gen)` as a path"),
371                                                        }
372                                                        init
373                                                    }));
374                                                }
375                                            }
376                                        }
377                                        p
378                                    },
379                                }))
380                            };
381                            let closure = syn::Expr::Closure(syn::ExprClosure {
382                                attrs: vec![],
383                                lifetimes: None,
384                                constness: None,
385                                movability: None,
386                                asyncness: None,
387                                capture: Some(syn::token::Move { span: Span::call_site() }),
388                                or1_token: syn::token::Or {
389                                    spans: [Span::call_site()],
390                                },
391                                inputs: {
392                                    let mut inputs = syn::punctuated::Punctuated::new();
393                                    inputs.push_value(syn::Pat::Ident(syn::PatIdent {
394                                        attrs: vec![],
395                                        by_ref: None,
396                                        mutability: None,
397                                        ident: syn::Ident::new("gen", Span::call_site()),
398                                        subpat: None,
399                                    }));
400                                    inputs
401                                },
402                                or2_token: syn::token::Or {
403                                    spans: [Span::call_site()],
404                                },
405                                output: syn::ReturnType::Default,
406                                body,
407                            });
408                            let paren = syn::Expr::Paren(syn::ExprParen {
409                                attrs: vec![],
410                                paren_token: syn::token::Paren {
411                                    span: proc_macro2::Group::new(
412                                        Delimiter::Parenthesis,
413                                        TokenStream::new(),
414                                    )
415                                    .delim_span(),
416                                },
417                                expr: Box::new(closure),
418                            });
419                            elems.push(syn::Expr::Cast(syn::ExprCast {
420                                attrs: vec![],
421                                expr: Box::new(paren),
422                                as_token: syn::token::As { span: Span::call_site() },
423                                ty: Box::new(syn::parse2(quote! { fn(&mut quickcheck::Gen) -> Self })?),
424                            }));
425                        }
426                        let mut choose: syn::ExprCall = syn::parse2(quote! { g.choose::<fn(&mut quickcheck::Gen) -> Self>(&[]).unwrap()(g) })?;
427                        let syn::Expr::MethodCall(pre_call) = choose.func.as_mut() else {
428                            bail!(Span::call_site(), "rsmonad-internal error: expected a method call")
429                        };
430                        let syn::Expr::MethodCall(pre_pre_call) = pre_call.receiver.as_mut() else {
431                            bail!(Span::call_site(), "rsmonad-internal error: expected a method call")
432                        };
433                        let Some(syn::Expr::Reference(array_ref)) = pre_pre_call.args.first_mut() else {
434                            bail!(Span::call_site(), "rsmonad-internal error: expected a single reference argument")
435                        };
436                        let syn::Expr::Array(closures) = array_ref.expr.as_mut() else {
437                            bail!(choose.args.span(), "rsmonad-internal error: expected an array")
438                        };
439                        closures.elems = elems;
440                        syn::Expr::Call(choose)
441                    }
442                    Fields::StructMembers(members) => match members {
443                        syn::Fields::Unit => syn::Expr::Path(syn::ExprPath {
444                            attrs: vec![],
445                            qself: None,
446                            path: syn::Path {
447                                leading_colon: None,
448                                segments: {
449                                    let mut p = syn::punctuated::Punctuated::new();
450                                    p.push_value(syn::PathSegment {
451                                        ident: syn::Ident::new("Self", Span::call_site()),
452                                        arguments: syn::PathArguments::None,
453                                    });
454                                    p
455                                }
456                            }
457                        }),
458                        syn::Fields::Named(named) => {
459                            syn::Expr::Struct(syn::ExprStruct {
460                                attrs: vec![],
461                                qself: None,
462                                path: syn::Path {
463                                    leading_colon: None,
464                                    segments: {
465                                        let mut p = syn::punctuated::Punctuated::new();
466                                        p.push_value(syn::PathSegment {
467                                            ident: syn::Ident::new("Self", Span::call_site()),
468                                            arguments: syn::PathArguments::None,
469                                        });
470                                        p
471                                    }
472                                },
473                                brace_token: syn::token::Brace {
474                                    span: proc_macro2::Group::new(
475                                        Delimiter::Brace,
476                                        TokenStream::new(),
477                                    )
478                                    .delim_span(),
479                                },
480                                fields: {
481                                    let mut p = syn::punctuated::Punctuated::new();
482                                    for member in named.named {
483                                        p.push(syn::FieldValue {
484                                            attrs: vec![],
485                                            member: syn::Member::Named(member.ident.clone().ok_or_else(|| syn::Error::new(member.span(), "Expected a named field"))?),
486                                            colon_token: Some(syn::token::Colon { spans: [Span::call_site()] }),
487                                            expr: syn::Expr::Call({
488                                                let mut init: syn::ExprCall = syn::parse2(quote! { <FixItInPost as quickcheck::Arbitrary>::arbitrary(g) })?;
489                                                match init.func.as_mut() {
490                                                    syn::Expr::Path(path) => {
491                                                        let Some(qself) = &mut path.qself else {
492                                                            bail!(init.span(), "rsmonad-internal error: couldn't parse qself in `<T as quickcheck::Arbitrary>::arbitrary(g)`");
493                                                        };
494                                                        *qself.ty.as_mut() = member.ty;
495                                                    }
496                                                    _ => bail!(init.span(), "rsmonad-internal error: couldn't parse `<T as quickcheck::Arbitrary>::arbitrary(g)` as a path"),
497                                                }
498                                                init
499                                            }),
500                                        });
501                                    }
502                                    p
503                                },
504                                dot2_token: None,
505                                rest: None,
506                            })
507                        },
508                        syn::Fields::Unnamed(unnamed) => {
509                            syn::Expr::Call(syn::ExprCall {
510                                attrs: vec![],
511                                func: Box::new(syn::Expr::Path(syn::ExprPath {
512                                    attrs: vec![],
513                                    qself: None,
514                                    path: syn::Path {
515                                        leading_colon: None,
516                                        segments: {
517                                            let mut p = syn::punctuated::Punctuated::new();
518                                            p.push_value(syn::PathSegment {
519                                                ident: syn::Ident::new("Self", Span::call_site()),
520                                                arguments: syn::PathArguments::None,
521                                            });
522                                            p
523                                        },
524                                    },
525                                })),
526                                paren_token: syn::token::Paren {
527                                    span: proc_macro2::Group::new(
528                                        Delimiter::Parenthesis,
529                                        TokenStream::new(),
530                                    )
531                                    .delim_span(),
532                                },
533                                args: {
534                                    let mut args = syn::punctuated::Punctuated::new();
535                                    for member in unnamed.unnamed {
536                                        args.push(
537                                            syn::Expr::Call({
538                                                let mut init: syn::ExprCall = syn::parse2(quote! { <FixItInPost as quickcheck::Arbitrary>::arbitrary(g) })?;
539                                                match init.func.as_mut() {
540                                                    syn::Expr::Path(path) => {
541                                                        let Some(qself) = &mut path.qself else {
542                                                            bail!(init.span(), "rsmonad-internal error: couldn't parse qself in `<T as quickcheck::Arbitrary>::arbitrary(g)`");
543                                                        };
544                                                        *qself.ty.as_mut() = member.ty;
545                                                    }
546                                                    _ => bail!(init.span(), "rsmonad-internal error: couldn't parse `<T as quickcheck::Arbitrary>::arbitrary(g)` as a path"),
547                                                }
548                                                init
549                                            }),
550                                        );
551                                    }
552                                    args
553                                },
554                            })
555                        },
556                    },
557                },
558                None,
559            ));
560            def
561        })],
562    })
563}
564
565/// Parse the definition of `bind`.
566fn parse_bind(
567    tokens: &mut proc_macro2::token_stream::IntoIter,
568    def_block: &proc_macro2::Group,
569    generics: &syn::Generics,
570) -> syn::Result<syn::ImplItemFn> {
571    // Parse `bind`
572    let mut bind = TokenStream::new();
573    let t_fn = match_tt!(tokens, Ident, "Expected `fn`", def_block.span());
574    if t_fn != "fn" {
575        bail!(t_fn.span(), "Expected `fn`",);
576    }
577    t_fn.to_tokens(&mut bind);
578    let t_bind = match_tt!(tokens, Ident, "Expected `bind`", t_fn.span());
579    if t_bind != "bind" {
580        bail!(t_bind.span(), "Expected `bind`")
581    }
582    t_bind.to_tokens(&mut bind);
583    let args = match_tt!(
584        tokens,
585        Group,
586        "Expected arguments immediately after `bind` (no need to repeat the <A: ...> bound)",
587        t_bind.span(),
588    );
589    if args.delimiter() != Delimiter::Parenthesis {
590        bail!(args.span(), "Expected arguments immediately after `bind`");
591    }
592    let args = proc_macro2::Group::new(Delimiter::Parenthesis, {
593        let mut args_ts = TokenStream::new();
594        let mut bare = args.stream().into_iter();
595        let t_self = match skip_attributes(&mut args_ts, &mut bare)? {
596            TokenTree::Ident(i) => i,
597            tt => bail!(tt.span(), "Expected `self`"),
598        };
599        if t_self != "self" {
600            bail!(t_self.span(), "Expected `self`");
601        }
602        t_self.to_tokens(&mut args_ts);
603        let comma = match_tt!(bare, Punct, "Expected a comma", t_self.span());
604        if comma.as_char() != ',' {
605            bail!(comma.span(), "Expected a comma");
606        }
607        comma.to_tokens(&mut args_ts);
608        let f = match skip_attributes(&mut args_ts, &mut bare)? {
609            TokenTree::Ident(i) => i,
610            tt => bail!(tt.span(), "Expected `f`"),
611        };
612        f.to_tokens(&mut args_ts);
613        proc_macro2::Punct::new(':', proc_macro2::Spacing::Alone).to_tokens(&mut args_ts);
614        proc_macro2::Ident::new("F", Span::call_site()).to_tokens(&mut args_ts);
615        args_ts
616    });
617    args.to_tokens(&mut bind);
618    let def_block = match_tt!(
619        tokens,
620        Group,
621        "Expected a function definition block (please don't try to specify return type; it's extremely long and will change as Rust evolves)",
622        args.span(),
623    );
624    if def_block.delimiter() != Delimiter::Brace {
625        bail!(def_block.span(), "Expected a function definition block");
626    }
627    def_block.to_tokens(&mut bind);
628    let mut bind: syn::ImplItemFn = syn::parse2(bind)?;
629    let inline_always: syn::MetaList = syn::parse2(quote! { inline(always) })?;
630    bind.attrs.push(syn::Attribute {
631        pound_token: syn::token::Pound {
632            spans: [Span::call_site()],
633        },
634        style: syn::AttrStyle::Outer,
635        bracket_token: syn::token::Bracket {
636            span: *inline_always.delimiter.span(),
637        },
638        meta: syn::Meta::List(inline_always.clone()),
639    });
640    bind.sig.generics.lt_token = Some(syn::token::Lt {
641        spans: [Span::call_site()],
642    });
643    bind.sig.generics.gt_token = Some(syn::token::Gt {
644        spans: [Span::call_site()],
645    });
646    bind.sig.generics.params.push_value({
647        let Some(syn::GenericParam::Type(gpt)) = generics.params.first() else {
648            bail!(generics.span(), "Expected at least one generic argument");
649        };
650        let mut gpt = gpt.clone();
651        gpt.ident = syn::Ident::new("B", Span::call_site());
652        syn::GenericParam::Type(gpt)
653    });
654    bind.sig
655        .generics
656        .params
657        .push(syn::GenericParam::Type(syn::parse2(
658            quote! { F: Fn(A) -> M<B> },
659        )?));
660    bind.sig.output = syn::ReturnType::Type(
661        syn::token::RArrow {
662            spans: [Span::call_site(), Span::call_site()],
663        },
664        Box::new(syn::Type::Path(syn::parse2(quote! { M<B> })?)),
665    );
666    Ok(bind)
667}
668
669/// Parse the definition of `consume`.
670fn parse_consume(
671    tokens: &mut proc_macro2::token_stream::IntoIter,
672    def_block: &proc_macro2::Group,
673) -> syn::Result<syn::ImplItemFn> {
674    // Parse `consume`
675    let mut consume = TokenStream::new();
676    let t_fn = match_tt!(tokens, Ident, "Expected `fn`", def_block.span(),);
677    if t_fn != "fn" {
678        bail!(
679            Span::call_site(),
680            "Expected a definition for `consume` after `bind`",
681        );
682    }
683    t_fn.to_tokens(&mut consume);
684    let t_consume = match_tt!(
685        tokens,
686        Ident,
687        "Expected a definition for `consume` after `bind`",
688        Span::call_site()
689    );
690    if t_consume != "consume" {
691        bail!(
692            Span::call_site(),
693            "Expected a definition for `consume` after `bind`"
694        )
695    }
696    t_consume.to_tokens(&mut consume);
697    let args = match_tt!(
698        tokens,
699        Group,
700        "Expected arguments immediately after the function name `consume` (no need to repeat the <A: ...> bound)",
701        Span::call_site(),
702    );
703    if args.delimiter() != Delimiter::Parenthesis {
704        bail!(
705            args.span(),
706            "Expected arguments immediately after `consume`"
707        );
708    }
709    let args = proc_macro2::Group::new(Delimiter::Parenthesis, {
710        let mut args_ts = TokenStream::new();
711        let mut bare = args.stream().into_iter();
712        let a = match skip_attributes(&mut args_ts, &mut bare)? {
713            TokenTree::Ident(i) => i,
714            tt => bail!(tt.span(), "Expected `a`"),
715        };
716        a.to_tokens(&mut args_ts);
717        proc_macro2::Punct::new(':', proc_macro2::Spacing::Alone).to_tokens(&mut args_ts);
718        proc_macro2::Ident::new("A", Span::call_site()).to_tokens(&mut args_ts);
719        args_ts
720    });
721    args.to_tokens(&mut consume);
722    let def_block = match_tt!(
723        tokens,
724        Group,
725        "Expected a function definition block (please don't try to specify return type; it's extremely long and will change as Rust evolves)",
726        args.span(),
727    );
728    if def_block.delimiter() != Delimiter::Brace {
729        bail!(def_block.span(), "Expected a function definition block");
730    }
731    def_block.to_tokens(&mut consume);
732    let mut consume: syn::ImplItemFn = syn::parse2(consume)?;
733    let inline_always: syn::MetaList = syn::parse2(quote! { inline(always) })?;
734    consume.attrs.push(syn::Attribute {
735        pound_token: syn::token::Pound {
736            spans: [Span::call_site()],
737        },
738        style: syn::AttrStyle::Outer,
739        bracket_token: syn::token::Bracket {
740            span: *inline_always.delimiter.span(),
741        },
742        meta: syn::Meta::List(inline_always),
743    });
744    consume.sig.generics.lt_token = Some(syn::token::Lt {
745        spans: [Span::call_site()],
746    });
747    consume.sig.generics.gt_token = Some(syn::token::Gt {
748        spans: [Span::call_site()],
749    });
750    consume.sig.output = syn::ReturnType::Type(
751        syn::token::RArrow {
752            spans: [Span::call_site(), Span::call_site()],
753        },
754        Box::new(syn::Type::Path(syn::parse2(quote! { Self })?)),
755    );
756    Ok(consume)
757}
758
759/// Write a `use` statement so we can refer to the implementee by an alias and use more `quote! { ...`.
760fn use_as_m(ident: &syn::Ident) -> syn::Item {
761    syn::Item::Use(syn::ItemUse {
762        attrs: vec![],
763        vis: syn::Visibility::Inherited,
764        use_token: syn::token::Use {
765            span: Span::call_site(),
766        },
767        leading_colon: None,
768        tree: syn::UseTree::Path(syn::UsePath {
769            ident: syn::Ident::new("super", Span::call_site()),
770            colon2_token: syn::token::PathSep {
771                spans: [Span::call_site(), Span::call_site()],
772            },
773            tree: Box::new(syn::UseTree::Rename(syn::UseRename {
774                ident: ident.clone(),
775                as_token: syn::token::As {
776                    span: Span::call_site(),
777                },
778                rename: syn::Ident::new("M", Span::call_site()),
779            })),
780        }),
781        semi_token: syn::token::Semi {
782            spans: [Span::call_site()],
783        },
784    })
785}
786
787/// Write a `mod` with all the implementation details without cluttering the surrounding namespace.
788fn impl_mod(
789    ident: &syn::Ident,
790    bind: syn::ImplItemFn,
791    consume: syn::ImplItemFn,
792) -> syn::Result<syn::ItemMod> {
793    let items = vec![
794        syn::Item::Use(syn::parse2(quote! { use rsmonad::prelude::*; })?),
795        syn::Item::Use(syn::parse2(quote! { use super::*; })?),
796        use_as_m(ident),
797        impl_functor()?,
798        impl_pipe()?,
799        impl_monad(bind, consume)?,
800        impl_rshift()?,
801        quickcheck_laws()?,
802    ];
803    Ok(syn::ItemMod {
804        attrs: vec![],
805        vis: syn::Visibility::Inherited,
806        unsafety: None,
807        mod_token: syn::token::Mod {
808            span: Span::call_site(),
809        },
810        ident: syn::Ident::new(
811            (heck::ToSnakeCase::to_snake_case(ident.to_string().as_str()) + "_impl").as_str(),
812            ident.span(),
813        ),
814        content: Some((
815            syn::token::Brace {
816                span: proc_macro2::Group::new(Delimiter::Brace, TokenStream::new()).delim_span(),
817            },
818            items,
819        )),
820        semi: None,
821    })
822}
823
824/// Write an `impl Functor { ...`
825fn impl_functor() -> syn::Result<syn::Item> {
826    Ok(syn::Item::Impl(syn::parse2(quote! {
827        impl<A> Functor<A> for M<A> {
828            type Functor<B> = M<B>;
829            #[inline(always)]
830            fn fmap<B, F: Fn(A) -> B>(self, f: F) -> M<B> {
831                self.bind(move |x| consume(f(x)))
832            }
833        }
834    })?))
835}
836
837/// Write an `impl Monad { ...`.
838fn impl_monad(
839    // ident: &syn::Ident,
840    // generics: &syn::Generics,
841    bind: syn::ImplItemFn,
842    consume: syn::ImplItemFn,
843) -> syn::Result<syn::Item> {
844    let mut item: syn::ItemImpl = syn::parse2(quote! { impl<A> Monad<A> for M<A> {} })?;
845    item.items.push(syn::ImplItem::Type(syn::parse2(
846        quote! { type Monad<B> = M<B>; },
847    )?));
848    item.items.push(syn::ImplItem::Fn(bind));
849    item.items.push(syn::ImplItem::Fn(consume));
850    Ok(syn::Item::Impl(item))
851}
852
853/// Write an `impl BitOr { ...`.
854fn impl_pipe() -> syn::Result<syn::Item> {
855    Ok(syn::Item::Impl(syn::parse2(quote! {
856        impl<A, B, F: Fn(A) -> B> core::ops::BitOr<F> for M<A> {
857            type Output = M<B>;
858            fn bitor(self, f: F) -> M<B> {
859                self.fmap(f)
860            }
861        }
862    })?))
863}
864
865/// Write an `impl Shr { ...`.
866fn impl_rshift() -> syn::Result<syn::Item> {
867    Ok(syn::Item::Impl(syn::parse2(quote! {
868        impl<A, B, F: Fn(A) -> M<B>> core::ops::Shr<F> for M<A> {
869            type Output = M<B>;
870            fn shr(self, f: F) -> M<B> {
871                self.bind(f)
872            }
873        }
874    })?))
875}
876
877/// Write property-based tests for the monad laws and similar laws for other typeclasses.
878fn quickcheck_laws() -> syn::Result<syn::Item> {
879    Ok(syn::Item::Macro(syn::parse2(quote! {
880        quickcheck::quickcheck! {
881            fn prop_monad_left_identity(a: u64) -> bool {
882                rsmonad::laws::monad::left_identity::<u64, u64, M<u64>, _>(a, &rsmonad::laws::hash_consume)
883            }
884            fn prop_monad_right_identity(ma: M<u64>) -> bool {
885                rsmonad::laws::monad::right_identity(ma)
886            }
887            fn prop_monad_associativity(ma: M<u64>) -> bool {
888                rsmonad::laws::monad::associativity::<u64, u64, u64, M<u64>, _, _>(ma, &rsmonad::laws::hash_consume, &(move |x| consume(u64::reverse_bits(x))))
889            }
890            fn prop_functor_identity(fa: M<u64>) -> bool {
891                rsmonad::laws::functor::identity(fa)
892            }
893            fn prop_functor_composition(fa: M<u64>) -> bool {
894                rsmonad::laws::functor::composition(fa, rsmonad::laws::hash, u64::reverse_bits)
895            }
896        }
897    })?))
898}
899
900/// Attribute deriving common traits.
901fn derives() -> syn::Result<syn::Attribute> {
902    let ml: syn::MetaList = syn::parse2(
903        quote! {derive(Clone, Debug, /* Default, */ Eq, Hash, Ord, PartialEq, PartialOrd)},
904    )
905    .map_err(move |e| {
906        syn::Error::new(
907            e.span(),
908            "rsmonad-internal error: couldn't parse #[derive(...)]. Please file an error--we want to fix what went wrong!",
909        )
910    })?;
911    Ok(syn::Attribute {
912        pound_token: syn::token::Pound {
913            spans: [Span::call_site()],
914        },
915        style: syn::AttrStyle::Outer,
916        bracket_token: syn::token::Bracket {
917            span: *ml.delimiter.span(),
918        },
919        meta: syn::Meta::List(ml),
920    })
921}
922
923/// Attribute allowing exhaustive enums and structs.
924fn exhaustion() -> syn::Result<syn::Attribute> {
925    let ml: syn::MetaList = syn::parse2(
926        quote! { allow(clippy::non_exhaustive_enums, clippy::non_exhaustive_structs) },
927    )
928    .map_err(move |e| {
929        syn::Error::new(
930            e.span(),
931            "rsmonad-internal error: couldn't parse #[allow(...)]. Please file an error--we want to fix what went wrong!",
932        )
933    })?;
934    Ok(syn::Attribute {
935        pound_token: syn::token::Pound {
936            spans: [Span::call_site()],
937        },
938        style: syn::AttrStyle::Outer,
939        bracket_token: syn::token::Bracket {
940            span: *ml.delimiter.span(),
941        },
942        meta: syn::Meta::List(ml),
943    })
944}
945
946/// Either `struct` or `enum` fields.
947enum Fields {
948    /// Enum variants.
949    EnumVariants(syn::punctuated::Punctuated<syn::Variant, syn::token::Comma>),
950    /// Struct members.
951    StructMembers(syn::Fields),
952}
953
954/// Parse an enum.
955fn from_enum(
956    out: &mut TokenStream,
957    mut item: syn::ItemEnum,
958    publicity: Option<proc_macro2::Ident>,
959) -> syn::Result<(syn::Ident, syn::Generics, Fields)> {
960    item.attrs.push(exhaustion()?);
961    item.attrs.push(derives()?);
962    item.to_tokens(out);
963    if let Some(p) = publicity {
964        p.to_tokens(out);
965    }
966    syn::Ident::new("use", Span::call_site()).to_tokens(out);
967    item.ident.to_tokens(out);
968    proc_macro2::Punct::new(':', proc_macro2::Spacing::Joint).to_tokens(out);
969    proc_macro2::Punct::new(':', proc_macro2::Spacing::Alone).to_tokens(out);
970    proc_macro2::Group::new(Delimiter::Brace, {
971        let mut ctors = TokenStream::new();
972        for ctor in &item.variants {
973            ctor.ident.to_tokens(&mut ctors);
974            proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone).to_tokens(&mut ctors);
975        }
976        ctors
977    })
978    .to_tokens(out);
979    proc_macro2::Punct::new(';', proc_macro2::Spacing::Alone).to_tokens(out);
980    Ok((
981        item.ident,
982        item.generics,
983        Fields::EnumVariants(item.variants),
984    ))
985}
986
987/// Parse a struct.
988fn from_struct(
989    out: &mut TokenStream,
990    mut item: syn::ItemStruct,
991) -> syn::Result<(syn::Ident, syn::Generics, Fields)> {
992    item.attrs.push(exhaustion()?);
993    item.attrs.push(derives()?);
994    item.to_tokens(out);
995    Ok((
996        item.ident,
997        item.generics,
998        Fields::StructMembers(item.fields),
999    ))
1000}