vts 1.1.2

Macro to generate boiler plate to define new types with associated constraints
Documentation
#![doc = include_str!("../README.md")]

use proc_macro::TokenStream;
use proc_macro2::Ident;
use syn::{
    parse::{Parse, ParseStream},
    parse_macro_input,
    spanned::Spanned,
    Token,
};

/// macro to define new type with given constraint
///
/// This macro generates all the boiler plate for you
///
/// * create an error types for this specific constraint verifation
/// * implement the constructor with verification of the constraint
///
/// ```
/// use vts::vts;
/// vts! {
///   /// some documentations
///   #[derive(Debug, Clone)]
///   pub type NonNulNatural = usize
///       where self.0 > 0;
/// }
///
/// assert!(matches!(NonNulNatural::new(0), Err(NonNulNaturalConstraintError)));
/// assert!(matches!(NonNulNatural::new(1), Ok(_)));
/// ```
#[proc_macro]
pub fn vts(input: TokenStream) -> TokenStream {
    let vtss = parse_macro_input!(input as VTSS);

    let vtss = vtss.vtss.into_iter().map(render_vts);

    let expanded = quote::quote! {
        #( #vtss )*
    };

    TokenStream::from(expanded)
}

fn render_vts(vts: VTS) -> proc_macro2::TokenStream {
    let VTS {
        attributes,
        visibility,
        name,
        base_type,
        expression,
        ..
    } = vts;

    let type_name = name.to_string();
    let asset_sized_ident = Ident::new(&format!("__AssertSized{name}"), base_type.span());
    let error_type_name = format!("{name}ConstraintError");
    let error_type_ident = Ident::new(&error_type_name, expression.span());
    let expression = if let Some(expression) = expression {
        expression
    } else {
        syn::Expr::Lit(syn::ExprLit {
            attrs: vec![],
            lit: syn::Lit::Bool(syn::LitBool::new(true, type_name.span())),
        })
    };

    let assert_sized = quote::quote_spanned! { base_type.span() =>
        struct #asset_sized_ident where #base_type: ::std::marker::Sized;
    };

    let error_type_definition = quote::quote_spanned!(expression.span() =>
        #[doc = "Error that will be returned if the associated constraint does not hold true"]
        #[derive(Clone, Copy, Debug)]
        #visibility
        struct #error_type_ident;

        impl ::std::error::Error for #error_type_ident {}
        impl ::std::fmt::Display for #error_type_ident {
            fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
                ::std::write!(f, "Value failed the constraint on {}", #type_name)
            }
        }
    );

    let constraint_function = quote::quote_spanned!(expression.span()=>
        fn __constraint(&self) -> bool {
            #assert_sized

            #expression
        }
    );

    let assert_function = quote::quote_spanned!(expression.span()=>
        fn __assert(&self) {
            ::std::assert!(
                self.__constraint()
            )
        }
    );

    if let Some(base_type) = base_type {
        quote::quote! {
            #( #attributes )*
            #visibility
            struct #name (#base_type);

            #error_type_definition

            impl #name {
                pub fn new(value: #base_type) -> ::std::result::Result<Self, #error_type_ident> {
                    let value = #name (value);
                    if value.__constraint() {
                        Ok(value)
                    } else {
                        Err(#error_type_ident)
                    }
                }

                #constraint_function
                #assert_function
            }

            impl ::std::ops::Deref for #name {
                type Target = #base_type;
                fn deref(&self) -> &Self::Target {
                    &self.0
                }
            }
        }
    } else {
        quote::quote! {
            #( #attributes )*
            #visibility
            struct #name;

            #error_type_definition

            impl #name {
                pub fn new() -> Self {
                    Self
                }
            }
        }
    }
}

#[allow(clippy::upper_case_acronyms)]
struct VTSS {
    vtss: Vec<VTS>,
}

#[allow(clippy::upper_case_acronyms)]
struct VTS {
    attributes: Vec<syn::Attribute>,
    visibility: syn::Visibility,
    _type_token: Token![type],
    name: syn::Ident,
    _equal_token: Option<Token![=]>,
    base_type: Option<syn::Ident>,
    _where_token: Option<Token![where]>,
    expression: Option<syn::Expr>,
    _semi_colon_token: Token![;],
}

impl Parse for VTSS {
    fn parse(input: ParseStream) -> syn::Result<Self> {
        let mut vtss = Vec::new();

        loop {
            if input.is_empty() {
                break;
            }

            let vts = input.parse()?;
            vtss.push(vts);
        }

        Ok(Self { vtss })
    }
}

impl Parse for VTS {
    fn parse(input: ParseStream) -> syn::Result<Self> {
        let attributes = input.call(syn::Attribute::parse_outer)?;
        let visibility = input.parse()?;
        let _type_token = input.parse()?;
        let name = input.parse()?;
        let (_equal_token, base_type) = if input.peek(Token![=]) {
            let _equal_token = input.parse()?;
            let base_type = input.parse()?;
            (Some(_equal_token), base_type)
        } else {
            (None, None)
        };
        let (_where_token, expression) = if input.peek(Token![where]) {
            let _where_token = input.parse()?;
            let expression = input.parse()?;
            (Some(_where_token), Some(expression))
        } else {
            (None, None)
        };
        let _semi_colon_token = input.parse()?;

        Ok(Self {
            attributes,
            visibility,
            _type_token,
            name,
            _equal_token,
            base_type,
            _where_token,
            expression,
            _semi_colon_token,
        })
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn parse_no_type() {
        let vts: VTS = syn::parse_str(
            r###"
                type Password;
            "###,
        )
        .unwrap();

        assert!(vts.attributes.is_empty());
        assert!(matches!(vts.visibility, syn::Visibility::Inherited));
        assert!(vts.base_type.is_none());
        assert!(vts.expression.is_none());
    }

    #[test]
    fn parse_no_constraint() {
        let vts: VTS = syn::parse_str(
            r###"
                type Password = String;
            "###,
        )
        .unwrap();

        assert!(vts.attributes.is_empty());
        assert!(matches!(vts.visibility, syn::Visibility::Inherited));
        assert!(vts.base_type.is_some());
        assert!(vts.expression.is_none());
    }

    #[test]
    fn parse_1() {
        let vts: VTS = syn::parse_str(
            r###"
                type Password = String where self.len() > 3 && self.len() < 10;
            "###,
        )
        .unwrap();

        assert!(vts.attributes.is_empty());
        assert!(matches!(vts.visibility, syn::Visibility::Inherited));
        assert!(vts.base_type.is_some());
        assert!(vts.expression.is_some());
    }

    #[test]
    fn parse_n() {
        let vtss: VTSS = syn::parse_str(
            r###"
                /// some documentation about it
                type Password = String where self.len() > 3 && self.len() < 10;
                /// more documentation
                type DigitStr = String where self.chars().all(|c| c.is_digit());
            "###,
        )
        .unwrap();

        assert_eq!(vtss.vtss.len(), 2);

        assert_eq!(vtss.vtss[0].attributes.len(), 1);
        assert!(matches!(
            vtss.vtss[0].visibility,
            syn::Visibility::Inherited
        ));
        assert!(vtss.vtss[0].base_type.is_some());
        assert!(vtss.vtss[0].expression.is_some());

        assert_eq!(vtss.vtss[1].attributes.len(), 1);
        assert!(matches!(
            vtss.vtss[1].visibility,
            syn::Visibility::Inherited
        ));
        assert!(vtss.vtss[1].base_type.is_some());
        assert!(vtss.vtss[1].expression.is_some());
    }
}