use proc_macro2::TokenStream as TokenStream2;
use quote::{quote, ToTokens};
use syn::{
parse::{Parse, ParseStream},
spanned::Spanned,
FnArg,
Ident,
ImplItem,
ImplItemFn,
ItemImpl,
Result,
ReturnType,
Token,
};
fn pascal_to_snake_case(ident: &Ident) -> Ident {
let mut snake = String::new();
for (i, ch) in ident.to_string().chars().enumerate() {
if ch.is_ascii_uppercase() {
if i != 0 {
snake.push('_');
}
snake.push(ch.to_ascii_lowercase());
} else {
snake.push(ch);
}
}
syn::parse_str(&snake).unwrap()
}
fn find_eval_static_fn(block: &ItemImpl) -> Result<&ImplItemFn> {
block.items.iter()
.find_map(|item| match item {
ImplItem::Fn(func) if func.sig.ident == "eval_static" => Some(func),
_ => None,
})
.ok_or_else(|| syn::Error::new(block.span(), "expected `eval_static` function inside `impl` block"))
}
fn path_ident(path: &syn::Type) -> Result<&Ident> {
match path {
syn::Type::Path(path) => path.path.get_ident()
.ok_or_else(|| syn::Error::new(path.span(), "expected identifier")),
_ => Err(syn::Error::new(path.span(), "expected path")),
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Type {
optional: bool,
is_ref: bool,
kind: TypeKind,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TypeKind {
Float,
Integer,
Complex,
Bool,
Unit,
Value,
}
fn match_first_segment(ty: &syn::Type) -> Result<Type> {
let path = match ty {
syn::Type::Path(ty) => &ty.path,
_ => return Err(syn::Error::new(ty.span(), "expected path")),
};
let Some(first) = path.segments.first() else {
return Err(syn::Error::new(path.span(), "expected path"));
};
let ident_str = first.ident.to_string();
if ident_str == "Option" {
let args = match &first.arguments {
syn::PathArguments::AngleBracketed(bracketed) if bracketed.args.len() == 1 => &bracketed.args,
_ => return Err(syn::Error::new(first.ident.span(), "expected one angle-bracketed argument")),
};
let first_arg = args.first().unwrap();
let ty = match first_arg {
syn::GenericArgument::Type(ty) => ty,
_ => return Err(syn::Error::new(first_arg.span(), "expected type as generic argument")),
};
Ok(Type {
optional: true,
is_ref: false,
kind: match_first_segment(ty)?.kind,
})
} else {
Ok(Type {
optional: false,
is_ref: false,
kind: match &*ident_str {
"Float" => TypeKind::Float,
"Integer" => TypeKind::Integer,
"Complex" => TypeKind::Complex,
"bool" => TypeKind::Bool,
"Value" => TypeKind::Value,
_ => return Err(syn::Error::new(first.ident.span(), format!("expected `Float`, `Integer`, `Complex`, `bool`, or `Value`, found `{}`", ident_str))),
},
})
}
}
impl TryFrom<ReturnType> for Type {
type Error = syn::Error;
fn try_from(ty: ReturnType) -> Result<Self> {
match ty {
ReturnType::Default => Ok(Type { optional: false, is_ref: false, kind: TypeKind::Unit }),
ReturnType::Type(_, ty) => match_first_segment(&ty),
}
}
}
impl TryFrom<syn::Type> for Type {
type Error = syn::Error;
fn try_from(ty: syn::Type) -> Result<Self> {
match ty {
syn::Type::Reference(ty) => Ok(Type { optional: false, is_ref: true, kind: match_first_segment(&ty.elem)?.kind }),
_ => match_first_segment(&ty),
}
}
}
impl Type {
pub fn rug_tokens(&self) -> TokenStream2 {
let kind = match self.kind {
TypeKind::Float => quote! { Float },
TypeKind::Integer => quote! { Integer },
TypeKind::Complex => quote! { Complex },
TypeKind::Bool => quote! { bool },
TypeKind::Unit => quote! { () },
TypeKind::Value => quote! { Value },
};
let reffed = if self.is_ref {
quote! { & #kind }
} else {
quote! { #kind }
};
if self.optional {
quote! { Option<#reffed> }
} else {
reffed
}
}
pub fn typename(&self) -> &'static str {
match self.kind {
TypeKind::Float => "Float",
TypeKind::Integer => "Integer",
TypeKind::Complex => "Complex",
TypeKind::Bool => "Boolean",
TypeKind::Unit => "Unit",
TypeKind::Value => "Value",
}
}
pub fn value_tokens(&self) -> TokenStream2 {
match self.kind {
TypeKind::Float => quote! { crate::numerical::value::Value::Float },
TypeKind::Integer => quote! { crate::numerical::value::Value::Integer },
TypeKind::Complex => quote! { crate::numerical::value::Value::Complex },
TypeKind::Bool => quote! { crate::numerical::value::Value::Bool },
TypeKind::Unit => quote! { crate::numerical::value::Value::Unit },
TypeKind::Value => quote! { crate::numerical::value::Value::Value },
}
}
}
#[derive(Debug)]
pub struct Param {
mutable: Option<Token![mut]>,
ident: Ident,
ty: Type,
}
impl TryFrom<FnArg> for Param {
type Error = syn::Error;
fn try_from(arg: FnArg) -> Result<Self> {
match arg {
FnArg::Typed(pat) => {
let (mutable, ident) = match *pat.pat {
syn::Pat::Ident(pat) => (pat.mutability, pat.ident),
_ => return Err(syn::Error::new(pat.pat.span(), "expected identifier")),
};
let ty = Type::try_from(*pat.ty)?;
Ok(Param { mutable, ident, ty })
},
FnArg::Receiver(_) => Err(syn::Error::new(arg.span(), "expected argument")),
}
}
}
impl ToTokens for Param {
fn to_tokens(&self, tokens: &mut TokenStream2) {
let Self { mutable, ident, ty } = self;
let ty = ty.rug_tokens();
tokens.extend(quote! { #mutable #ident: #ty });
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Radian {
Input,
Output,
None,
}
impl Parse for Radian {
fn parse(input: ParseStream) -> Result<Self> {
if let Ok(radian) = input.parse::<Ident>() {
if radian != "radian" {
return Ok(Radian::None);
}
} else {
return Ok(Radian::None);
}
input.parse::<Token![=]>()?;
let kind = input.parse::<Ident>()?;
match kind.to_string().as_str() {
"input" => Ok(Radian::Input),
"output" => Ok(Radian::Output),
_ => Err(syn::Error::new(kind.span(), "expected `input` or `output`")),
}
}
}
#[derive(Debug)]
pub struct Builtin {
item: ItemImpl,
name: Ident,
pascal_name: Ident,
params: Vec<Param>,
}
impl Builtin {
pub fn signature(&self) -> String {
let params = self.params.iter().map(|param| {
let ty = param.ty.typename();
if param.ty.optional {
format!("{}: {} (optional)", param.ident, ty)
} else {
format!("{}: {}", param.ident, ty)
}
}).collect::<Vec<_>>().join(", ");
format!("{}({})", self.name, params)
}
pub fn generate_check_stmts(&self, radian: Radian) -> TokenStream2 {
let Self { name, .. } = self;
let signature = self.signature();
let num_params = self.params.len();
let type_checkers = self.params
.iter()
.enumerate()
.map(|(i, param)| {
let (ident, ty) = (¶m.ident, ¶m.ty);
if ty.kind == TypeKind::Value {
return quote! {
let #ident = crate::funcs::helper::next_arg(args, &mut arg_count)
.ok_or_else(|| crate::numerical::builtin::error::BuiltinError::MissingArgument(crate::numerical::error::kind::MissingArgument {
name: stringify!(#name).to_owned(),
index: #i,
expected: #num_params,
given: crate::funcs::helper::count_all_args(args, &mut arg_count),
signature: #signature.to_owned(),
}))?;
};
}
let base_call = quote! { crate::funcs::helper::next_arg(args, &mut arg_count) };
let type_coerce_expr = match param.ty.kind {
TypeKind::Float => Some(quote! { .map(|arg| arg.coerce_float()) }),
TypeKind::Integer => Some(quote! { .map(|arg| arg.coerce_integer()) }),
TypeKind::Complex => Some(quote! { .map(|arg| arg.coerce_complex()) }),
_ => None,
};
let trig_convert_expr = if radian == Radian::Input {
Some(quote! {
.map(|arg| {
if ctxt.trig_mode == crate::numerical::ctxt::TrigMode::Degrees {
arg.into_radians()
} else {
arg
}
})
})
} else {
None
};
let full_getter = quote! { #base_call #type_coerce_expr #trig_convert_expr };
let (received_type, none_branch) = if ty.optional {
(
quote! { Some(#ident) },
quote! { None },
)
} else {
(
quote! { #ident },
quote! {
return Err(crate::numerical::builtin::error::BuiltinError::MissingArgument(crate::numerical::error::kind::MissingArgument {
name: stringify!(#name).to_owned(),
index: #i,
expected: #num_params,
given: crate::funcs::helper::count_all_args(args, &mut arg_count),
signature: #signature.to_owned(),
}));
},
)
};
let user_ty = ty.typename();
let ty = ty.value_tokens();
quote! {
let #ident = match #full_getter {
Some(#ty(#ident)) => #received_type,
Some(bad_value) => {
return Err(crate::numerical::builtin::error::BuiltinError::TypeMismatch(crate::numerical::error::kind::TypeMismatch {
name: stringify!(#name).to_owned(),
index: #i,
expected: #user_ty,
given: bad_value.typename(),
signature: #signature.to_owned(),
}));
},
None => { #none_branch },
};
}
});
quote! {
#( #type_checkers )*
if crate::funcs::helper::count_all_args(args, &mut arg_count) > #num_params {
return Err(crate::numerical::builtin::error::BuiltinError::TooManyArguments(crate::numerical::error::kind::TooManyArguments {
name: stringify!(#name).to_owned(),
expected: #num_params,
given: crate::funcs::helper::count_all_args(args, &mut arg_count),
signature: #signature.to_owned(),
}));
}
}
}
pub fn generate_call(&self, radian: Radian) -> TokenStream2 {
let pascal_name = &self.pascal_name;
let param_idents = self.params.iter().map(|param| {
let ident = ¶m.ident;
if param.ty.is_ref {
quote! { &#ident }
} else {
quote! { #ident }
}
});
let make_value = quote! { crate::numerical::value::Value::from(
#pascal_name::eval_static(#(#param_idents),*)
) };
if radian == Radian::Output {
quote! {
if ctxt.trig_mode == crate::numerical::ctxt::TrigMode::Degrees {
Ok(#make_value.into_degrees())
} else {
Ok(#make_value)
}
}
} else {
quote! { Ok(#make_value) }
}
}
fn impl_static(&self) -> TokenStream2 {
let Self { item, .. } = self;
quote! { #item }
}
fn impl_builtin(&self, radian: Radian) -> TokenStream2 {
let Self { pascal_name, params, .. } = self;
let arg_count = params.len();
let type_checkers = self.generate_check_stmts(radian);
let call = self.generate_call(radian);
quote! {
impl crate::numerical::builtin::Builtin for #pascal_name {
fn num_args(&self) -> usize { #arg_count }
fn eval(
&self,
ctxt: &crate::numerical::ctxt::Ctxt,
args: &mut dyn Iterator<Item = crate::numerical::value::Value>,
) -> Result<crate::numerical::value::Value, crate::numerical::builtin::error::BuiltinError> {
let mut arg_count = 0;
#type_checkers
#call
}
}
}
}
pub fn generate(&self, radian: Radian) -> TokenStream2 {
let static_impl = self.impl_static();
let builtin_impl = self.impl_builtin(radian);
quote! {
#static_impl
#builtin_impl
}
}
}
impl Parse for Builtin {
fn parse(input: ParseStream) -> Result<Self> {
let item = input.parse::<ItemImpl>()?;
let pascal_name = path_ident(&item.self_ty)?.clone();
let name = pascal_to_snake_case(&pascal_name);
let eval_static_fn = find_eval_static_fn(&item)?.clone();
let builtin = Builtin {
item,
pascal_name,
name,
params: eval_static_fn.sig.inputs.into_iter().map(Param::try_from).collect::<Result<_>>()?,
};
let mut seen_optional = false;
for param in &builtin.params {
if param.ty.optional {
seen_optional = true;
} else if seen_optional {
return Err(syn::Error::new(param.ident.span(), "optional parameters cannot be followed by non-optional parameters"));
}
}
Ok(builtin)
}
}