1use std::collections::{BTreeMap, BTreeSet, HashMap};
2use std::num::NonZeroUsize;
3
4use thiserror::Error;
5
6use crate::desugar::desugared_ast::{
7 DagDecl, DimExpr, Expr, GenericConstraint, MulDivOp, TypeExpr, TypeExprKind, UnitExpr,
8};
9use crate::syntax::ast::UnitConstness;
10use crate::syntax::dimension::{BaseDimId, Dimension, Rational, RationalError};
11use crate::syntax::names::{
12 ConstructorName, DeclName, DimName, FieldName, GenericParamName, IndexName, IndexVariantName,
13 StructTypeName, UnitRef,
14};
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
21pub enum PositiveFiniteScaleError {
22 #[error("scale must be finite")]
23 NonFinite,
24 #[error("scale must be greater than zero")]
25 NonPositive,
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
30pub struct PositiveFiniteScale(f64);
31
32impl PositiveFiniteScale {
33 pub fn new(value: f64) -> Result<Self, PositiveFiniteScaleError> {
39 if !value.is_finite() {
40 Err(PositiveFiniteScaleError::NonFinite)
41 } else if value <= 0.0 {
42 Err(PositiveFiniteScaleError::NonPositive)
43 } else {
44 Ok(Self(value))
45 }
46 }
47
48 #[must_use]
53 pub(crate) const fn new_unchecked(value: f64) -> Self {
54 Self(value)
55 }
56
57 #[must_use]
59 pub const fn get(self) -> f64 {
60 self.0
61 }
62}
63
64impl std::fmt::Display for PositiveFiniteScale {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 self.0.fmt(f)
67 }
68}
69
70#[derive(Debug, Clone)]
72pub enum UnitScale {
73 Static(PositiveFiniteScale),
75 Dynamic {
79 scale_expr: Expr,
81 base_unit_scale: PositiveFiniteScale,
84 },
85}
86
87impl UnitScale {
88 #[must_use]
90 pub const fn as_static(&self) -> Option<f64> {
91 match self {
92 Self::Static(s) => Some(s.get()),
93 Self::Dynamic { .. } => None,
94 }
95 }
96
97 #[must_use]
99 pub const fn is_static(&self) -> bool {
100 matches!(self, Self::Static(_))
101 }
102
103 #[must_use]
105 pub const fn is_dynamic(&self) -> bool {
106 matches!(self, Self::Dynamic { .. })
107 }
108}
109
110#[derive(Debug, Clone)]
112pub struct UnitInfo {
113 pub dimension: Dimension,
115 pub constness: UnitConstness,
117 pub scale: UnitScale,
120}
121
122#[derive(Debug, Clone)]
124pub struct StructField {
125 pub name: FieldName,
126 pub type_ann: TypeExpr,
127}
128
129#[derive(Debug, Clone)]
135pub struct UnionMemberDef {
136 pub name: ConstructorName,
138 pub fields: Vec<StructField>,
141}
142
143#[derive(Debug, Clone)]
152pub enum TypeDefKind {
153 Required,
156 Union { members: Vec<UnionMemberDef> },
160}
161
162#[derive(Debug, Clone, Copy, PartialEq, Eq)]
164pub enum TypeGenericConstraint {
165 Dim,
167 Index,
169 Nat,
171 Unconstrained,
173}
174
175impl From<GenericConstraint> for TypeGenericConstraint {
176 fn from(c: GenericConstraint) -> Self {
177 match c {
178 GenericConstraint::Dim => Self::Dim,
179 GenericConstraint::Index => Self::Index,
180 GenericConstraint::Nat => Self::Nat,
181 GenericConstraint::Type => Self::Unconstrained,
182 }
183 }
184}
185
186#[derive(Debug, Clone)]
188pub struct TypeGenericParam {
189 pub name: GenericParamName,
190 pub constraint: TypeGenericConstraint,
191 pub default: Option<crate::desugar::desugared_ast::TypeExpr>,
193}
194
195#[derive(Debug, Clone)]
197pub struct TypeDef {
198 pub name: StructTypeName,
199 pub generic_params: Vec<TypeGenericParam>,
200 pub kind: TypeDefKind,
201}
202
203impl TypeDef {
204 #[must_use]
208 pub fn union_members(&self) -> Option<&[UnionMemberDef]> {
209 match &self.kind {
210 TypeDefKind::Union { members } => Some(members),
211 TypeDefKind::Required => None,
212 }
213 }
214
215 #[must_use]
218 pub const fn is_union(&self) -> bool {
219 matches!(self.kind, TypeDefKind::Union { .. })
220 }
221
222 #[must_use]
224 pub const fn is_required(&self) -> bool {
225 matches!(self.kind, TypeDefKind::Required)
226 }
227
228 #[must_use]
238 pub fn record_fields(&self) -> Option<&[StructField]> {
239 let TypeDefKind::Union { members } = &self.kind else {
240 return None;
241 };
242 let [only] = members.as_slice() else {
243 return None;
244 };
245 (only.name.as_str() == self.name.as_str()).then_some(only.fields.as_slice())
246 }
247
248 #[must_use]
253 pub fn fields(&self) -> &[StructField] {
254 self.record_fields().unwrap_or(&[])
255 }
256}
257
258#[derive(Debug, Clone)]
260pub struct RangeIndexData {
261 pub start: f64,
262 pub end: f64,
263 pub step: f64,
264 pub step_count: NonZeroUsize,
266 pub dimension: Dimension,
267 pub display_label: Option<String>,
269 pub display_scale: f64,
271}
272
273impl RangeIndexData {
274 #[must_use]
276 #[expect(
277 clippy::cast_precision_loss,
278 reason = "range step indices are small enough for exact f64 representation"
279 )]
280 pub fn step_value(&self, i: usize) -> f64 {
281 (i as f64).mul_add(self.step, self.start)
282 }
283
284 #[must_use]
286 pub const fn step_count(&self) -> usize {
287 self.step_count.get()
288 }
289}
290
291#[derive(Debug, Clone)]
293pub enum IndexKind {
294 Named { variants: Vec<IndexVariantName> },
296 Range(RangeIndexData),
298 RequiredNamed,
300 RequiredRange { dimension: Dimension },
302 NatRange {
306 size: NonZeroUsize,
310 },
311}
312
313#[derive(Debug, Clone)]
315pub struct IndexDef {
316 pub name: IndexName,
317 pub kind: IndexKind,
318}
319
320impl IndexDef {
321 #[must_use]
328 pub fn variants(&self) -> Vec<IndexVariantName> {
329 match &self.kind {
330 IndexKind::Named { variants } => variants.clone(),
331 IndexKind::Range(data) => {
332 let count = data.step_count();
333 (0..count).map(IndexVariantName::range_step).collect()
334 }
335 IndexKind::NatRange { size } => {
336 (0..size.get()).map(IndexVariantName::range_step).collect()
337 }
338 IndexKind::RequiredNamed | IndexKind::RequiredRange { .. } => vec![],
339 }
340 }
341
342 #[must_use]
346 pub const fn step_count(&self) -> usize {
347 match &self.kind {
348 IndexKind::Named { variants } => variants.len(),
349 IndexKind::Range(data) => data.step_count(),
350 IndexKind::NatRange { size } => size.get(),
351 IndexKind::RequiredNamed | IndexKind::RequiredRange { .. } => 0,
352 }
353 }
354
355 #[must_use]
357 pub const fn range_data(&self) -> Option<&RangeIndexData> {
358 match &self.kind {
359 IndexKind::Range(data) => Some(data),
360 _ => None,
361 }
362 }
363
364 #[must_use]
366 pub const fn is_range(&self) -> bool {
367 matches!(
368 self.kind,
369 IndexKind::Range(_) | IndexKind::RequiredRange { .. }
370 )
371 }
372
373 #[must_use]
375 pub const fn is_named(&self) -> bool {
376 matches!(
377 self.kind,
378 IndexKind::Named { .. } | IndexKind::RequiredNamed
379 )
380 }
381
382 #[must_use]
384 pub const fn is_nat_range(&self) -> bool {
385 matches!(self.kind, IndexKind::NatRange { .. })
386 }
387
388 #[must_use]
390 pub const fn nat_range_size(&self) -> Option<u64> {
391 match &self.kind {
392 IndexKind::NatRange { size } => Some(size.get() as u64),
393 _ => None,
394 }
395 }
396
397 #[must_use]
399 pub const fn is_required(&self) -> bool {
400 matches!(
401 self.kind,
402 IndexKind::RequiredNamed | IndexKind::RequiredRange { .. }
403 )
404 }
405}
406
407#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
413pub enum NatRangeIndexError {
414 #[error("range(0) is not allowed; indexes must contain at least one element")]
416 Empty,
417 #[error("nat range size {size} does not fit in usize on this target")]
419 DoesNotFitUsize { size: u64 },
420}
421
422#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
427pub struct NatRangeIndex {
428 size: NonZeroUsize,
429}
430
431impl NatRangeIndex {
432 #[must_use]
434 pub const fn new(size: NonZeroUsize) -> Self {
435 Self { size }
436 }
437
438 pub fn try_from_u64(size: u64) -> Result<Self, NatRangeIndexError> {
444 if size == 0 {
445 return Err(NatRangeIndexError::Empty);
446 }
447 let size =
448 usize::try_from(size).map_err(|_| NatRangeIndexError::DoesNotFitUsize { size })?;
449 let size = NonZeroUsize::new(size).ok_or(NatRangeIndexError::Empty)?;
450 Ok(Self::new(size))
451 }
452
453 #[must_use]
455 pub const fn size(self) -> NonZeroUsize {
456 self.size
457 }
458
459 #[must_use]
461 #[expect(
462 clippy::expect_used,
463 reason = "Graphcal currently supports targets where usize fits in u64"
464 )]
465 pub fn size_u64(self) -> u64 {
466 u64::try_from(self.size.get()).expect("usize fits in u64 on supported targets")
467 }
468
469 #[must_use]
471 pub fn display_name(self) -> IndexName {
472 IndexName::new(format!("range({})", self.size_u64()))
473 }
474}
475
476#[derive(Debug, Clone, PartialEq)]
486pub enum UnitResolveError {
487 UnknownUnit(UnitRef),
489 DynamicScale(UnitRef),
491 Overflow(RationalError),
493}
494
495impl From<RationalError> for UnitResolveError {
496 fn from(err: RationalError) -> Self {
497 Self::Overflow(err)
498 }
499}
500
501fn resolve_dim_expr_impl(
503 dimensions: &HashMap<DimName, Dimension>,
504 expr: &DimExpr,
505) -> Result<Option<Dimension>, RationalError> {
506 expr.terms
507 .iter()
508 .try_fold(Some(Dimension::dimensionless()), |acc, item| {
509 let Some(acc) = acc else {
510 return Ok(None);
511 };
512 let Some(atom) = item.term.name.value.as_bare() else {
513 return Ok(None);
514 };
515 let Some(base) = dimensions.get(atom.as_str()) else {
516 return Ok(None);
517 };
518 let exp = item.term.power.unwrap_or(Rational::ONE);
519 let powered = base.pow(exp)?;
520 match item.op {
521 MulDivOp::Mul => acc * powered,
522 MulDivOp::Div => acc / powered,
523 }
524 .map(Some)
525 })
526}
527
528fn resolve_type_expr_impl(
530 dimensions: &HashMap<DimName, Dimension>,
531 type_expr: &TypeExpr,
532) -> Result<Option<Dimension>, RationalError> {
533 match &type_expr.kind {
534 TypeExprKind::Dimensionless => Ok(Some(Dimension::dimensionless())),
535 TypeExprKind::Bool
536 | TypeExprKind::Int
537 | TypeExprKind::Datetime
538 | TypeExprKind::TypeApplication { .. }
539 | TypeExprKind::DatetimeApplication { .. } => Ok(None),
540 TypeExprKind::DimExpr(dim_expr) => resolve_dim_expr_impl(dimensions, dim_expr),
541 TypeExprKind::Indexed { base, .. } => resolve_type_expr_impl(dimensions, base),
542 }
543}
544
545#[must_use]
550pub fn pow_scale(scale: f64, exp: Rational) -> f64 {
551 if exp.is_integer() {
552 scale.powi(exp.num())
553 } else {
554 scale.powf(f64::from(exp.num()) / f64::from(exp.den()))
555 }
556}
557
558fn resolve_unit_expr_impl(
560 units: &HashMap<UnitRef, UnitInfo>,
561 expr: &UnitExpr,
562) -> Result<(Dimension, f64), UnitResolveError> {
563 let mut dim = Dimension::dimensionless();
564 let mut scale = 1.0_f64;
565 for item in &expr.terms {
566 let Some(info) = units.get(&item.name.value) else {
567 return Err(UnitResolveError::UnknownUnit(item.name.value.clone()));
568 };
569 let exp = item.power.unwrap_or(Rational::ONE);
570 let powered_dim = info.dimension.pow(exp)?;
571 let Some(static_scale) = info.scale.as_static() else {
572 return Err(UnitResolveError::DynamicScale(item.name.value.clone()));
573 };
574 let powered_scale = pow_scale(static_scale, exp);
575 match item.op {
576 MulDivOp::Mul => {
577 dim = (dim * powered_dim)?;
578 scale *= powered_scale;
579 }
580 MulDivOp::Div => {
581 dim = (dim / powered_dim)?;
582 scale /= powered_scale;
583 }
584 }
585 }
586 Ok((dim, scale))
587}
588
589fn resolve_unit_dimension_impl(
593 units: &HashMap<UnitRef, UnitInfo>,
594 expr: &UnitExpr,
595) -> Result<Dimension, UnitResolveError> {
596 let mut dim = Dimension::dimensionless();
597 for item in &expr.terms {
598 let Some(info) = units.get(&item.name.value) else {
599 return Err(UnitResolveError::UnknownUnit(item.name.value.clone()));
600 };
601 let exp = item.power.unwrap_or(Rational::ONE);
602 let powered_dim = info.dimension.pow(exp)?;
603 dim = match item.op {
604 MulDivOp::Mul => (dim * powered_dim)?,
605 MulDivOp::Div => (dim / powered_dim)?,
606 };
607 }
608 Ok(dim)
609}
610
611fn format_dimension_preferring_alias(
618 dimensions: &HashMap<DimName, Dimension>,
619 base_dim_names: &BTreeMap<BaseDimId, String>,
620 dim: &Dimension,
621) -> String {
622 let canonical = format!("{}", dim.display_with(base_dim_names));
623 let is_compound = canonical.contains([' ', '^', '*', '/']);
626 if is_compound
627 && let Some(alias) = dimensions
628 .iter()
629 .filter(|(_, d)| *d == dim)
630 .map(|(name, _)| name)
631 .min()
632 {
633 return alias.to_string();
634 }
635 canonical
636}
637
638#[derive(Debug, Clone)]
645pub struct DimensionRegistry {
646 base_dim_names: BTreeMap<BaseDimId, String>,
648 base_dim_symbols: BTreeMap<BaseDimId, String>,
650 dimensions: HashMap<DimName, Dimension>,
651}
652
653impl DimensionRegistry {
654 #[must_use]
656 pub fn get_dimension(&self, name: &str) -> Option<&Dimension> {
657 self.dimensions.get(name)
658 }
659
660 pub fn all_dimensions(&self) -> impl Iterator<Item = (&DimName, &Dimension)> {
662 self.dimensions.iter()
663 }
664
665 #[must_use]
667 pub const fn base_dim_names(&self) -> &BTreeMap<BaseDimId, String> {
668 &self.base_dim_names
669 }
670
671 #[must_use]
673 pub const fn base_dim_symbols(&self) -> &BTreeMap<BaseDimId, String> {
674 &self.base_dim_symbols
675 }
676
677 #[must_use]
684 pub fn format_dimension(&self, dim: &Dimension) -> String {
685 format_dimension_preferring_alias(&self.dimensions, &self.base_dim_names, dim)
686 }
687
688 pub fn resolve_dim_expr(&self, expr: &DimExpr) -> Result<Option<Dimension>, RationalError> {
693 resolve_dim_expr_impl(&self.dimensions, expr)
694 }
695
696 pub fn resolve_type_expr(
701 &self,
702 type_expr: &TypeExpr,
703 ) -> Result<Option<Dimension>, RationalError> {
704 resolve_type_expr_impl(&self.dimensions, type_expr)
705 }
706}
707
708#[derive(Debug, Clone)]
710pub struct UnitRegistry {
711 units: HashMap<UnitRef, UnitInfo>,
712}
713
714impl UnitRegistry {
715 #[must_use]
717 pub fn get_unit(&self, name: &UnitRef) -> Option<&UnitInfo> {
718 self.units.get(name)
719 }
720
721 pub fn all_units(&self) -> impl Iterator<Item = (&UnitRef, &Dimension, &UnitScale)> {
723 self.units
724 .iter()
725 .map(|(name, info)| (name, &info.dimension, &info.scale))
726 }
727
728 pub fn resolve_unit_expr(&self, expr: &UnitExpr) -> Result<(Dimension, f64), UnitResolveError> {
735 resolve_unit_expr_impl(&self.units, expr)
736 }
737
738 pub fn resolve_unit_dimension(&self, expr: &UnitExpr) -> Result<Dimension, UnitResolveError> {
747 resolve_unit_dimension_impl(&self.units, expr)
748 }
749}
750
751#[derive(Debug, Clone)]
760pub struct TypeRegistry {
761 types: HashMap<StructTypeName, TypeDef>,
762 ctors: HashMap<ConstructorName, StructTypeName>,
768}
769
770impl TypeRegistry {
771 #[must_use]
773 pub fn get_type(&self, name: &str) -> Option<&TypeDef> {
774 self.types.get(name)
775 }
776
777 #[must_use]
781 pub fn lookup_ctor(&self, ctor: &ConstructorName) -> Option<(&TypeDef, &UnionMemberDef)> {
782 let union_name = self.ctors.get(ctor)?;
783 let td = self.types.get(union_name)?;
784 let members = td.union_members()?;
785 let member = members.iter().find(|m| m.name == *ctor)?;
786 Some((td, member))
787 }
788
789 pub fn all_types(&self) -> impl Iterator<Item = &TypeDef> {
791 self.types.values()
792 }
793}
794
795#[derive(Debug, Clone)]
797pub struct IndexRegistry {
798 indexes: HashMap<IndexName, IndexDef>,
799 nat_ranges: HashMap<NatRangeIndex, IndexDef>,
800}
801
802impl IndexRegistry {
803 #[must_use]
805 pub fn get_index(&self, name: &str) -> Option<&IndexDef> {
806 self.indexes.get(name)
807 }
808
809 #[must_use]
811 pub fn get_nat_range(&self, index: NatRangeIndex) -> Option<&IndexDef> {
812 self.nat_ranges.get(&index)
813 }
814
815 pub fn all_indexes(&self) -> impl Iterator<Item = &IndexDef> {
817 self.indexes.values().chain(self.nat_ranges.values())
818 }
819}
820
821#[derive(Debug, Clone)]
830pub struct Registry {
831 pub dimensions: DimensionRegistry,
832 pub units: UnitRegistry,
833 pub types: TypeRegistry,
834 pub indexes: IndexRegistry,
835 pub dags: DagRegistry,
836}
837
838#[derive(Debug, Default, Clone)]
845pub struct DagRegistry {
846 dags: HashMap<DeclName, DagDecl>,
850}
851
852impl DagRegistry {
853 #[must_use]
855 pub fn get(&self, name: &str) -> Option<&DagDecl> {
856 self.dags.get(name)
857 }
858
859 pub fn all_dags(&self) -> impl Iterator<Item = (&DeclName, &DagDecl)> {
861 self.dags.iter()
862 }
863}
864
865#[derive(Debug, Default)]
874pub struct RegistryBuilder {
875 base_dim_names: BTreeMap<BaseDimId, String>,
876 base_dim_symbols: BTreeMap<BaseDimId, String>,
877
878 dimensions: HashMap<DimName, Dimension>,
879 units: HashMap<UnitRef, UnitInfo>,
880 types: HashMap<StructTypeName, TypeDef>,
881 ctors: HashMap<ConstructorName, StructTypeName>,
882 indexes: HashMap<IndexName, IndexDef>,
883 nat_ranges: HashMap<NatRangeIndex, IndexDef>,
884 dags: HashMap<DeclName, DagDecl>,
885 affine_prone_dims: BTreeSet<BaseDimId>,
890}
891
892impl RegistryBuilder {
893 #[must_use]
894 pub fn new() -> Self {
895 Self::default()
896 }
897
898 #[must_use]
900 pub fn build(self) -> Registry {
901 Registry {
902 dimensions: DimensionRegistry {
903 base_dim_names: self.base_dim_names,
904 base_dim_symbols: self.base_dim_symbols,
905 dimensions: self.dimensions,
906 },
907 units: UnitRegistry { units: self.units },
908 types: TypeRegistry {
909 types: self.types,
910 ctors: self.ctors,
911 },
912 indexes: IndexRegistry {
913 indexes: self.indexes,
914 nat_ranges: self.nat_ranges,
915 },
916 dags: DagRegistry { dags: self.dags },
917 }
918 }
919
920 pub fn register_dag(&mut self, name: DeclName, decl: DagDecl) {
925 self.dags.insert(name, decl);
926 }
927
928 pub fn merge_from_registry(&mut self, parent: &Registry) {
936 for (id, name) in &parent.dimensions.base_dim_names {
937 self.base_dim_names
938 .entry(id.clone())
939 .or_insert_with(|| name.clone());
940 }
941 for (id, symbol) in &parent.dimensions.base_dim_symbols {
942 self.base_dim_symbols
943 .entry(id.clone())
944 .or_insert_with(|| symbol.clone());
945 }
946 for (name, dim) in &parent.dimensions.dimensions {
947 self.dimensions
948 .entry(name.clone())
949 .or_insert_with(|| dim.clone());
950 }
951 for (name, info) in &parent.units.units {
952 self.units
953 .entry(name.clone())
954 .or_insert_with(|| info.clone());
955 }
956 for (name, def) in &parent.types.types {
957 self.types
958 .entry(name.clone())
959 .or_insert_with(|| def.clone());
960 }
961 for (ctor, union_name) in &parent.types.ctors {
962 self.ctors
963 .entry(ctor.clone())
964 .or_insert_with(|| union_name.clone());
965 }
966 for (name, def) in &parent.indexes.indexes {
967 self.indexes
968 .entry(name.clone())
969 .or_insert_with(|| def.clone());
970 }
971 for (index, def) in &parent.indexes.nat_ranges {
972 self.nat_ranges.entry(*index).or_insert_with(|| def.clone());
973 }
974 for (name, decl) in &parent.dags.dags {
975 self.dags
976 .entry(name.clone())
977 .or_insert_with(|| decl.clone());
978 }
979 }
980
981 pub fn mark_affine_prone(&mut self, id: BaseDimId) {
992 self.affine_prone_dims.insert(id);
993 }
994
995 #[must_use]
999 pub fn is_affine_prone(&self, dim: &Dimension) -> bool {
1000 let mut iter = dim.iter();
1001 let Some((id, &exp)) = iter.next() else {
1002 return false;
1003 };
1004 iter.next().is_none() && exp == Rational::ONE && self.affine_prone_dims.contains(id)
1005 }
1006
1007 pub fn register_base_dimension(&mut self, name: DimName, id: BaseDimId) -> BaseDimId {
1008 let dim = Dimension::base(id.clone());
1009 self.base_dim_names.insert(id.clone(), name.to_string());
1010 self.dimensions.insert(name, dim);
1011 id
1012 }
1013
1014 pub fn register_base_dimension_with_symbol(
1019 &mut self,
1020 name: DimName,
1021 id: BaseDimId,
1022 symbol: String,
1023 ) -> BaseDimId {
1024 let id = self.register_base_dimension(name, id);
1025 self.base_dim_symbols.insert(id.clone(), symbol);
1026 id
1027 }
1028
1029 pub fn set_base_dim_symbol(&mut self, id: BaseDimId, symbol: String) {
1034 self.base_dim_symbols.entry(id).or_insert(symbol);
1035 }
1036
1037 pub fn register_dimension(&mut self, name: DimName, dim: Dimension) {
1039 self.dimensions.insert(name, dim);
1040 }
1041
1042 pub fn register_unit(
1044 &mut self,
1045 name: impl Into<UnitRef>,
1046 dimension: Dimension,
1047 scale: PositiveFiniteScale,
1048 ) {
1049 self.units.insert(
1050 name.into(),
1051 UnitInfo {
1052 dimension,
1053 constness: UnitConstness::Const,
1054 scale: UnitScale::Static(scale),
1055 },
1056 );
1057 }
1058
1059 pub fn register_unit_with_scale(
1061 &mut self,
1062 name: impl Into<UnitRef>,
1063 dimension: Dimension,
1064 scale: UnitScale,
1065 constness: UnitConstness,
1066 ) {
1067 self.units.insert(
1068 name.into(),
1069 UnitInfo {
1070 dimension,
1071 constness,
1072 scale,
1073 },
1074 );
1075 }
1076
1077 pub fn register_unit_dynamic(
1079 &mut self,
1080 name: impl Into<UnitRef>,
1081 dimension: Dimension,
1082 scale: UnitScale,
1083 ) {
1084 self.register_unit_with_scale(name, dimension, scale, UnitConstness::Dynamic);
1085 }
1086
1087 pub fn register_type(&mut self, def: TypeDef) {
1097 if let TypeDefKind::Union { ref members } = def.kind {
1098 for member in members {
1099 self.ctors.insert(member.name.clone(), def.name.clone());
1102 }
1103 }
1104 self.types.insert(def.name.clone(), def);
1105 }
1106
1107 pub fn register_index(&mut self, def: IndexDef) {
1109 self.indexes.insert(def.name.clone(), def);
1110 }
1111
1112 pub fn ensure_nat_range_index(&mut self, size: NonZeroUsize) -> NatRangeIndex {
1120 let nat_range = NatRangeIndex::new(size);
1121 self.nat_ranges
1122 .entry(nat_range)
1123 .or_insert_with(|| IndexDef {
1124 name: nat_range.display_name(),
1125 kind: IndexKind::NatRange { size },
1126 });
1127 nat_range
1128 }
1129
1130 #[must_use]
1134 pub fn get_dimension(&self, name: &str) -> Option<&Dimension> {
1135 self.dimensions.get(name)
1136 }
1137
1138 #[must_use]
1140 pub fn get_unit(&self, name: &UnitRef) -> Option<&UnitInfo> {
1141 self.units.get(name)
1142 }
1143
1144 pub fn all_units(&self) -> impl Iterator<Item = (&UnitRef, &Dimension, &UnitScale)> {
1146 self.units
1147 .iter()
1148 .map(|(name, info)| (name, &info.dimension, &info.scale))
1149 }
1150
1151 #[must_use]
1153 pub fn get_type(&self, name: &str) -> Option<&TypeDef> {
1154 self.types.get(name)
1155 }
1156
1157 #[must_use]
1159 pub fn get_index(&self, name: &str) -> Option<&IndexDef> {
1160 self.indexes.get(name)
1161 }
1162
1163 #[must_use]
1165 pub fn get_nat_range(&self, index: NatRangeIndex) -> Option<&IndexDef> {
1166 self.nat_ranges.get(&index)
1167 }
1168
1169 #[must_use]
1171 pub const fn base_dim_names(&self) -> &BTreeMap<BaseDimId, String> {
1172 &self.base_dim_names
1173 }
1174
1175 #[must_use]
1177 pub const fn base_dim_symbols(&self) -> &BTreeMap<BaseDimId, String> {
1178 &self.base_dim_symbols
1179 }
1180
1181 #[must_use]
1186 pub fn format_dimension(&self, dim: &Dimension) -> String {
1187 format_dimension_preferring_alias(&self.dimensions, &self.base_dim_names, dim)
1188 }
1189
1190 pub fn resolve_dim_expr(&self, expr: &DimExpr) -> Result<Option<Dimension>, RationalError> {
1195 resolve_dim_expr_impl(&self.dimensions, expr)
1196 }
1197
1198 pub fn resolve_type_expr(
1203 &self,
1204 type_expr: &TypeExpr,
1205 ) -> Result<Option<Dimension>, RationalError> {
1206 resolve_type_expr_impl(&self.dimensions, type_expr)
1207 }
1208
1209 pub fn resolve_unit_expr(&self, expr: &UnitExpr) -> Result<(Dimension, f64), UnitResolveError> {
1216 resolve_unit_expr_impl(&self.units, expr)
1217 }
1218
1219 pub fn resolve_unit_dimension(&self, expr: &UnitExpr) -> Result<Dimension, UnitResolveError> {
1228 resolve_unit_dimension_impl(&self.units, expr)
1229 }
1230}
1231
1232#[cfg(test)]
1233mod tests {
1234 use super::*;
1235 use crate::registry::prelude::load_prelude;
1236 use crate::syntax::ast::{DimExprItem, DimTerm, UnitExprItem};
1237 use crate::syntax::dimension::BaseDimId;
1238 use crate::syntax::names::{NamePath, UnitName};
1239 use crate::syntax::span::Span;
1240 use crate::syntax::span::Spanned;
1241
1242 fn length_id() -> BaseDimId {
1244 BaseDimId::Prelude("Length".to_string())
1245 }
1246 fn time_id() -> BaseDimId {
1247 BaseDimId::Prelude("Time".to_string())
1248 }
1249 fn mass_id() -> BaseDimId {
1250 BaseDimId::Prelude("Mass".to_string())
1251 }
1252
1253 fn make_registry() -> Registry {
1254 let mut b = RegistryBuilder::new();
1255 load_prelude(&mut b).unwrap();
1256 b.build()
1257 }
1258
1259 fn make_dim_term_name(name: &str) -> Spanned<NamePath> {
1260 Spanned::new(NamePath::from(name), Span::new(0, 0))
1261 }
1262
1263 fn make_dim_type_expr(name: &str) -> TypeExpr {
1265 use crate::syntax::ast::{DimExpr, DimExprItem, DimTerm};
1266 TypeExpr {
1267 kind: TypeExprKind::DimExpr(DimExpr {
1268 terms: vec![DimExprItem {
1269 op: MulDivOp::Mul,
1270 term: DimTerm {
1271 name: make_dim_term_name(name),
1272 power: None,
1273 span: Span::new(0, 0),
1274 },
1275 }],
1276 span: Span::new(0, 0),
1277 }),
1278 constraints: vec![],
1279 span: Span::new(0, 0),
1280 }
1281 }
1282
1283 fn make_unit_name(name: &str) -> Spanned<UnitRef> {
1284 Spanned::new(UnitRef::local(UnitName::new(name)), Span::new(0, 0))
1285 }
1286
1287 #[test]
1288 fn registry_base_dimensions() {
1289 let r = make_registry();
1290 assert_eq!(
1291 r.dimensions.get_dimension("Length"),
1292 Some(&Dimension::base(length_id()))
1293 );
1294 assert_eq!(
1295 r.dimensions.get_dimension("Time"),
1296 Some(&Dimension::base(time_id()))
1297 );
1298 assert_eq!(
1299 r.dimensions.get_dimension("Mass"),
1300 Some(&Dimension::base(mass_id()))
1301 );
1302 }
1303
1304 #[test]
1305 fn registry_derived_dimensions() {
1306 let r = make_registry();
1307 let velocity = r.dimensions.get_dimension("Velocity").unwrap();
1308 let expected = (Dimension::base(length_id()) / Dimension::base(time_id())).unwrap();
1309 assert_eq!(*velocity, expected);
1310 }
1311
1312 #[test]
1313 fn registry_base_units() {
1314 let r = make_registry();
1315 let m = r.units.get_unit(&UnitRef::local("m")).unwrap();
1316 assert_eq!(m.dimension, Dimension::base(length_id()));
1317 assert!((m.scale.as_static().unwrap() - 1.0).abs() < f64::EPSILON);
1318 }
1319
1320 #[test]
1321 fn registry_derived_units() {
1322 let r = make_registry();
1323 let km = r.units.get_unit(&UnitRef::local("km")).unwrap();
1324 assert_eq!(km.dimension, Dimension::base(length_id()));
1325 assert!((km.scale.as_static().unwrap() - 1000.0).abs() < f64::EPSILON);
1326 }
1327
1328 #[test]
1329 fn resolve_dim_expr_velocity() {
1330 let r = make_registry();
1331 let expr = DimExpr {
1333 terms: vec![
1334 DimExprItem {
1335 op: MulDivOp::Mul,
1336 term: DimTerm {
1337 name: make_dim_term_name("Length"),
1338 power: None,
1339 span: Span::new(0, 0),
1340 },
1341 },
1342 DimExprItem {
1343 op: MulDivOp::Div,
1344 term: DimTerm {
1345 name: make_dim_term_name("Time"),
1346 power: None,
1347 span: Span::new(0, 0),
1348 },
1349 },
1350 ],
1351 span: Span::new(0, 0),
1352 };
1353 let dim = r.dimensions.resolve_dim_expr(&expr).unwrap().unwrap();
1354 let expected = (Dimension::base(length_id()) / Dimension::base(time_id())).unwrap();
1355 assert_eq!(dim, expected);
1356 }
1357
1358 #[test]
1359 fn resolve_unit_expr_m_per_s_squared() {
1360 let r = make_registry();
1361 let expr = UnitExpr {
1363 terms: vec![
1364 UnitExprItem {
1365 op: MulDivOp::Mul,
1366 name: make_unit_name("m"),
1367 power: None,
1368 },
1369 UnitExprItem {
1370 op: MulDivOp::Div,
1371 name: make_unit_name("s"),
1372 power: Some(Rational::from_int(2)),
1373 },
1374 ],
1375 span: Span::new(0, 0),
1376 };
1377 let (dim, scale) = r.units.resolve_unit_expr(&expr).unwrap();
1378 let expected_dim = (Dimension::base(length_id())
1379 / Dimension::base(time_id()).pow_int(2).unwrap())
1380 .unwrap();
1381 assert_eq!(dim, expected_dim);
1382 assert!((scale - 1.0).abs() < f64::EPSILON);
1383 }
1384
1385 #[test]
1386 fn resolve_unit_expr_km_per_hour() {
1387 let r = make_registry();
1388 let expr = UnitExpr {
1390 terms: vec![
1391 UnitExprItem {
1392 op: MulDivOp::Mul,
1393 name: make_unit_name("km"),
1394 power: None,
1395 },
1396 UnitExprItem {
1397 op: MulDivOp::Div,
1398 name: make_unit_name("hour"),
1399 power: None,
1400 },
1401 ],
1402 span: Span::new(0, 0),
1403 };
1404 let (dim, scale) = r.units.resolve_unit_expr(&expr).unwrap();
1405 let expected_dim = (Dimension::base(length_id()) / Dimension::base(time_id())).unwrap();
1406 assert_eq!(dim, expected_dim);
1407 assert!((scale - 1000.0 / 3600.0).abs() < 1e-10);
1409 }
1410
1411 #[test]
1412 fn registry_type_register_and_lookup() {
1413 let mut b = RegistryBuilder::new();
1414 load_prelude(&mut b).unwrap();
1415 b.register_type(TypeDef {
1418 name: StructTypeName::new("TransferResult"),
1419 generic_params: vec![],
1420 kind: TypeDefKind::Union {
1421 members: vec![UnionMemberDef {
1422 name: ConstructorName::new("TransferResult"),
1423 fields: vec![
1424 StructField {
1425 name: FieldName::new("dv1"),
1426 type_ann: make_dim_type_expr("Velocity"),
1427 },
1428 StructField {
1429 name: FieldName::new("dv2"),
1430 type_ann: make_dim_type_expr("Velocity"),
1431 },
1432 ],
1433 }],
1434 },
1435 });
1436 let r = b.build();
1437 let velocity_dim = (Dimension::base(length_id()) / Dimension::base(time_id())).unwrap();
1438 let def = r.types.get_type("TransferResult").unwrap();
1439 assert_eq!(def.name.as_str(), "TransferResult");
1440 assert!(def.is_union());
1441 let fields = def.record_fields().expect("single-variant collision");
1442 assert_eq!(fields.len(), 2);
1443 assert_eq!(fields[0].name.as_str(), "dv1");
1444 assert_eq!(
1445 r.dimensions.resolve_type_expr(&fields[0].type_ann),
1446 Ok(Some(velocity_dim))
1447 );
1448 assert!(r.types.get_type("NonExistent").is_none());
1449 }
1450
1451 #[test]
1452 fn registry_index_register_and_lookup() {
1453 let mut b = RegistryBuilder::new();
1454 load_prelude(&mut b).unwrap();
1455 b.register_index(IndexDef {
1456 name: IndexName::new("Maneuver"),
1457 kind: IndexKind::Named {
1458 variants: vec![
1459 IndexVariantName::new("Departure"),
1460 IndexVariantName::new("Correction"),
1461 IndexVariantName::new("Insertion"),
1462 ],
1463 },
1464 });
1465 let r = b.build();
1466 let def = r.indexes.get_index("Maneuver").unwrap();
1467 assert_eq!(def.name.as_str(), "Maneuver");
1468 let variants = def.variants();
1469 let variant_strs: Vec<&str> = variants.iter().map(IndexVariantName::as_str).collect();
1470 assert_eq!(variant_strs, vec!["Departure", "Correction", "Insertion"]);
1471 assert!(r.indexes.get_index("NonExistent").is_none());
1472 }
1473
1474 #[test]
1475 fn register_user_defined_base_dimension() {
1476 let mut b = RegistryBuilder::new();
1477 load_prelude(&mut b).unwrap();
1478 let info_id = BaseDimId::UserDefined {
1479 dag: crate::dag_id::DagId::root("test"),
1480 name: "Information".to_string(),
1481 };
1482 let id = b.register_base_dimension(DimName::new("Information"), info_id.clone());
1483 assert_eq!(id, info_id);
1484 let r = b.build();
1485 let dim = r.dimensions.get_dimension("Information").unwrap();
1487 assert_eq!(*dim, Dimension::base(id.clone()));
1488 assert_eq!(
1490 r.dimensions.base_dim_names().get(&id),
1491 Some(&"Information".to_string())
1492 );
1493 }
1494
1495 #[test]
1496 fn register_base_dimension_with_symbol() {
1497 let mut b = RegistryBuilder::new();
1498 let id = b.register_base_dimension_with_symbol(
1499 DimName::new("Length"),
1500 BaseDimId::Prelude("Length".to_string()),
1501 "m".to_string(),
1502 );
1503 let r = b.build();
1504 assert_eq!(
1505 r.dimensions.base_dim_symbols().get(&id),
1506 Some(&"m".to_string())
1507 );
1508 }
1509
1510 #[test]
1511 fn set_base_dim_symbol_only_first() {
1512 let mut b = RegistryBuilder::new();
1513 let info_id = BaseDimId::UserDefined {
1514 dag: crate::dag_id::DagId::root("test"),
1515 name: "Information".to_string(),
1516 };
1517 let id = b.register_base_dimension(DimName::new("Information"), info_id);
1518 b.set_base_dim_symbol(id.clone(), "bit".to_string());
1519 b.set_base_dim_symbol(id.clone(), "byte".to_string());
1521 let r = b.build();
1522 assert_eq!(
1523 r.dimensions.base_dim_symbols().get(&id),
1524 Some(&"bit".to_string())
1525 );
1526 }
1527}