use super::super::Codegen;
use proc_macro2::{Ident, TokenStream};
use quote::{quote, quote_spanned};
use syn::Type;
#[derive(Clone, Copy)]
pub enum Trait {
Add,
Sub,
Mul,
Div,
AddAssign,
SubAssign,
MulAssign,
DivAssign,
PartialEq,
PartialOrd,
}
use Trait::*;
impl Trait {
pub fn name(&self) -> TokenStream {
match self {
Add => quote! { core::ops::Add },
Sub => quote! { core::ops::Sub },
Mul => quote! { core::ops::Mul },
Div => quote! { core::ops::Div },
AddAssign => quote! { core::ops::AddAssign },
SubAssign => quote! { core::ops::SubAssign },
MulAssign => quote! { core::ops::MulAssign },
DivAssign => quote! { core::ops::DivAssign },
PartialEq => quote! { core::cmp::PartialEq },
PartialOrd => quote! { core::cmp::PartialOrd },
}
}
pub fn fn_name(&self) -> TokenStream {
match self {
Add => quote! { add },
Sub => quote! { sub },
Mul => quote! { mul },
Div => quote! { div },
AddAssign => quote! { add_assign },
SubAssign => quote! { sub_assign },
MulAssign => quote! { mul_assign },
DivAssign => quote! { div_assign },
PartialEq => quote! { eq },
PartialOrd => quote! { partial_cmp },
}
}
pub fn fn_return_type(&self) -> TokenStream {
match self {
Add | Sub | Mul | Div => quote! { Self::Output },
AddAssign | SubAssign | MulAssign | DivAssign => {
quote! { () }
}
PartialEq => quote! { bool },
PartialOrd => quote! { Option<core::cmp::Ordering> },
}
}
pub fn lhs_arg(&self) -> TokenStream {
match self {
Add | Sub | Mul | Div => quote! { self },
AddAssign | SubAssign | MulAssign | DivAssign => {
quote! { &mut self }
}
PartialEq | PartialOrd => quote! { &self },
}
}
pub fn rhs_arg_type(&self, rhs: &TokenStream) -> TokenStream {
match self {
Add | Sub | Mul | Div | AddAssign | SubAssign | MulAssign | DivAssign => rhs.clone(),
PartialEq | PartialOrd => {
let rhs = rhs.clone();
quote! { &#rhs }
}
}
}
pub fn has_output_type(&self) -> bool {
matches!(self, Add | Sub | Mul | Div)
}
}
enum StorageType {
Generic,
Concrete(Type),
}
enum QuantityType {
Quantity,
Dimensionless,
Storage,
}
enum ReferenceType {
Value,
Reference,
MutableReference,
}
impl ReferenceType {
fn is_ref(&self) -> bool {
match self {
ReferenceType::Value => false,
ReferenceType::Reference => true,
ReferenceType::MutableReference => true,
}
}
}
struct Operand {
type_: QuantityType,
storage: StorageType,
reference: ReferenceType,
}
impl Operand {
fn ref_sign(&self, span: proc_macro2::Span) -> TokenStream {
match self.reference {
ReferenceType::Value => quote_spanned! {span=>},
ReferenceType::Reference => quote_spanned! {span=>&'a },
ReferenceType::MutableReference => quote_spanned! {span=>&'a mut },
}
}
fn is_storage(&self) -> bool {
!matches!(
self.type_,
QuantityType::Quantity | QuantityType::Dimensionless
)
}
}
enum OutputQuantityDimension {
Existing(TokenStream),
New(TokenStream),
}
impl OutputQuantityDimension {
fn unwrap(&self) -> &TokenStream {
match self {
OutputQuantityDimension::Existing(t) => t,
OutputQuantityDimension::New(t) => t,
}
}
}
struct OutputQuantity {
storage: TokenStream,
dimension: OutputQuantityDimension,
}
impl OutputQuantity {
fn output_type_def(&self, quantity_type: &Ident) -> TokenStream {
let OutputQuantity { storage, dimension } = self;
let dimension = dimension.unwrap();
let out = quote! { type Output = #quantity_type < #storage, #dimension >; };
out
}
fn generic_const_bound(&self, quantity_type: &Ident) -> TokenStream {
if let OutputQuantityDimension::New(dim) = &self.dimension {
quote! { #quantity_type < (), #dim >: }
} else {
quote! {}
}
}
}
struct OperatorTrait {
name: Trait,
lhs: Operand,
rhs: Operand,
}
impl OperatorTrait {
fn different_dimensions_allowed(&self) -> bool {
use Trait::*;
match self.name {
Add | Sub | AddAssign | SubAssign | MulAssign | DivAssign | PartialEq | PartialOrd => {
false
}
Mul | Div => true,
}
}
fn different_storage_types_allowed(&self) -> bool {
matches!(
self.lhs.type_,
QuantityType::Quantity | QuantityType::Dimensionless
) && matches!(
self.rhs.type_,
QuantityType::Quantity | QuantityType::Dimensionless
)
}
fn dimension_types(&self) -> (Option<TokenStream>, Option<TokenStream>) {
use QuantityType::*;
match (&self.lhs.type_, &self.rhs.type_) {
(Quantity, Quantity) => {
if self.different_dimensions_allowed() {
(Some(quote! { DL }), Some(quote! { DR }))
} else {
(Some(quote! { D }), Some(quote! { D }))
}
}
(Quantity, _) => (Some(quote! { D }), None),
(_, Quantity) => (None, Some(quote! { D })),
_ => (None, None),
}
}
fn storage_types(&self) -> (TokenStream, TokenStream) {
use StorageType::*;
let different_storage_types_allowed = self.different_storage_types_allowed();
match (&self.lhs.storage, &self.rhs.storage) {
(Generic, Generic) => {
if different_storage_types_allowed {
(quote! { LHS }, quote! { RHS })
} else {
(quote! { S }, quote! { S })
}
}
(Concrete(ty), Generic) => (quote! {#ty}, quote! {S}),
(Generic, Concrete(ty)) => (quote! {S}, quote! {#ty}),
(Concrete(tyl), Concrete(tyr)) => (quote! {#tyl}, quote! {#tyr}),
}
}
fn generics(&self, dimension_type: &Ident) -> Vec<TokenStream> {
let mut num_lifetimes = 0;
if self.lhs.reference.is_ref() {
num_lifetimes += 1
}
if self.rhs.reference.is_ref() {
num_lifetimes += 1
}
let mut types = vec![];
match num_lifetimes {
0 => {}
1 | 2 => types.push(quote! { 'a }),
_ => unreachable!(),
}
let make_dim_expr_from_name = |name| quote! { const #name: #dimension_type };
let (lhs_dimension, rhs_dimension) = self.dimension_types();
let has_lhs_dimension = lhs_dimension.is_some();
types.extend(lhs_dimension.into_iter().map(make_dim_expr_from_name));
if self.different_dimensions_allowed() || !has_lhs_dimension {
types.extend(rhs_dimension.into_iter().map(make_dim_expr_from_name));
}
let (lhs_storage, rhs_storage) = self.storage_types();
if matches!(self.lhs.storage, StorageType::Generic) {
types.push(lhs_storage);
}
if matches!(self.rhs.storage, StorageType::Generic)
&& (!matches!(self.lhs.storage, StorageType::Generic)
|| self.different_storage_types_allowed())
{
types.push(rhs_storage);
}
types
}
fn generics_gen(&self, dimension_type: &Ident) -> TokenStream {
let types = self.generics(dimension_type);
quote! {
< #(#types),* >
}
}
fn lhs_type(&self, quantity_type: &Ident, dimension_type: &Ident) -> TokenStream {
let storage = self.storage_types().0;
let dimension = self.dimension_types().0;
self.type_for_operand(&self.lhs, quantity_type, dimension_type, storage, dimension)
}
fn rhs_type(&self, quantity_type: &Ident, dimension_type: &Ident) -> TokenStream {
let storage = self.storage_types().1;
let dimension = self.dimension_types().1;
self.type_for_operand(&self.rhs, quantity_type, dimension_type, storage, dimension)
}
fn type_for_operand(
&self,
operand: &Operand,
quantity_type: &Ident,
dimension_type: &Ident,
storage: TokenStream,
dimension: Option<TokenStream>,
) -> TokenStream {
let span = dimension_type.span();
let ref_sign = operand.ref_sign(span);
let type_name = match operand.type_ {
QuantityType::Quantity => {
quote_spanned! {span=> #quantity_type < #storage, #dimension > }
}
QuantityType::Dimensionless => {
quote_spanned! {span=>#quantity_type < #storage, { #dimension_type :: none() } >}
}
QuantityType::Storage => quote_spanned! {span=>#storage},
};
quote_spanned! {span=>#ref_sign #type_name}
}
fn trait_bounds(
&self,
quantity_type: &Ident,
output_type: &Option<OutputQuantity>,
) -> TokenStream {
let storage_bounds = if matches!(self.lhs.storage, StorageType::Generic)
|| matches!(self.rhs.storage, StorageType::Generic)
{
let (lhs_storage, rhs_storage) = self.storage_types();
let trait_name = self.name.name();
let output_bound = quote! {};
let lhs_copy_bound = if self.lhs.reference.is_ref() {
quote! { #lhs_storage: Copy, }
} else {
quote! {}
};
let rhs_copy_bound = if self.rhs.reference.is_ref() {
quote! { #rhs_storage: Copy, }
} else {
quote! {}
};
quote! {
#lhs_storage: #trait_name :: < #rhs_storage, #output_bound >,
#lhs_copy_bound
#rhs_copy_bound
}
} else {
quote! {}
};
let generic_const_bound = output_type
.as_ref()
.map(|output_type| output_type.generic_const_bound(quantity_type))
.unwrap_or_default();
quote! {
#storage_bounds
#generic_const_bound
}
}
fn output_quantity_storage(&self) -> TokenStream {
assert!(self.name.has_output_type());
let trait_name = self.name.name();
let (lhs, rhs) = self.storage_types();
if let StorageType::Concrete(lhs_ty) = &self.lhs.storage {
if let StorageType::Concrete(rhs_ty) = &self.rhs.storage {
assert_eq!(lhs_ty, rhs_ty);
return quote! { #lhs_ty };
}
}
quote! { < #lhs as #trait_name<#rhs> >::Output }
}
fn output_quantity_dimension(&self, dimension_type: &Ident) -> OutputQuantityDimension {
assert!(self.name.has_output_type());
let span = dimension_type.span();
use OutputQuantityDimension::*;
use QuantityType::*;
let existing = Existing(quote_spanned! { span=> D });
match (&self.lhs.type_, &self.rhs.type_) {
(Quantity, Quantity) => match self.name {
Mul => New(quote_spanned! {span=> { DL.add(DR) } }),
Div => New(quote_spanned! {span=> { DL.sub(DR) } }),
_ => existing,
},
(Quantity, Storage) => existing,
(Storage, Quantity) => match self.name {
Mul => existing,
Div => New(quote_spanned! {span=> { D.neg() } }),
_ => unreachable!(),
},
(Dimensionless, Storage) | (Storage, Dimensionless) => {
New(quote_spanned! {span=> { #dimension_type :: none() } })
}
_ => unreachable!(),
}
}
fn output_type(&self, dimension_type: &Ident) -> Option<OutputQuantity> {
if !self.name.has_output_type() {
None
} else {
Some(OutputQuantity {
storage: self.output_quantity_storage(),
dimension: self.output_quantity_dimension(dimension_type),
})
}
}
fn rhs_takes_ref(&self) -> bool {
matches!(self.name, PartialOrd | PartialEq)
}
fn fn_return_expr(
&self,
quantity_type: &Ident,
output_type: &Option<OutputQuantity>,
) -> TokenStream {
let lhs = match self.lhs.type_ {
QuantityType::Quantity | QuantityType::Dimensionless => quote! { self.0 },
QuantityType::Storage => quote! { self },
};
let rhs = match self.rhs.type_ {
QuantityType::Quantity | QuantityType::Dimensionless => quote! { rhs.0 },
QuantityType::Storage => quote! { rhs },
};
let fn_name = self.name.fn_name();
let deref_or_ref = if self.rhs_takes_ref() {
quote! { & }
} else if self.rhs.reference.is_ref() && self.rhs.is_storage() {
quote! {*}
} else {
quote! {}
};
let lhs = if self.lhs.reference.is_ref() && self.lhs.is_storage() {
quote! {(*#lhs)}
} else {
lhs
};
let result = quote! { #lhs.#fn_name(#deref_or_ref #rhs) };
if output_type.is_some() {
quote! { #quantity_type ( #result ) }
} else {
result
}
}
}
macro_rules! def_operand {
(&mut $quantity: ident, $storage: expr) => {
Operand {
reference: ReferenceType::MutableReference,
type_: QuantityType::$quantity,
storage: $storage,
}
};
(& $quantity: ident, $storage: expr) => {
Operand {
reference: ReferenceType::Reference,
type_: QuantityType::$quantity,
storage: $storage,
}
};
($quantity: ident, $storage: expr) => {
Operand {
reference: ReferenceType::Value,
type_: QuantityType::$quantity,
storage: $storage,
}
};
}
macro_rules! add_trait {
(
$traits: ident,
$name: path,
($($lhs:tt)*), ($($rhs:tt)*)) => {
$traits.push(OperatorTrait {
name: $name,
lhs: def_operand!($($lhs)*),
rhs: def_operand!($($rhs)*),
})
}
}
impl Codegen {
pub fn gen_operator_trait_impls(&self) -> TokenStream {
self.iter_numeric_traits()
.map(|num_trait| self.gen_operator_trait_impl(num_trait))
.collect()
}
#[rustfmt::skip]
fn iter_numeric_traits(&self) -> impl Iterator<Item = OperatorTrait> + '_ {
let mut traits = vec![];
use StorageType::*;
for t in [Add, Sub, Mul, Div] {
add_trait!(traits, t, (Quantity, Generic), (Quantity, Generic));
add_trait!(traits, t, (Quantity, Generic), (&Quantity, Generic));
add_trait!(traits, t, (&Quantity, Generic), (Quantity, Generic));
add_trait!(traits, t, (&Quantity, Generic), (&Quantity, Generic));
}
for t in [AddAssign, SubAssign] {
add_trait!(traits, t, (Quantity, Generic), (Quantity, Generic));
add_trait!(traits, t, (Quantity, Generic), (&Quantity, Generic));
add_trait!(traits, t, (&mut Quantity, Generic), (Quantity, Generic));
add_trait!(traits, t, (&mut Quantity, Generic), (&Quantity, Generic));
}
for t in [MulAssign, DivAssign] {
add_trait!(traits, t, (Quantity, Generic), (Dimensionless, Generic));
add_trait!(traits, t, (Quantity, Generic), (&Dimensionless, Generic));
add_trait!(traits, t, (&mut Quantity, Generic), (Dimensionless, Generic));
add_trait!(traits, t, (&mut Quantity, Generic), (&Dimensionless, Generic));
}
for t in [PartialOrd, PartialEq] {
add_trait!(traits, t, (Quantity, Generic), (Quantity, Generic));
add_trait!(traits, t, (Quantity, Generic), (&Quantity, Generic));
add_trait!(traits, t, (&Quantity, Generic), (Quantity, Generic));
}
for t in [Add, Sub] {
add_trait!(traits, t, (Dimensionless, Generic), (Storage, Generic));
add_trait!(traits, t, (Dimensionless, Generic), (&Storage, Generic));
add_trait!(traits, t, (&Dimensionless, Generic), (Storage, Generic));
add_trait!(traits, t, (&Dimensionless, Generic), (&Storage, Generic));
}
for t in [AddAssign, SubAssign] {
add_trait!(traits, t, (Dimensionless, Generic), (Storage, Generic));
add_trait!(traits, t, (Dimensionless, Generic), (&Storage, Generic));
add_trait!(traits, t, (&mut Dimensionless, Generic), (Storage, Generic));
add_trait!(traits, t, (&mut Dimensionless, Generic), (&Storage, Generic));
}
for ty in self.storage_type_names() {
for t in [Add, Sub] {
add_trait!(traits, t, (Storage, Concrete(ty.clone())), (Dimensionless, Concrete(ty.clone())));
add_trait!(traits, t, (Storage, Concrete(ty.clone())), (&Dimensionless, Concrete(ty.clone())));
add_trait!(traits, t, (&Storage, Concrete(ty.clone())), (Dimensionless, Concrete(ty.clone())));
add_trait!(traits, t, (&Storage, Concrete(ty.clone())), (&Dimensionless, Concrete(ty.clone())));
}
for t in [AddAssign, SubAssign] {
add_trait!(traits, t, (Storage, Concrete(ty.clone())), (Dimensionless, Concrete(ty.clone())));
add_trait!(traits, t, (Storage, Concrete(ty.clone())), (&Dimensionless, Concrete(ty.clone())));
}
for t in [Mul, Div] {
add_trait!(traits, t, (Quantity, Generic), (Storage, Concrete(ty.clone())));
add_trait!(traits, t, (&Quantity, Generic), (Storage, Concrete(ty.clone())));
add_trait!(traits, t, (Quantity, Generic), (&Storage, Concrete(ty.clone())));
add_trait!(traits, t, (&Quantity, Generic), (&Storage, Concrete(ty.clone())));
add_trait!(traits, t, (Storage, Concrete(ty.clone())), (Quantity, Generic));
add_trait!(traits, t, (&Storage, Concrete(ty.clone())), (Quantity, Generic));
add_trait!(traits, t, (Storage, Concrete(ty.clone())), (&Quantity, Generic));
add_trait!(traits, t, (&Storage, Concrete(ty.clone())), (&Quantity, Generic));
}
for t in [MulAssign, DivAssign] {
add_trait!(traits, t, (Quantity, Concrete(ty.clone())), (Storage, Concrete(ty.clone())));
add_trait!(traits, t, (Quantity, Concrete(ty.clone())), (&Storage, Concrete(ty.clone())));
add_trait!(traits, t, (Storage, Concrete(ty.clone())), (Dimensionless, Concrete(ty.clone())));
add_trait!(traits, t, (Storage, Concrete(ty.clone())), (&Dimensionless, Concrete(ty.clone())));
}
for t in [PartialEq, PartialOrd] {
add_trait!(traits, t, (Dimensionless, Generic), (Storage, Concrete(ty.clone())));
add_trait!(traits, t, (Storage, Concrete(ty.clone())), (Dimensionless, Generic));
}
}
traits.into_iter()
}
fn gen_operator_trait_impl(&self, numeric_trait: OperatorTrait) -> TokenStream {
let name = numeric_trait.name;
let fn_name = name.fn_name();
let trait_name = name.name();
let fn_return_type = name.fn_return_type();
let lhs = numeric_trait.lhs_type(&self.defs.quantity_type, &self.defs.dimension_type);
let rhs = numeric_trait.rhs_type(&self.defs.quantity_type, &self.defs.dimension_type);
let lhs_arg = name.lhs_arg();
let rhs_arg = name.rhs_arg_type(&rhs);
let fn_args = quote! { #lhs_arg, rhs: #rhs_arg };
let impl_generics = numeric_trait.generics_gen(&self.defs.dimension_type);
let output_type = numeric_trait.output_type(&self.defs.dimension_type);
let output_type_def = output_type
.as_ref()
.map(|output_type| output_type.output_type_def(&self.defs.quantity_type));
let trait_bounds = numeric_trait.trait_bounds(&self.defs.quantity_type, &output_type);
let fn_return_expr = numeric_trait.fn_return_expr(&self.defs.quantity_type, &output_type);
quote! {
impl #impl_generics #trait_name::<#rhs> for #lhs
where
#trait_bounds
{
#output_type_def
fn #fn_name(#fn_args) -> #fn_return_type {
#fn_return_expr
}
}
}
}
}