use proc_macro2::{Ident, TokenStream as TokenStream2};
use quote::quote;
use crate::config::{ArithmeticOp, ArithmeticResult, TypeConfig, get_standard_arithmetic_ops};
use crate::generator::{
generate_arithmetic_for_all_types, generate_arithmetic_for_primitive_types,
};
pub fn generate_arithmetic_impls(config: &TypeConfig) -> TokenStream2 {
let ops = get_standard_arithmetic_ops();
let constraint_impls = generate_arithmetic_for_all_types(
config,
&ops,
|lhs_alias,
rhs_alias,
output_alias,
trait_ident,
method_ident,
op_symbol,
result,
op,
_| {
generate_arithmetic_impl(
lhs_alias,
rhs_alias,
output_alias,
trait_ident,
method_ident,
op_symbol,
result,
op,
false,
)
},
);
let primitive_impls = generate_arithmetic_for_primitive_types(
config,
&ops,
|lhs_alias,
rhs_alias,
output_alias,
trait_ident,
method_ident,
op_symbol,
result,
op,
is_reversed| {
generate_arithmetic_impl(
lhs_alias,
rhs_alias,
output_alias,
trait_ident,
method_ident,
op_symbol,
result,
op,
is_reversed,
)
},
);
quote! {
#constraint_impls
#primitive_impls
}
}
#[allow(clippy::too_many_arguments)]
fn generate_arithmetic_impl(
lhs_alias: Ident,
rhs_alias: Ident,
output_alias: Ident,
trait_ident: Ident,
method_ident: Ident,
op_symbol: TokenStream2,
result: &ArithmeticResult,
op: ArithmeticOp,
is_reversed: bool,
) -> TokenStream2 {
let rhs_is_primitive = rhs_alias == "f32" || rhs_alias == "f64";
let lhs_is_primitive = lhs_alias == "f32" || lhs_alias == "f64";
if result.is_safe && !rhs_is_primitive && !lhs_is_primitive {
quote! {
impl #trait_ident<#rhs_alias> for #lhs_alias {
type Output = #output_alias;
fn #method_ident(self, rhs: #rhs_alias) -> Self::Output {
let result = self.get() #op_symbol rhs.get();
unsafe { #output_alias::new_unchecked(result) }
}
}
}
} else if rhs_is_primitive || lhs_is_primitive {
if is_reversed {
let fin_type = if lhs_alias == "f32" {
quote! { FinF32 }
} else {
quote! { FinF64 }
};
if op == ArithmeticOp::Div {
quote! {
impl #trait_ident<#rhs_alias> for #lhs_alias {
type Output = Result<#output_alias, FloatError>;
fn #method_ident(self, rhs: #rhs_alias) -> Self::Output {
let lhs_fin = #fin_type::new(self).map_err(|_| FloatError::NaN)?;
let result = lhs_fin.get() / rhs.get();
if !result.is_finite() {
return Err(FloatError::NaN);
}
#output_alias::new(result)
}
}
}
} else {
quote! {
impl #trait_ident<#rhs_alias> for #lhs_alias {
type Output = Result<#output_alias, FloatError>;
fn #method_ident(self, rhs: #rhs_alias) -> Self::Output {
let lhs_fin = #fin_type::new(self).map_err(|_| FloatError::NaN)?;
let result = lhs_fin.get() #op_symbol rhs.get();
#output_alias::new(result)
}
}
}
}
} else {
let fin_type = if rhs_alias == "f32" {
quote! { FinF32 }
} else {
quote! { FinF64 }
};
if op == ArithmeticOp::Div {
quote! {
impl #trait_ident<#rhs_alias> for #lhs_alias {
type Output = Result<#output_alias, FloatError>;
fn #method_ident(self, rhs: #rhs_alias) -> Self::Output {
let rhs_fin = #fin_type::new(rhs).map_err(|_| FloatError::NaN)?;
let result = self.get() / rhs_fin.get();
if !result.is_finite() {
return Err(FloatError::NaN);
}
#output_alias::new(result)
}
}
}
} else {
quote! {
impl #trait_ident<#rhs_alias> for #lhs_alias {
type Output = Result<#output_alias, FloatError>;
fn #method_ident(self, rhs: #rhs_alias) -> Self::Output {
let rhs_fin = #fin_type::new(rhs).map_err(|_| FloatError::NaN)?;
let result = self.get() #op_symbol rhs_fin.get();
#output_alias::new(result)
}
}
}
}
}
} else if op == ArithmeticOp::Div {
quote! {
impl #trait_ident<#rhs_alias> for #lhs_alias {
type Output = Result<#output_alias, FloatError>;
fn #method_ident(self, rhs: #rhs_alias) -> Self::Output {
let result = self.get() / rhs.get();
if !result.is_finite() {
return Err(FloatError::NaN);
}
#output_alias::new(result)
}
}
}
} else {
quote! {
impl #trait_ident<#rhs_alias> for #lhs_alias {
type Output = Result<#output_alias, FloatError>;
fn #method_ident(self, rhs: #rhs_alias) -> Self::Output {
let result = self.get() #op_symbol rhs.get();
#output_alias::new(result)
}
}
}
}
}