use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
use quote::quote;
use super::type_utils::make_type_alias;
use crate::config::{ArithmeticOp, ArithmeticResult, TypeConfig};
pub fn generate_arithmetic_for_all_types<F>(
config: &TypeConfig,
ops: &[(ArithmeticOp, &str, &str, TokenStream2)],
mut impl_generator: F,
) -> TokenStream2
where
F: FnMut(
Ident,
Ident,
Ident,
Ident,
Ident,
TokenStream2,
&ArithmeticResult,
ArithmeticOp,
bool,
) -> TokenStream2,
{
let mut impls = Vec::new();
for lhs_type in &config.constraint_types {
for rhs_type in &config.constraint_types {
for (op, trait_name, method_name, op_symbol) in ops {
let trait_ident = Ident::new(trait_name, Span::call_site());
let method_ident = Ident::new(method_name, Span::call_site());
let key = (
*op,
lhs_type.type_name.to_string(),
rhs_type.type_name.to_string(),
);
let result = config
.arithmetic_results
.get(&key)
.expect("Arithmetic result not found");
for float_type in &lhs_type.float_types {
let lhs_alias = make_type_alias(&lhs_type.type_name, float_type);
let rhs_alias = make_type_alias(&rhs_type.type_name, float_type);
let output_alias = make_type_alias(&result.output_type, float_type);
let impl_code = impl_generator(
lhs_alias,
rhs_alias,
output_alias,
trait_ident.clone(),
method_ident.clone(),
op_symbol.clone(),
result,
*op,
false, );
impls.push(impl_code);
}
}
}
}
quote! {
#(#impls)*
}
}
pub fn generate_arithmetic_for_primitive_types<F>(
config: &TypeConfig,
ops: &[(ArithmeticOp, &str, &str, TokenStream2)],
mut impl_generator: F,
) -> TokenStream2
where
F: FnMut(
Ident,
Ident,
Ident,
Ident,
Ident,
TokenStream2,
&ArithmeticResult,
ArithmeticOp,
bool,
) -> TokenStream2,
{
let mut impls = Vec::new();
let primitive_mappings = vec![("f32", "Fin"), ("f64", "Fin")];
for lhs_type in &config.constraint_types {
for (primitive_name, fin_constraint) in &primitive_mappings {
for (op, trait_name, method_name, op_symbol) in ops {
let trait_ident = Ident::new(trait_name, Span::call_site());
let method_ident = Ident::new(method_name, Span::call_site());
let key = (
*op,
lhs_type.type_name.to_string(),
fin_constraint.to_string(),
);
if let Some(result) = config.arithmetic_results.get(&key) {
for float_type in &lhs_type.float_types {
if float_type.to_string().as_str() != *primitive_name {
continue;
}
let lhs_alias = make_type_alias(&lhs_type.type_name, float_type);
let primitive_ident = Ident::new(primitive_name, Span::call_site());
let output_alias = make_type_alias(&result.output_type, float_type);
let impl_code = impl_generator(
lhs_alias,
primitive_ident,
output_alias,
trait_ident.clone(),
method_ident.clone(),
op_symbol.clone(),
result,
*op,
false, );
impls.push(impl_code);
}
}
}
}
}
for (primitive_name, fin_constraint) in &primitive_mappings {
for rhs_type in &config.constraint_types {
for (op, trait_name, method_name, op_symbol) in ops {
let trait_ident = Ident::new(trait_name, Span::call_site());
let method_ident = Ident::new(method_name, Span::call_site());
let key = (
*op,
fin_constraint.to_string(),
rhs_type.type_name.to_string(),
);
if let Some(result) = config.arithmetic_results.get(&key) {
for float_type in &rhs_type.float_types {
if float_type.to_string().as_str() != *primitive_name {
continue;
}
let primitive_ident = Ident::new(primitive_name, Span::call_site());
let rhs_alias = make_type_alias(&rhs_type.type_name, float_type);
let output_alias = make_type_alias(&result.output_type, float_type);
let impl_code = impl_generator(
primitive_ident,
rhs_alias,
output_alias,
trait_ident.clone(),
method_ident.clone(),
op_symbol.clone(),
result,
*op,
true, );
impls.push(impl_code);
}
}
}
}
}
quote! {
#(#impls)*
}
}