#![doc = include_str!("../README.md")]
use proc_macro::TokenStream;
use proc_macro2::Ident;
use syn::{
parse::{Parse, ParseStream},
parse_macro_input,
spanned::Spanned,
Token,
};
#[proc_macro]
pub fn vts(input: TokenStream) -> TokenStream {
let vtss = parse_macro_input!(input as VTSS);
let vtss = vtss.vtss.into_iter().map(render_vts);
let expanded = quote::quote! {
#( #vtss )*
};
TokenStream::from(expanded)
}
fn render_vts(vts: VTS) -> proc_macro2::TokenStream {
let VTS {
attributes,
visibility,
name,
base_type,
expression,
..
} = vts;
let type_name = name.to_string();
let asset_sized_ident = Ident::new(&format!("__AssertSized{name}"), base_type.span());
let error_type_name = format!("{name}ConstraintError");
let error_type_ident = Ident::new(&error_type_name, expression.span());
let expression = if let Some(expression) = expression {
expression
} else {
syn::Expr::Lit(syn::ExprLit {
attrs: vec![],
lit: syn::Lit::Bool(syn::LitBool::new(true, type_name.span())),
})
};
let assert_sized = quote::quote_spanned! { base_type.span() =>
struct #asset_sized_ident where #base_type: ::std::marker::Sized;
};
let error_type_definition = quote::quote_spanned!(expression.span() =>
#[doc = "Error that will be returned if the associated constraint does not hold true"]
#[derive(Clone, Copy, Debug)]
#visibility
struct #error_type_ident;
impl ::std::error::Error for #error_type_ident {}
impl ::std::fmt::Display for #error_type_ident {
fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
::std::write!(f, "Value failed the constraint on {}", #type_name)
}
}
);
let constraint_function = quote::quote_spanned!(expression.span()=>
fn __constraint(&self) -> bool {
#assert_sized
#expression
}
);
let assert_function = quote::quote_spanned!(expression.span()=>
fn __assert(&self) {
::std::assert!(
self.__constraint()
)
}
);
if let Some(base_type) = base_type {
quote::quote! {
#( #attributes )*
#visibility
struct #name (#base_type);
#error_type_definition
impl #name {
pub fn new(value: #base_type) -> ::std::result::Result<Self, #error_type_ident> {
let value = #name (value);
if value.__constraint() {
Ok(value)
} else {
Err(#error_type_ident)
}
}
#constraint_function
#assert_function
}
impl ::std::ops::Deref for #name {
type Target = #base_type;
fn deref(&self) -> &Self::Target {
&self.0
}
}
}
} else {
quote::quote! {
#( #attributes )*
#visibility
struct #name;
#error_type_definition
impl #name {
pub fn new() -> Self {
Self
}
}
}
}
}
#[allow(clippy::upper_case_acronyms)]
struct VTSS {
vtss: Vec<VTS>,
}
#[allow(clippy::upper_case_acronyms)]
struct VTS {
attributes: Vec<syn::Attribute>,
visibility: syn::Visibility,
_type_token: Token![type],
name: syn::Ident,
_equal_token: Option<Token![=]>,
base_type: Option<syn::Ident>,
_where_token: Option<Token![where]>,
expression: Option<syn::Expr>,
_semi_colon_token: Token![;],
}
impl Parse for VTSS {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut vtss = Vec::new();
loop {
if input.is_empty() {
break;
}
let vts = input.parse()?;
vtss.push(vts);
}
Ok(Self { vtss })
}
}
impl Parse for VTS {
fn parse(input: ParseStream) -> syn::Result<Self> {
let attributes = input.call(syn::Attribute::parse_outer)?;
let visibility = input.parse()?;
let _type_token = input.parse()?;
let name = input.parse()?;
let (_equal_token, base_type) = if input.peek(Token![=]) {
let _equal_token = input.parse()?;
let base_type = input.parse()?;
(Some(_equal_token), base_type)
} else {
(None, None)
};
let (_where_token, expression) = if input.peek(Token![where]) {
let _where_token = input.parse()?;
let expression = input.parse()?;
(Some(_where_token), Some(expression))
} else {
(None, None)
};
let _semi_colon_token = input.parse()?;
Ok(Self {
attributes,
visibility,
_type_token,
name,
_equal_token,
base_type,
_where_token,
expression,
_semi_colon_token,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_no_type() {
let vts: VTS = syn::parse_str(
r###"
type Password;
"###,
)
.unwrap();
assert!(vts.attributes.is_empty());
assert!(matches!(vts.visibility, syn::Visibility::Inherited));
assert!(vts.base_type.is_none());
assert!(vts.expression.is_none());
}
#[test]
fn parse_no_constraint() {
let vts: VTS = syn::parse_str(
r###"
type Password = String;
"###,
)
.unwrap();
assert!(vts.attributes.is_empty());
assert!(matches!(vts.visibility, syn::Visibility::Inherited));
assert!(vts.base_type.is_some());
assert!(vts.expression.is_none());
}
#[test]
fn parse_1() {
let vts: VTS = syn::parse_str(
r###"
type Password = String where self.len() > 3 && self.len() < 10;
"###,
)
.unwrap();
assert!(vts.attributes.is_empty());
assert!(matches!(vts.visibility, syn::Visibility::Inherited));
assert!(vts.base_type.is_some());
assert!(vts.expression.is_some());
}
#[test]
fn parse_n() {
let vtss: VTSS = syn::parse_str(
r###"
/// some documentation about it
type Password = String where self.len() > 3 && self.len() < 10;
/// more documentation
type DigitStr = String where self.chars().all(|c| c.is_digit());
"###,
)
.unwrap();
assert_eq!(vtss.vtss.len(), 2);
assert_eq!(vtss.vtss[0].attributes.len(), 1);
assert!(matches!(
vtss.vtss[0].visibility,
syn::Visibility::Inherited
));
assert!(vtss.vtss[0].base_type.is_some());
assert!(vtss.vtss[0].expression.is_some());
assert_eq!(vtss.vtss[1].attributes.len(), 1);
assert!(matches!(
vtss.vtss[1].visibility,
syn::Visibility::Inherited
));
assert!(vtss.vtss[1].base_type.is_some());
assert!(vtss.vtss[1].expression.is_some());
}
}