use std::collections::{BTreeMap, BTreeSet, HashMap};
use std::num::NonZeroUsize;
use thiserror::Error;
use crate::desugar::desugared_ast::{
DagDecl, DimExpr, Expr, GenericConstraint, MulDivOp, TypeExpr, TypeExprKind, UnitExpr,
};
use crate::syntax::ast::UnitConstness;
use crate::syntax::dimension::{BaseDimId, Dimension, Rational, RationalError};
use crate::syntax::names::{
ConstructorName, DeclName, DimName, FieldName, GenericParamName, IndexName, IndexVariantName,
StructTypeName, UnitRef,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
pub enum PositiveFiniteScaleError {
#[error("scale must be finite")]
NonFinite,
#[error("scale must be greater than zero")]
NonPositive,
}
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
pub struct PositiveFiniteScale(f64);
impl PositiveFiniteScale {
pub fn new(value: f64) -> Result<Self, PositiveFiniteScaleError> {
if !value.is_finite() {
Err(PositiveFiniteScaleError::NonFinite)
} else if value <= 0.0 {
Err(PositiveFiniteScaleError::NonPositive)
} else {
Ok(Self(value))
}
}
#[must_use]
pub(crate) const fn new_unchecked(value: f64) -> Self {
Self(value)
}
#[must_use]
pub const fn get(self) -> f64 {
self.0
}
}
impl std::fmt::Display for PositiveFiniteScale {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
#[derive(Debug, Clone)]
pub enum UnitScale {
Static(PositiveFiniteScale),
Dynamic {
scale_expr: Expr,
base_unit_scale: PositiveFiniteScale,
},
}
impl UnitScale {
#[must_use]
pub const fn as_static(&self) -> Option<f64> {
match self {
Self::Static(s) => Some(s.get()),
Self::Dynamic { .. } => None,
}
}
#[must_use]
pub const fn is_static(&self) -> bool {
matches!(self, Self::Static(_))
}
#[must_use]
pub const fn is_dynamic(&self) -> bool {
matches!(self, Self::Dynamic { .. })
}
}
#[derive(Debug, Clone)]
pub struct UnitInfo {
pub dimension: Dimension,
pub constness: UnitConstness,
pub scale: UnitScale,
}
#[derive(Debug, Clone)]
pub struct StructField {
pub name: FieldName,
pub type_ann: TypeExpr,
}
#[derive(Debug, Clone)]
pub struct UnionMemberDef {
pub name: ConstructorName,
pub fields: Vec<StructField>,
}
#[derive(Debug, Clone)]
pub enum TypeDefKind {
Required,
Union { members: Vec<UnionMemberDef> },
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TypeGenericConstraint {
Dim,
Index,
Nat,
Unconstrained,
}
impl From<GenericConstraint> for TypeGenericConstraint {
fn from(c: GenericConstraint) -> Self {
match c {
GenericConstraint::Dim => Self::Dim,
GenericConstraint::Index => Self::Index,
GenericConstraint::Nat => Self::Nat,
GenericConstraint::Type => Self::Unconstrained,
}
}
}
#[derive(Debug, Clone)]
pub struct TypeGenericParam {
pub name: GenericParamName,
pub constraint: TypeGenericConstraint,
pub default: Option<crate::desugar::desugared_ast::TypeExpr>,
}
#[derive(Debug, Clone)]
pub struct TypeDef {
pub name: StructTypeName,
pub generic_params: Vec<TypeGenericParam>,
pub kind: TypeDefKind,
}
impl TypeDef {
#[must_use]
pub fn union_members(&self) -> Option<&[UnionMemberDef]> {
match &self.kind {
TypeDefKind::Union { members } => Some(members),
TypeDefKind::Required => None,
}
}
#[must_use]
pub const fn is_union(&self) -> bool {
matches!(self.kind, TypeDefKind::Union { .. })
}
#[must_use]
pub const fn is_required(&self) -> bool {
matches!(self.kind, TypeDefKind::Required)
}
#[must_use]
pub fn record_fields(&self) -> Option<&[StructField]> {
let TypeDefKind::Union { members } = &self.kind else {
return None;
};
let [only] = members.as_slice() else {
return None;
};
(only.name.as_str() == self.name.as_str()).then_some(only.fields.as_slice())
}
#[must_use]
pub fn fields(&self) -> &[StructField] {
self.record_fields().unwrap_or(&[])
}
}
#[derive(Debug, Clone)]
pub struct RangeIndexData {
pub start: f64,
pub end: f64,
pub step: f64,
pub step_count: NonZeroUsize,
pub dimension: Dimension,
pub display_label: Option<String>,
pub display_scale: f64,
}
impl RangeIndexData {
#[must_use]
#[expect(
clippy::cast_precision_loss,
reason = "range step indices are small enough for exact f64 representation"
)]
pub fn step_value(&self, i: usize) -> f64 {
(i as f64).mul_add(self.step, self.start)
}
#[must_use]
pub const fn step_count(&self) -> usize {
self.step_count.get()
}
}
#[derive(Debug, Clone)]
pub enum IndexKind {
Named { variants: Vec<IndexVariantName> },
Range(RangeIndexData),
RequiredNamed,
RequiredRange { dimension: Dimension },
NatRange {
size: NonZeroUsize,
},
}
#[derive(Debug, Clone)]
pub struct IndexDef {
pub name: IndexName,
pub kind: IndexKind,
}
impl IndexDef {
#[must_use]
pub fn variants(&self) -> Vec<IndexVariantName> {
match &self.kind {
IndexKind::Named { variants } => variants.clone(),
IndexKind::Range(data) => {
let count = data.step_count();
(0..count).map(IndexVariantName::range_step).collect()
}
IndexKind::NatRange { size } => {
(0..size.get()).map(IndexVariantName::range_step).collect()
}
IndexKind::RequiredNamed | IndexKind::RequiredRange { .. } => vec![],
}
}
#[must_use]
pub const fn step_count(&self) -> usize {
match &self.kind {
IndexKind::Named { variants } => variants.len(),
IndexKind::Range(data) => data.step_count(),
IndexKind::NatRange { size } => size.get(),
IndexKind::RequiredNamed | IndexKind::RequiredRange { .. } => 0,
}
}
#[must_use]
pub const fn range_data(&self) -> Option<&RangeIndexData> {
match &self.kind {
IndexKind::Range(data) => Some(data),
_ => None,
}
}
#[must_use]
pub const fn is_range(&self) -> bool {
matches!(
self.kind,
IndexKind::Range(_) | IndexKind::RequiredRange { .. }
)
}
#[must_use]
pub const fn is_named(&self) -> bool {
matches!(
self.kind,
IndexKind::Named { .. } | IndexKind::RequiredNamed
)
}
#[must_use]
pub const fn is_nat_range(&self) -> bool {
matches!(self.kind, IndexKind::NatRange { .. })
}
#[must_use]
pub const fn nat_range_size(&self) -> Option<u64> {
match &self.kind {
IndexKind::NatRange { size } => Some(size.get() as u64),
_ => None,
}
}
#[must_use]
pub const fn is_required(&self) -> bool {
matches!(
self.kind,
IndexKind::RequiredNamed | IndexKind::RequiredRange { .. }
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
pub enum NatRangeIndexError {
#[error("range(0) is not allowed; indexes must contain at least one element")]
Empty,
#[error("nat range size {size} does not fit in usize on this target")]
DoesNotFitUsize { size: u64 },
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct NatRangeIndex {
size: NonZeroUsize,
}
impl NatRangeIndex {
#[must_use]
pub const fn new(size: NonZeroUsize) -> Self {
Self { size }
}
pub fn try_from_u64(size: u64) -> Result<Self, NatRangeIndexError> {
if size == 0 {
return Err(NatRangeIndexError::Empty);
}
let size =
usize::try_from(size).map_err(|_| NatRangeIndexError::DoesNotFitUsize { size })?;
let size = NonZeroUsize::new(size).ok_or(NatRangeIndexError::Empty)?;
Ok(Self::new(size))
}
#[must_use]
pub const fn size(self) -> NonZeroUsize {
self.size
}
#[must_use]
#[expect(
clippy::expect_used,
reason = "Graphcal currently supports targets where usize fits in u64"
)]
pub fn size_u64(self) -> u64 {
u64::try_from(self.size.get()).expect("usize fits in u64 on supported targets")
}
#[must_use]
pub fn display_name(self) -> IndexName {
IndexName::new(format!("range({})", self.size_u64()))
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum UnitResolveError {
UnknownUnit(UnitRef),
DynamicScale(UnitRef),
Overflow(RationalError),
}
impl From<RationalError> for UnitResolveError {
fn from(err: RationalError) -> Self {
Self::Overflow(err)
}
}
fn resolve_dim_expr_impl(
dimensions: &HashMap<DimName, Dimension>,
expr: &DimExpr,
) -> Result<Option<Dimension>, RationalError> {
expr.terms
.iter()
.try_fold(Some(Dimension::dimensionless()), |acc, item| {
let Some(acc) = acc else {
return Ok(None);
};
let Some(atom) = item.term.name.value.as_bare() else {
return Ok(None);
};
let Some(base) = dimensions.get(atom.as_str()) else {
return Ok(None);
};
let exp = item.term.power.unwrap_or(Rational::ONE);
let powered = base.pow(exp)?;
match item.op {
MulDivOp::Mul => acc * powered,
MulDivOp::Div => acc / powered,
}
.map(Some)
})
}
fn resolve_type_expr_impl(
dimensions: &HashMap<DimName, Dimension>,
type_expr: &TypeExpr,
) -> Result<Option<Dimension>, RationalError> {
match &type_expr.kind {
TypeExprKind::Dimensionless => Ok(Some(Dimension::dimensionless())),
TypeExprKind::Bool
| TypeExprKind::Int
| TypeExprKind::Datetime
| TypeExprKind::TypeApplication { .. }
| TypeExprKind::DatetimeApplication { .. } => Ok(None),
TypeExprKind::DimExpr(dim_expr) => resolve_dim_expr_impl(dimensions, dim_expr),
TypeExprKind::Indexed { base, .. } => resolve_type_expr_impl(dimensions, base),
}
}
#[must_use]
pub fn pow_scale(scale: f64, exp: Rational) -> f64 {
if exp.is_integer() {
scale.powi(exp.num())
} else {
scale.powf(f64::from(exp.num()) / f64::from(exp.den()))
}
}
fn resolve_unit_expr_impl(
units: &HashMap<UnitRef, UnitInfo>,
expr: &UnitExpr,
) -> Result<(Dimension, f64), UnitResolveError> {
let mut dim = Dimension::dimensionless();
let mut scale = 1.0_f64;
for item in &expr.terms {
let Some(info) = units.get(&item.name.value) else {
return Err(UnitResolveError::UnknownUnit(item.name.value.clone()));
};
let exp = item.power.unwrap_or(Rational::ONE);
let powered_dim = info.dimension.pow(exp)?;
let Some(static_scale) = info.scale.as_static() else {
return Err(UnitResolveError::DynamicScale(item.name.value.clone()));
};
let powered_scale = pow_scale(static_scale, exp);
match item.op {
MulDivOp::Mul => {
dim = (dim * powered_dim)?;
scale *= powered_scale;
}
MulDivOp::Div => {
dim = (dim / powered_dim)?;
scale /= powered_scale;
}
}
}
Ok((dim, scale))
}
fn resolve_unit_dimension_impl(
units: &HashMap<UnitRef, UnitInfo>,
expr: &UnitExpr,
) -> Result<Dimension, UnitResolveError> {
let mut dim = Dimension::dimensionless();
for item in &expr.terms {
let Some(info) = units.get(&item.name.value) else {
return Err(UnitResolveError::UnknownUnit(item.name.value.clone()));
};
let exp = item.power.unwrap_or(Rational::ONE);
let powered_dim = info.dimension.pow(exp)?;
dim = match item.op {
MulDivOp::Mul => (dim * powered_dim)?,
MulDivOp::Div => (dim / powered_dim)?,
};
}
Ok(dim)
}
fn format_dimension_preferring_alias(
dimensions: &HashMap<DimName, Dimension>,
base_dim_names: &BTreeMap<BaseDimId, String>,
dim: &Dimension,
) -> String {
let canonical = format!("{}", dim.display_with(base_dim_names));
let is_compound = canonical.contains([' ', '^', '*', '/']);
if is_compound
&& let Some(alias) = dimensions
.iter()
.filter(|(_, d)| *d == dim)
.map(|(name, _)| name)
.min()
{
return alias.to_string();
}
canonical
}
#[derive(Debug, Clone)]
pub struct DimensionRegistry {
base_dim_names: BTreeMap<BaseDimId, String>,
base_dim_symbols: BTreeMap<BaseDimId, String>,
dimensions: HashMap<DimName, Dimension>,
}
impl DimensionRegistry {
#[must_use]
pub fn get_dimension(&self, name: &str) -> Option<&Dimension> {
self.dimensions.get(name)
}
pub fn all_dimensions(&self) -> impl Iterator<Item = (&DimName, &Dimension)> {
self.dimensions.iter()
}
#[must_use]
pub const fn base_dim_names(&self) -> &BTreeMap<BaseDimId, String> {
&self.base_dim_names
}
#[must_use]
pub const fn base_dim_symbols(&self) -> &BTreeMap<BaseDimId, String> {
&self.base_dim_symbols
}
#[must_use]
pub fn format_dimension(&self, dim: &Dimension) -> String {
format_dimension_preferring_alias(&self.dimensions, &self.base_dim_names, dim)
}
pub fn resolve_dim_expr(&self, expr: &DimExpr) -> Result<Option<Dimension>, RationalError> {
resolve_dim_expr_impl(&self.dimensions, expr)
}
pub fn resolve_type_expr(
&self,
type_expr: &TypeExpr,
) -> Result<Option<Dimension>, RationalError> {
resolve_type_expr_impl(&self.dimensions, type_expr)
}
}
#[derive(Debug, Clone)]
pub struct UnitRegistry {
units: HashMap<UnitRef, UnitInfo>,
}
impl UnitRegistry {
#[must_use]
pub fn get_unit(&self, name: &UnitRef) -> Option<&UnitInfo> {
self.units.get(name)
}
pub fn all_units(&self) -> impl Iterator<Item = (&UnitRef, &Dimension, &UnitScale)> {
self.units
.iter()
.map(|(name, info)| (name, &info.dimension, &info.scale))
}
pub fn resolve_unit_expr(&self, expr: &UnitExpr) -> Result<(Dimension, f64), UnitResolveError> {
resolve_unit_expr_impl(&self.units, expr)
}
pub fn resolve_unit_dimension(&self, expr: &UnitExpr) -> Result<Dimension, UnitResolveError> {
resolve_unit_dimension_impl(&self.units, expr)
}
}
#[derive(Debug, Clone)]
pub struct TypeRegistry {
types: HashMap<StructTypeName, TypeDef>,
ctors: HashMap<ConstructorName, StructTypeName>,
}
impl TypeRegistry {
#[must_use]
pub fn get_type(&self, name: &str) -> Option<&TypeDef> {
self.types.get(name)
}
#[must_use]
pub fn lookup_ctor(&self, ctor: &ConstructorName) -> Option<(&TypeDef, &UnionMemberDef)> {
let union_name = self.ctors.get(ctor)?;
let td = self.types.get(union_name)?;
let members = td.union_members()?;
let member = members.iter().find(|m| m.name == *ctor)?;
Some((td, member))
}
pub fn all_types(&self) -> impl Iterator<Item = &TypeDef> {
self.types.values()
}
}
#[derive(Debug, Clone)]
pub struct IndexRegistry {
indexes: HashMap<IndexName, IndexDef>,
nat_ranges: HashMap<NatRangeIndex, IndexDef>,
}
impl IndexRegistry {
#[must_use]
pub fn get_index(&self, name: &str) -> Option<&IndexDef> {
self.indexes.get(name)
}
#[must_use]
pub fn get_nat_range(&self, index: NatRangeIndex) -> Option<&IndexDef> {
self.nat_ranges.get(&index)
}
pub fn all_indexes(&self) -> impl Iterator<Item = &IndexDef> {
self.indexes.values().chain(self.nat_ranges.values())
}
}
#[derive(Debug, Clone)]
pub struct Registry {
pub dimensions: DimensionRegistry,
pub units: UnitRegistry,
pub types: TypeRegistry,
pub indexes: IndexRegistry,
pub dags: DagRegistry,
}
#[derive(Debug, Default, Clone)]
pub struct DagRegistry {
dags: HashMap<DeclName, DagDecl>,
}
impl DagRegistry {
#[must_use]
pub fn get(&self, name: &str) -> Option<&DagDecl> {
self.dags.get(name)
}
pub fn all_dags(&self) -> impl Iterator<Item = (&DeclName, &DagDecl)> {
self.dags.iter()
}
}
#[derive(Debug, Default)]
pub struct RegistryBuilder {
base_dim_names: BTreeMap<BaseDimId, String>,
base_dim_symbols: BTreeMap<BaseDimId, String>,
dimensions: HashMap<DimName, Dimension>,
units: HashMap<UnitRef, UnitInfo>,
types: HashMap<StructTypeName, TypeDef>,
ctors: HashMap<ConstructorName, StructTypeName>,
indexes: HashMap<IndexName, IndexDef>,
nat_ranges: HashMap<NatRangeIndex, IndexDef>,
dags: HashMap<DeclName, DagDecl>,
affine_prone_dims: BTreeSet<BaseDimId>,
}
impl RegistryBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn build(self) -> Registry {
Registry {
dimensions: DimensionRegistry {
base_dim_names: self.base_dim_names,
base_dim_symbols: self.base_dim_symbols,
dimensions: self.dimensions,
},
units: UnitRegistry { units: self.units },
types: TypeRegistry {
types: self.types,
ctors: self.ctors,
},
indexes: IndexRegistry {
indexes: self.indexes,
nat_ranges: self.nat_ranges,
},
dags: DagRegistry { dags: self.dags },
}
}
pub fn register_dag(&mut self, name: DeclName, decl: DagDecl) {
self.dags.insert(name, decl);
}
pub fn merge_from_registry(&mut self, parent: &Registry) {
for (id, name) in &parent.dimensions.base_dim_names {
self.base_dim_names
.entry(id.clone())
.or_insert_with(|| name.clone());
}
for (id, symbol) in &parent.dimensions.base_dim_symbols {
self.base_dim_symbols
.entry(id.clone())
.or_insert_with(|| symbol.clone());
}
for (name, dim) in &parent.dimensions.dimensions {
self.dimensions
.entry(name.clone())
.or_insert_with(|| dim.clone());
}
for (name, info) in &parent.units.units {
self.units
.entry(name.clone())
.or_insert_with(|| info.clone());
}
for (name, def) in &parent.types.types {
self.types
.entry(name.clone())
.or_insert_with(|| def.clone());
}
for (ctor, union_name) in &parent.types.ctors {
self.ctors
.entry(ctor.clone())
.or_insert_with(|| union_name.clone());
}
for (name, def) in &parent.indexes.indexes {
self.indexes
.entry(name.clone())
.or_insert_with(|| def.clone());
}
for (index, def) in &parent.indexes.nat_ranges {
self.nat_ranges.entry(*index).or_insert_with(|| def.clone());
}
for (name, decl) in &parent.dags.dags {
self.dags
.entry(name.clone())
.or_insert_with(|| decl.clone());
}
}
pub fn mark_affine_prone(&mut self, id: BaseDimId) {
self.affine_prone_dims.insert(id);
}
#[must_use]
pub fn is_affine_prone(&self, dim: &Dimension) -> bool {
let mut iter = dim.iter();
let Some((id, &exp)) = iter.next() else {
return false;
};
iter.next().is_none() && exp == Rational::ONE && self.affine_prone_dims.contains(id)
}
pub fn register_base_dimension(&mut self, name: DimName, id: BaseDimId) -> BaseDimId {
let dim = Dimension::base(id.clone());
self.base_dim_names.insert(id.clone(), name.to_string());
self.dimensions.insert(name, dim);
id
}
pub fn register_base_dimension_with_symbol(
&mut self,
name: DimName,
id: BaseDimId,
symbol: String,
) -> BaseDimId {
let id = self.register_base_dimension(name, id);
self.base_dim_symbols.insert(id.clone(), symbol);
id
}
pub fn set_base_dim_symbol(&mut self, id: BaseDimId, symbol: String) {
self.base_dim_symbols.entry(id).or_insert(symbol);
}
pub fn register_dimension(&mut self, name: DimName, dim: Dimension) {
self.dimensions.insert(name, dim);
}
pub fn register_unit(
&mut self,
name: impl Into<UnitRef>,
dimension: Dimension,
scale: PositiveFiniteScale,
) {
self.units.insert(
name.into(),
UnitInfo {
dimension,
constness: UnitConstness::Const,
scale: UnitScale::Static(scale),
},
);
}
pub fn register_unit_with_scale(
&mut self,
name: impl Into<UnitRef>,
dimension: Dimension,
scale: UnitScale,
constness: UnitConstness,
) {
self.units.insert(
name.into(),
UnitInfo {
dimension,
constness,
scale,
},
);
}
pub fn register_unit_dynamic(
&mut self,
name: impl Into<UnitRef>,
dimension: Dimension,
scale: UnitScale,
) {
self.register_unit_with_scale(name, dimension, scale, UnitConstness::Dynamic);
}
pub fn register_type(&mut self, def: TypeDef) {
if let TypeDefKind::Union { ref members } = def.kind {
for member in members {
self.ctors.insert(member.name.clone(), def.name.clone());
}
}
self.types.insert(def.name.clone(), def);
}
pub fn register_index(&mut self, def: IndexDef) {
self.indexes.insert(def.name.clone(), def);
}
pub fn ensure_nat_range_index(&mut self, size: NonZeroUsize) -> NatRangeIndex {
let nat_range = NatRangeIndex::new(size);
self.nat_ranges
.entry(nat_range)
.or_insert_with(|| IndexDef {
name: nat_range.display_name(),
kind: IndexKind::NatRange { size },
});
nat_range
}
#[must_use]
pub fn get_dimension(&self, name: &str) -> Option<&Dimension> {
self.dimensions.get(name)
}
#[must_use]
pub fn get_unit(&self, name: &UnitRef) -> Option<&UnitInfo> {
self.units.get(name)
}
pub fn all_units(&self) -> impl Iterator<Item = (&UnitRef, &Dimension, &UnitScale)> {
self.units
.iter()
.map(|(name, info)| (name, &info.dimension, &info.scale))
}
#[must_use]
pub fn get_type(&self, name: &str) -> Option<&TypeDef> {
self.types.get(name)
}
#[must_use]
pub fn get_index(&self, name: &str) -> Option<&IndexDef> {
self.indexes.get(name)
}
#[must_use]
pub fn get_nat_range(&self, index: NatRangeIndex) -> Option<&IndexDef> {
self.nat_ranges.get(&index)
}
#[must_use]
pub const fn base_dim_names(&self) -> &BTreeMap<BaseDimId, String> {
&self.base_dim_names
}
#[must_use]
pub const fn base_dim_symbols(&self) -> &BTreeMap<BaseDimId, String> {
&self.base_dim_symbols
}
#[must_use]
pub fn format_dimension(&self, dim: &Dimension) -> String {
format_dimension_preferring_alias(&self.dimensions, &self.base_dim_names, dim)
}
pub fn resolve_dim_expr(&self, expr: &DimExpr) -> Result<Option<Dimension>, RationalError> {
resolve_dim_expr_impl(&self.dimensions, expr)
}
pub fn resolve_type_expr(
&self,
type_expr: &TypeExpr,
) -> Result<Option<Dimension>, RationalError> {
resolve_type_expr_impl(&self.dimensions, type_expr)
}
pub fn resolve_unit_expr(&self, expr: &UnitExpr) -> Result<(Dimension, f64), UnitResolveError> {
resolve_unit_expr_impl(&self.units, expr)
}
pub fn resolve_unit_dimension(&self, expr: &UnitExpr) -> Result<Dimension, UnitResolveError> {
resolve_unit_dimension_impl(&self.units, expr)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::registry::prelude::load_prelude;
use crate::syntax::ast::{DimExprItem, DimTerm, UnitExprItem};
use crate::syntax::dimension::BaseDimId;
use crate::syntax::names::{NamePath, UnitName};
use crate::syntax::span::Span;
use crate::syntax::span::Spanned;
fn length_id() -> BaseDimId {
BaseDimId::Prelude("Length".to_string())
}
fn time_id() -> BaseDimId {
BaseDimId::Prelude("Time".to_string())
}
fn mass_id() -> BaseDimId {
BaseDimId::Prelude("Mass".to_string())
}
fn make_registry() -> Registry {
let mut b = RegistryBuilder::new();
load_prelude(&mut b).unwrap();
b.build()
}
fn make_dim_term_name(name: &str) -> Spanned<NamePath> {
Spanned::new(NamePath::from(name), Span::new(0, 0))
}
fn make_dim_type_expr(name: &str) -> TypeExpr {
use crate::syntax::ast::{DimExpr, DimExprItem, DimTerm};
TypeExpr {
kind: TypeExprKind::DimExpr(DimExpr {
terms: vec![DimExprItem {
op: MulDivOp::Mul,
term: DimTerm {
name: make_dim_term_name(name),
power: None,
span: Span::new(0, 0),
},
}],
span: Span::new(0, 0),
}),
constraints: vec![],
span: Span::new(0, 0),
}
}
fn make_unit_name(name: &str) -> Spanned<UnitRef> {
Spanned::new(UnitRef::local(UnitName::new(name)), Span::new(0, 0))
}
#[test]
fn registry_base_dimensions() {
let r = make_registry();
assert_eq!(
r.dimensions.get_dimension("Length"),
Some(&Dimension::base(length_id()))
);
assert_eq!(
r.dimensions.get_dimension("Time"),
Some(&Dimension::base(time_id()))
);
assert_eq!(
r.dimensions.get_dimension("Mass"),
Some(&Dimension::base(mass_id()))
);
}
#[test]
fn registry_derived_dimensions() {
let r = make_registry();
let velocity = r.dimensions.get_dimension("Velocity").unwrap();
let expected = (Dimension::base(length_id()) / Dimension::base(time_id())).unwrap();
assert_eq!(*velocity, expected);
}
#[test]
fn registry_base_units() {
let r = make_registry();
let m = r.units.get_unit(&UnitRef::local("m")).unwrap();
assert_eq!(m.dimension, Dimension::base(length_id()));
assert!((m.scale.as_static().unwrap() - 1.0).abs() < f64::EPSILON);
}
#[test]
fn registry_derived_units() {
let r = make_registry();
let km = r.units.get_unit(&UnitRef::local("km")).unwrap();
assert_eq!(km.dimension, Dimension::base(length_id()));
assert!((km.scale.as_static().unwrap() - 1000.0).abs() < f64::EPSILON);
}
#[test]
fn resolve_dim_expr_velocity() {
let r = make_registry();
let expr = DimExpr {
terms: vec![
DimExprItem {
op: MulDivOp::Mul,
term: DimTerm {
name: make_dim_term_name("Length"),
power: None,
span: Span::new(0, 0),
},
},
DimExprItem {
op: MulDivOp::Div,
term: DimTerm {
name: make_dim_term_name("Time"),
power: None,
span: Span::new(0, 0),
},
},
],
span: Span::new(0, 0),
};
let dim = r.dimensions.resolve_dim_expr(&expr).unwrap().unwrap();
let expected = (Dimension::base(length_id()) / Dimension::base(time_id())).unwrap();
assert_eq!(dim, expected);
}
#[test]
fn resolve_unit_expr_m_per_s_squared() {
let r = make_registry();
let expr = UnitExpr {
terms: vec![
UnitExprItem {
op: MulDivOp::Mul,
name: make_unit_name("m"),
power: None,
},
UnitExprItem {
op: MulDivOp::Div,
name: make_unit_name("s"),
power: Some(Rational::from_int(2)),
},
],
span: Span::new(0, 0),
};
let (dim, scale) = r.units.resolve_unit_expr(&expr).unwrap();
let expected_dim = (Dimension::base(length_id())
/ Dimension::base(time_id()).pow_int(2).unwrap())
.unwrap();
assert_eq!(dim, expected_dim);
assert!((scale - 1.0).abs() < f64::EPSILON);
}
#[test]
fn resolve_unit_expr_km_per_hour() {
let r = make_registry();
let expr = UnitExpr {
terms: vec![
UnitExprItem {
op: MulDivOp::Mul,
name: make_unit_name("km"),
power: None,
},
UnitExprItem {
op: MulDivOp::Div,
name: make_unit_name("hour"),
power: None,
},
],
span: Span::new(0, 0),
};
let (dim, scale) = r.units.resolve_unit_expr(&expr).unwrap();
let expected_dim = (Dimension::base(length_id()) / Dimension::base(time_id())).unwrap();
assert_eq!(dim, expected_dim);
assert!((scale - 1000.0 / 3600.0).abs() < 1e-10);
}
#[test]
fn registry_type_register_and_lookup() {
let mut b = RegistryBuilder::new();
load_prelude(&mut b).unwrap();
b.register_type(TypeDef {
name: StructTypeName::new("TransferResult"),
generic_params: vec![],
kind: TypeDefKind::Union {
members: vec![UnionMemberDef {
name: ConstructorName::new("TransferResult"),
fields: vec![
StructField {
name: FieldName::new("dv1"),
type_ann: make_dim_type_expr("Velocity"),
},
StructField {
name: FieldName::new("dv2"),
type_ann: make_dim_type_expr("Velocity"),
},
],
}],
},
});
let r = b.build();
let velocity_dim = (Dimension::base(length_id()) / Dimension::base(time_id())).unwrap();
let def = r.types.get_type("TransferResult").unwrap();
assert_eq!(def.name.as_str(), "TransferResult");
assert!(def.is_union());
let fields = def.record_fields().expect("single-variant collision");
assert_eq!(fields.len(), 2);
assert_eq!(fields[0].name.as_str(), "dv1");
assert_eq!(
r.dimensions.resolve_type_expr(&fields[0].type_ann),
Ok(Some(velocity_dim))
);
assert!(r.types.get_type("NonExistent").is_none());
}
#[test]
fn registry_index_register_and_lookup() {
let mut b = RegistryBuilder::new();
load_prelude(&mut b).unwrap();
b.register_index(IndexDef {
name: IndexName::new("Maneuver"),
kind: IndexKind::Named {
variants: vec![
IndexVariantName::new("Departure"),
IndexVariantName::new("Correction"),
IndexVariantName::new("Insertion"),
],
},
});
let r = b.build();
let def = r.indexes.get_index("Maneuver").unwrap();
assert_eq!(def.name.as_str(), "Maneuver");
let variants = def.variants();
let variant_strs: Vec<&str> = variants.iter().map(IndexVariantName::as_str).collect();
assert_eq!(variant_strs, vec!["Departure", "Correction", "Insertion"]);
assert!(r.indexes.get_index("NonExistent").is_none());
}
#[test]
fn register_user_defined_base_dimension() {
let mut b = RegistryBuilder::new();
load_prelude(&mut b).unwrap();
let info_id = BaseDimId::UserDefined {
dag: crate::dag_id::DagId::root("test"),
name: "Information".to_string(),
};
let id = b.register_base_dimension(DimName::new("Information"), info_id.clone());
assert_eq!(id, info_id);
let r = b.build();
let dim = r.dimensions.get_dimension("Information").unwrap();
assert_eq!(*dim, Dimension::base(id.clone()));
assert_eq!(
r.dimensions.base_dim_names().get(&id),
Some(&"Information".to_string())
);
}
#[test]
fn register_base_dimension_with_symbol() {
let mut b = RegistryBuilder::new();
let id = b.register_base_dimension_with_symbol(
DimName::new("Length"),
BaseDimId::Prelude("Length".to_string()),
"m".to_string(),
);
let r = b.build();
assert_eq!(
r.dimensions.base_dim_symbols().get(&id),
Some(&"m".to_string())
);
}
#[test]
fn set_base_dim_symbol_only_first() {
let mut b = RegistryBuilder::new();
let info_id = BaseDimId::UserDefined {
dag: crate::dag_id::DagId::root("test"),
name: "Information".to_string(),
};
let id = b.register_base_dimension(DimName::new("Information"), info_id);
b.set_base_dim_symbol(id.clone(), "bit".to_string());
b.set_base_dim_symbol(id.clone(), "byte".to_string());
let r = b.build();
assert_eq!(
r.dimensions.base_dim_symbols().get(&id),
Some(&"bit".to_string())
);
}
}