use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
use quote::quote;
use crate::config::{Bounds, ConstraintDef, Sign, TypeConfig};
use crate::generator::{find_constraint_def, make_type_alias};
fn infer_abs_output_type(constraint_def: &ConstraintDef, config: &TypeConfig) -> Ident {
let bounds = &constraint_def.bounds;
if bounds.is_symmetric() || bounds.is_normalized() || bounds.is_negative_normalized() {
let normalized_bounds = Bounds {
lower: Some(0.0),
upper: Some(1.0),
};
if let Some(ty) = config.find_type_by_constraints(Sign::Positive, &normalized_bounds, false)
{
return ty;
}
}
let abs_bounds = Bounds {
lower: Some(0.0),
upper: None,
};
let excludes_zero = constraint_def.excludes_zero;
if let Some(ty) = config.find_type_by_constraints(Sign::Positive, &abs_bounds, excludes_zero) {
return ty;
}
if excludes_zero {
Ident::new("Positive", Span::call_site())
} else {
Ident::new("NonNegative", Span::call_site())
}
}
fn infer_signum_output_type(constraint_def: &ConstraintDef, config: &TypeConfig) -> Ident {
match (constraint_def.sign, constraint_def.excludes_zero) {
(Sign::Positive, _) => {
let norm_bounds = Bounds {
lower: Some(0.0),
upper: Some(1.0),
};
config
.find_type_by_constraints(Sign::Positive, &norm_bounds, false)
.unwrap_or_else(|| Ident::new("Normalized", Span::call_site()))
}
(Sign::Negative, _) => {
let neg_norm_bounds = Bounds {
lower: Some(-1.0),
upper: Some(0.0),
};
config
.find_type_by_constraints(Sign::Negative, &neg_norm_bounds, false)
.unwrap_or_else(|| Ident::new("NegativeNormalized", Span::call_site()))
}
(Sign::Any, _) => {
let sym_bounds = Bounds {
lower: Some(-1.0),
upper: Some(1.0),
};
config
.find_type_by_constraints(Sign::Any, &sym_bounds, false)
.unwrap_or_else(|| Ident::new("Symmetric", Span::call_site()))
}
}
}
pub fn generate_abs_impls(config: &TypeConfig) -> TokenStream2 {
let mut impls = Vec::new();
for type_def in &config.constraint_types {
let type_name = &type_def.type_name;
let constraint_def = find_constraint_def(config, type_name);
let output_type = infer_abs_output_type(constraint_def, config);
for float_type in &type_def.float_types {
let type_alias = make_type_alias(type_name, float_type);
let output_alias = make_type_alias(&output_type, float_type);
impls.push(quote! {
impl #type_alias {
#[inline]
#[must_use]
pub fn abs(self) -> #output_alias {
let result = self.get().abs();
unsafe { #output_alias::new_unchecked(result) }
}
}
});
}
}
quote! {
#(#impls)*
}
}
pub fn generate_signum_impls(config: &TypeConfig) -> TokenStream2 {
let mut impls = Vec::new();
for type_def in &config.constraint_types {
let type_name = &type_def.type_name;
let constraint_def = find_constraint_def(config, type_name);
let output_type = infer_signum_output_type(constraint_def, config);
for float_type in &type_def.float_types {
let type_alias = make_type_alias(type_name, float_type);
let output_alias = make_type_alias(&output_type, float_type);
impls.push(quote! {
impl #type_alias {
#[inline]
#[must_use]
pub fn signum(self) -> #output_alias {
let result = self.get().signum();
unsafe { #output_alias::new_unchecked(result) }
}
}
});
}
}
quote! {
#(#impls)*
}
}
pub fn generate_sin_impls(config: &TypeConfig) -> TokenStream2 {
let mut impls = Vec::new();
let sym_bounds = Bounds {
lower: Some(-1.0),
upper: Some(1.0),
};
let output_type = config
.find_type_by_constraints(Sign::Any, &sym_bounds, false)
.unwrap_or_else(|| Ident::new("Symmetric", Span::call_site()));
for type_def in &config.constraint_types {
let type_name = &type_def.type_name;
for float_type in &type_def.float_types {
let type_alias = make_type_alias(type_name, float_type);
let output_alias = make_type_alias(&output_type, float_type);
impls.push(quote! {
#[cfg(feature = "std")]
impl #type_alias {
#[inline]
#[must_use]
pub fn sin(self) -> #output_alias {
let result = self.get().sin();
unsafe { #output_alias::new_unchecked(result) }
}
}
});
}
}
quote! {
#(#impls)*
}
}
pub fn generate_cos_impls(config: &TypeConfig) -> TokenStream2 {
let mut impls = Vec::new();
let sym_bounds = Bounds {
lower: Some(-1.0),
upper: Some(1.0),
};
let output_type = config
.find_type_by_constraints(Sign::Any, &sym_bounds, false)
.unwrap_or_else(|| Ident::new("Symmetric", Span::call_site()));
for type_def in &config.constraint_types {
let type_name = &type_def.type_name;
for float_type in &type_def.float_types {
let type_alias = make_type_alias(type_name, float_type);
let output_alias = make_type_alias(&output_type, float_type);
impls.push(quote! {
#[cfg(feature = "std")]
impl #type_alias {
#[inline]
#[must_use]
pub fn cos(self) -> #output_alias {
let result = self.get().cos();
unsafe { #output_alias::new_unchecked(result) }
}
}
});
}
}
quote! {
#(#impls)*
}
}
pub fn generate_tan_impls(config: &TypeConfig) -> TokenStream2 {
let mut impls = Vec::new();
for type_def in &config.constraint_types {
let type_name = &type_def.type_name;
let output_type = Ident::new("Fin", Span::call_site());
for float_type in &type_def.float_types {
let type_alias = make_type_alias(type_name, float_type);
let output_alias = make_type_alias(&output_type, float_type);
impls.push(quote! {
#[cfg(feature = "std")]
impl #type_alias {
#[inline]
#[must_use]
pub fn tan(self) -> Result<#output_alias, FloatError> {
let result = self.get().tan();
if !result.is_finite() {
return Err(FloatError::NaN);
}
unsafe { Ok(#output_alias::new_unchecked(result)) }
}
}
});
}
}
quote! {
#(#impls)*
}
}