Skip to main content

graphcal_compiler/registry/
types.rs

1use std::collections::{BTreeMap, BTreeSet, HashMap};
2use std::num::NonZeroUsize;
3
4use thiserror::Error;
5
6use crate::desugar::desugared_ast::{
7    DagDecl, DimExpr, Expr, GenericConstraint, MulDivOp, TypeExpr, TypeExprKind, UnitExpr,
8};
9use crate::syntax::ast::UnitConstness;
10use crate::syntax::dimension::{BaseDimId, Dimension, Rational, RationalError};
11use crate::syntax::names::{
12    ConstructorName, DeclName, DimName, FieldName, GenericParamName, IndexName, IndexVariantName,
13    StructTypeName, UnitRef,
14};
15// ---------------------------------------------------------------------------
16// Data types
17// ---------------------------------------------------------------------------
18
19/// Error returned when a unit scale is not a positive finite scalar.
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
21pub enum PositiveFiniteScaleError {
22    #[error("scale must be finite")]
23    NonFinite,
24    #[error("scale must be greater than zero")]
25    NonPositive,
26}
27
28/// A unit scale factor that is guaranteed to be positive and finite.
29#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
30pub struct PositiveFiniteScale(f64);
31
32impl PositiveFiniteScale {
33    /// Validate a raw scale factor.
34    ///
35    /// # Errors
36    ///
37    /// Returns an error when `value` is `NaN`, infinite, zero, or negative.
38    pub fn new(value: f64) -> Result<Self, PositiveFiniteScaleError> {
39        if !value.is_finite() {
40            Err(PositiveFiniteScaleError::NonFinite)
41        } else if value <= 0.0 {
42            Err(PositiveFiniteScaleError::NonPositive)
43        } else {
44            Ok(Self(value))
45        }
46    }
47
48    /// Construct a scale from trusted internal constants.
49    ///
50    /// Callers must ensure `value` is positive and finite. This is restricted
51    /// to the compiler crate so external code must use [`Self::new`].
52    #[must_use]
53    pub(crate) const fn new_unchecked(value: f64) -> Self {
54        Self(value)
55    }
56
57    /// Return the wrapped raw scale factor.
58    #[must_use]
59    pub const fn get(self) -> f64 {
60        self.0
61    }
62}
63
64impl std::fmt::Display for PositiveFiniteScale {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        self.0.fmt(f)
67    }
68}
69
70/// How a unit's scale factor is determined.
71#[derive(Debug, Clone)]
72pub enum UnitScale {
73    /// Scale factor known at compile time (e.g., `const unit km: Length = 1000 m;`).
74    Static(PositiveFiniteScale),
75    /// Scale factor depends on runtime values (e.g., `unit EUR: Money = (@rate) USD;`).
76    ///
77    /// The final SI scale = `eval(scale_expr) * base_unit_scale`.
78    Dynamic {
79        /// The unevaluated scale expression containing `@`-references.
80        scale_expr: Expr,
81        /// The scale factor of the base unit in the definition (resolved at compile time).
82        /// For `(@rate) USD` where USD has scale 1.0, this is 1.0.
83        base_unit_scale: PositiveFiniteScale,
84    },
85}
86
87impl UnitScale {
88    /// Returns the static scale factor, or `None` if the scale is dynamic.
89    #[must_use]
90    pub const fn as_static(&self) -> Option<f64> {
91        match self {
92            Self::Static(s) => Some(s.get()),
93            Self::Dynamic { .. } => None,
94        }
95    }
96
97    /// Returns `true` if the scale is resolved at compile time.
98    #[must_use]
99    pub const fn is_static(&self) -> bool {
100        matches!(self, Self::Static(_))
101    }
102
103    /// Returns `true` if the scale depends on runtime values.
104    #[must_use]
105    pub const fn is_dynamic(&self) -> bool {
106        matches!(self, Self::Dynamic { .. })
107    }
108}
109
110/// Information about a registered unit.
111#[derive(Debug, Clone)]
112pub struct UnitInfo {
113    /// The dimension this unit measures.
114    pub dimension: Dimension,
115    /// Whether this unit may appear in compile-time (`const`) contexts.
116    pub constness: UnitConstness,
117    /// Scale factor to convert 1 of this unit to base SI units.
118    /// e.g., km -> `Static(1000.0)` (1 km = 1000 m)
119    pub scale: UnitScale,
120}
121
122/// A field in a record type definition.
123#[derive(Debug, Clone)]
124pub struct StructField {
125    pub name: FieldName,
126    pub type_ann: TypeExpr,
127}
128
129/// A member (constructor) of a tagged-union type.
130///
131/// The compiler treats every `type T { ... }` declaration as an n-variant
132/// tagged union — including single-variant cases. Each variant carries
133/// its payload fields inline; there are no per-variant standalone types.
134#[derive(Debug, Clone)]
135pub struct UnionMemberDef {
136    /// Constructor name.
137    pub name: ConstructorName,
138    /// Payload fields for this constructor. An empty `Vec` means a unit
139    /// constructor (`Coast`).
140    pub fields: Vec<StructField>,
141}
142
143/// The kind of a type definition.
144///
145/// The functional core only distinguishes two shapes: a *required* type
146/// stub (no body, awaits binding via include) and an *n-variant union*
147/// — single-variant or multi-variant alike. Record-shaped types are
148/// represented as a single-variant union whose sole constructor's name
149/// matches the type's name (e.g.,
150/// `type Position { Position(x: Length, y: Length) }`).
151#[derive(Debug, Clone)]
152pub enum TypeDefKind {
153    /// A required type with no body: `type Element;`. Bound from outside
154    /// via parameterized include.
155    Required,
156    /// A tagged union: `type Maneuver { Impulsive(delta_v: Velocity), Coast }`
157    /// or, as a single-variant special case,
158    /// `type Position { Position(x: Length, y: Length) }`.
159    Union { members: Vec<UnionMemberDef> },
160}
161
162/// The constraint on a generic parameter of a type definition.
163#[derive(Debug, Clone, Copy, PartialEq, Eq)]
164pub enum TypeGenericConstraint {
165    /// `D: Dim` — the generic stands for a dimension.
166    Dim,
167    /// `I: Index` — the generic stands for an index.
168    Index,
169    /// `N: Nat` — the generic stands for a natural number (type-level).
170    Nat,
171    /// `F: Type` — unconstrained phantom type parameter.
172    Unconstrained,
173}
174
175impl From<GenericConstraint> for TypeGenericConstraint {
176    fn from(c: GenericConstraint) -> Self {
177        match c {
178            GenericConstraint::Dim => Self::Dim,
179            GenericConstraint::Index => Self::Index,
180            GenericConstraint::Nat => Self::Nat,
181            GenericConstraint::Type => Self::Unconstrained,
182        }
183    }
184}
185
186/// A generic parameter on a type definition.
187#[derive(Debug, Clone)]
188pub struct TypeGenericParam {
189    pub name: GenericParamName,
190    pub constraint: TypeGenericConstraint,
191    /// Optional default type expression, e.g. `F: Type = Unframed`.
192    pub default: Option<crate::desugar::desugared_ast::TypeExpr>,
193}
194
195/// A registered type definition: either a required type stub or a tagged union.
196#[derive(Debug, Clone)]
197pub struct TypeDef {
198    pub name: StructTypeName,
199    pub generic_params: Vec<TypeGenericParam>,
200    pub kind: TypeDefKind,
201}
202
203impl TypeDef {
204    /// Returns the union members if this is a tagged union.
205    ///
206    /// Returns `None` only for a required (unbound) type stub.
207    #[must_use]
208    pub fn union_members(&self) -> Option<&[UnionMemberDef]> {
209        match &self.kind {
210            TypeDefKind::Union { members } => Some(members),
211            TypeDefKind::Required => None,
212        }
213    }
214
215    /// Returns `true` if this is a tagged union — single-variant or
216    /// multi-variant.
217    #[must_use]
218    pub const fn is_union(&self) -> bool {
219        matches!(self.kind, TypeDefKind::Union { .. })
220    }
221
222    /// Returns `true` if this is a required type stub awaiting binding.
223    #[must_use]
224    pub const fn is_required(&self) -> bool {
225        matches!(self.kind, TypeDefKind::Required)
226    }
227
228    /// If this is a single-variant union whose sole constructor's name
229    /// equals the type's name, returns that variant's payload fields.
230    /// This is the record-like shape: field access and brace
231    /// construction work directly on it.
232    ///
233    /// For multi-variant unions or single-variant unions whose
234    /// constructor name differs from the type name, returns `None` —
235    /// callers must dispatch through the constructor namespace and / or
236    /// `match`.
237    #[must_use]
238    pub fn record_fields(&self) -> Option<&[StructField]> {
239        let TypeDefKind::Union { members } = &self.kind else {
240            return None;
241        };
242        let [only] = members.as_slice() else {
243            return None;
244        };
245        (only.name.as_str() == self.name.as_str()).then_some(only.fields.as_slice())
246    }
247
248    /// Backward-compatible accessor that returns the record-shaped
249    /// fields (empty when the type is multi-variant or a required
250    /// stub). Prefer [`record_fields`](Self::record_fields) at new call
251    /// sites — it makes the single-variant precondition explicit.
252    #[must_use]
253    pub fn fields(&self) -> &[StructField] {
254        self.record_fields().unwrap_or(&[])
255    }
256}
257
258/// Data for a concrete numeric range index (e.g., `linspace(0.0 s, 100.0 s, step: 0.1 s)`).
259#[derive(Debug, Clone)]
260pub struct RangeIndexData {
261    pub start: f64,
262    pub end: f64,
263    pub step: f64,
264    /// Validated number of inclusive range steps.
265    pub step_count: NonZeroUsize,
266    pub dimension: Dimension,
267    /// Display unit label (e.g., `"s"`) for formatting step values.
268    pub display_label: Option<String>,
269    /// Scale factor from SI to display unit: `display_value = si_value / scale`.
270    pub display_scale: f64,
271}
272
273impl RangeIndexData {
274    /// Returns the SI value at step `i`.
275    #[must_use]
276    #[expect(
277        clippy::cast_precision_loss,
278        reason = "range step indices are small enough for exact f64 representation"
279    )]
280    pub fn step_value(&self, i: usize) -> f64 {
281        (i as f64).mul_add(self.step, self.start)
282    }
283
284    /// Returns the number of steps in this range.
285    #[must_use]
286    pub const fn step_count(&self) -> usize {
287        self.step_count.get()
288    }
289}
290
291/// The kind of an index: either named variants or a numeric range.
292#[derive(Debug, Clone)]
293pub enum IndexKind {
294    /// A named label set, e.g. `index Maneuver = { Departure, Correction, Insertion };`
295    Named { variants: Vec<IndexVariantName> },
296    /// A numeric range, e.g. `index T = linspace(0.0 s, 100.0 s, step: 0.1 s);`
297    Range(RangeIndexData),
298    /// Required named index (no variants): must be bound via parameterized import.
299    RequiredNamed,
300    /// Required range index with dimension constraint: must be bound via parameterized import.
301    RequiredRange { dimension: Dimension },
302    /// A Nat-parameterized range: `range(N)` with elements `{0, 1, ..., N-1}`.
303    ///
304    /// Created synthetically for integer literals in index position (e.g., `D[3]`).
305    NatRange {
306        /// The non-zero size of the range (number of elements). Stored as
307        /// `usize` because it bounds in-memory variant tables; AST-level Nat
308        /// literals are converted at the registry boundary.
309        size: NonZeroUsize,
310    },
311}
312
313/// A declared index with its ordered variants.
314#[derive(Debug, Clone)]
315pub struct IndexDef {
316    pub name: IndexName,
317    pub kind: IndexKind,
318}
319
320impl IndexDef {
321    /// Returns the ordered variant names for this index.
322    ///
323    /// For named indexes, returns the declared variants.
324    /// For range indexes, generates synthetic names like `"#0"`, `"#1"`, etc.
325    /// For nat range indexes, generates synthetic names like `"#0"`, `"#1"`, etc.
326    /// For required indexes, returns an empty vec (no variants until bound).
327    #[must_use]
328    pub fn variants(&self) -> Vec<IndexVariantName> {
329        match &self.kind {
330            IndexKind::Named { variants } => variants.clone(),
331            IndexKind::Range(data) => {
332                let count = data.step_count();
333                (0..count).map(IndexVariantName::range_step).collect()
334            }
335            IndexKind::NatRange { size } => {
336                (0..size.get()).map(IndexVariantName::range_step).collect()
337            }
338            IndexKind::RequiredNamed | IndexKind::RequiredRange { .. } => vec![],
339        }
340    }
341
342    /// Returns the number of steps/variants in this index.
343    ///
344    /// Returns 0 for required indexes (no variants until bound).
345    #[must_use]
346    pub const fn step_count(&self) -> usize {
347        match &self.kind {
348            IndexKind::Named { variants } => variants.len(),
349            IndexKind::Range(data) => data.step_count(),
350            IndexKind::NatRange { size } => size.get(),
351            IndexKind::RequiredNamed | IndexKind::RequiredRange { .. } => 0,
352        }
353    }
354
355    /// Returns the range data if this is a concrete range index.
356    #[must_use]
357    pub const fn range_data(&self) -> Option<&RangeIndexData> {
358        match &self.kind {
359            IndexKind::Range(data) => Some(data),
360            _ => None,
361        }
362    }
363
364    /// Returns true if this is a range index (concrete or required, not nat range).
365    #[must_use]
366    pub const fn is_range(&self) -> bool {
367        matches!(
368            self.kind,
369            IndexKind::Range(_) | IndexKind::RequiredRange { .. }
370        )
371    }
372
373    /// Returns true if this is a named index (concrete or required).
374    #[must_use]
375    pub const fn is_named(&self) -> bool {
376        matches!(
377            self.kind,
378            IndexKind::Named { .. } | IndexKind::RequiredNamed
379        )
380    }
381
382    /// Returns true if this is a nat range index.
383    #[must_use]
384    pub const fn is_nat_range(&self) -> bool {
385        matches!(self.kind, IndexKind::NatRange { .. })
386    }
387
388    /// Returns the nat range size, if this is a nat range index.
389    #[must_use]
390    pub const fn nat_range_size(&self) -> Option<u64> {
391        match &self.kind {
392            IndexKind::NatRange { size } => Some(size.get() as u64),
393            _ => None,
394        }
395    }
396
397    /// Returns true if this is a required index (must be bound via parameterized import).
398    #[must_use]
399    pub const fn is_required(&self) -> bool {
400        matches!(
401            self.kind,
402            IndexKind::RequiredNamed | IndexKind::RequiredRange { .. }
403        )
404    }
405}
406
407// ---------------------------------------------------------------------------
408// Nat range helpers
409// ---------------------------------------------------------------------------
410
411/// Error returned when an AST/runtime Nat range size cannot become a concrete index.
412#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
413pub enum NatRangeIndexError {
414    /// Empty Nat ranges are deliberately not representable.
415    #[error("range(0) is not allowed; indexes must contain at least one element")]
416    Empty,
417    /// The source-level `u64` size does not fit in this target's in-memory index size.
418    #[error("nat range size {size} does not fit in usize on this target")]
419    DoesNotFitUsize { size: u64 },
420}
421
422/// Typed identity for a concrete compiler-generated Nat range index.
423///
424/// The core carries this non-zero size directly; display names are derived only
425/// for diagnostics and compatibility with APIs that still need an [`IndexName`].
426#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
427pub struct NatRangeIndex {
428    size: NonZeroUsize,
429}
430
431impl NatRangeIndex {
432    /// Create an identity for a non-empty Nat range index.
433    #[must_use]
434    pub const fn new(size: NonZeroUsize) -> Self {
435        Self { size }
436    }
437
438    /// Try to create an identity from an AST/runtime `u64` size.
439    ///
440    /// # Errors
441    ///
442    /// Returns an error when `size` is zero or cannot fit in `usize` on this target.
443    pub fn try_from_u64(size: u64) -> Result<Self, NatRangeIndexError> {
444        if size == 0 {
445            return Err(NatRangeIndexError::Empty);
446        }
447        let size =
448            usize::try_from(size).map_err(|_| NatRangeIndexError::DoesNotFitUsize { size })?;
449        let size = NonZeroUsize::new(size).ok_or(NatRangeIndexError::Empty)?;
450        Ok(Self::new(size))
451    }
452
453    /// Return the non-zero in-memory size.
454    #[must_use]
455    pub const fn size(self) -> NonZeroUsize {
456        self.size
457    }
458
459    /// Return the size as a `u64` for Nat-expression comparisons and display.
460    #[must_use]
461    #[expect(
462        clippy::expect_used,
463        reason = "Graphcal currently supports targets where usize fits in u64"
464    )]
465    pub fn size_u64(self) -> u64 {
466        u64::try_from(self.size.get()).expect("usize fits in u64 on supported targets")
467    }
468
469    /// Render this identity for diagnostics as source-level `range(N)` syntax.
470    #[must_use]
471    pub fn display_name(self) -> IndexName {
472        IndexName::new(format!("range({})", self.size_u64()))
473    }
474}
475
476// ---------------------------------------------------------------------------
477// Private helper functions for resolution logic
478// ---------------------------------------------------------------------------
479
480/// Why a unit expression could not be resolved.
481///
482/// Carries the failing unit name so callers can produce a precise
483/// diagnostic instead of re-scanning the expression to find it (the old
484/// `Ok(None)` return conflated unknown names with dynamic scales).
485#[derive(Debug, Clone, PartialEq)]
486pub enum UnitResolveError {
487    /// A unit name in the expression is not registered.
488    UnknownUnit(UnitRef),
489    /// A unit in the expression has a runtime-dependent scale.
490    DynamicScale(UnitRef),
491    /// Dimension exponent arithmetic overflowed.
492    Overflow(RationalError),
493}
494
495impl From<RationalError> for UnitResolveError {
496    fn from(err: RationalError) -> Self {
497        Self::Overflow(err)
498    }
499}
500
501/// Shared implementation for resolving a `DimExpr` to a concrete `Dimension`.
502fn resolve_dim_expr_impl(
503    dimensions: &HashMap<DimName, Dimension>,
504    expr: &DimExpr,
505) -> Result<Option<Dimension>, RationalError> {
506    expr.terms
507        .iter()
508        .try_fold(Some(Dimension::dimensionless()), |acc, item| {
509            let Some(acc) = acc else {
510                return Ok(None);
511            };
512            let Some(atom) = item.term.name.value.as_bare() else {
513                return Ok(None);
514            };
515            let Some(base) = dimensions.get(atom.as_str()) else {
516                return Ok(None);
517            };
518            let exp = item.term.power.unwrap_or(Rational::ONE);
519            let powered = base.pow(exp)?;
520            match item.op {
521                MulDivOp::Mul => acc * powered,
522                MulDivOp::Div => acc / powered,
523            }
524            .map(Some)
525        })
526}
527
528/// Shared implementation for resolving a `TypeExpr` to a concrete `Dimension`.
529fn resolve_type_expr_impl(
530    dimensions: &HashMap<DimName, Dimension>,
531    type_expr: &TypeExpr,
532) -> Result<Option<Dimension>, RationalError> {
533    match &type_expr.kind {
534        TypeExprKind::Dimensionless => Ok(Some(Dimension::dimensionless())),
535        TypeExprKind::Bool
536        | TypeExprKind::Int
537        | TypeExprKind::Datetime
538        | TypeExprKind::TypeApplication { .. }
539        | TypeExprKind::DatetimeApplication { .. } => Ok(None),
540        TypeExprKind::DimExpr(dim_expr) => resolve_dim_expr_impl(dimensions, dim_expr),
541        TypeExprKind::Indexed { base, .. } => resolve_type_expr_impl(dimensions, base),
542    }
543}
544
545/// Raise a positive unit scale to a rational power.
546///
547/// Integer powers use `powi` for exactness; fractional powers fall back to
548/// `powf`, which is well-defined because unit scales are always positive.
549#[must_use]
550pub fn pow_scale(scale: f64, exp: Rational) -> f64 {
551    if exp.is_integer() {
552        scale.powi(exp.num())
553    } else {
554        scale.powf(f64::from(exp.num()) / f64::from(exp.den()))
555    }
556}
557
558/// Shared implementation for resolving a `UnitExpr` to its dimension and static scale factor.
559fn resolve_unit_expr_impl(
560    units: &HashMap<UnitRef, UnitInfo>,
561    expr: &UnitExpr,
562) -> Result<(Dimension, f64), UnitResolveError> {
563    let mut dim = Dimension::dimensionless();
564    let mut scale = 1.0_f64;
565    for item in &expr.terms {
566        let Some(info) = units.get(&item.name.value) else {
567            return Err(UnitResolveError::UnknownUnit(item.name.value.clone()));
568        };
569        let exp = item.power.unwrap_or(Rational::ONE);
570        let powered_dim = info.dimension.pow(exp)?;
571        let Some(static_scale) = info.scale.as_static() else {
572            return Err(UnitResolveError::DynamicScale(item.name.value.clone()));
573        };
574        let powered_scale = pow_scale(static_scale, exp);
575        match item.op {
576            MulDivOp::Mul => {
577                dim = (dim * powered_dim)?;
578                scale *= powered_scale;
579            }
580            MulDivOp::Div => {
581                dim = (dim / powered_dim)?;
582                scale /= powered_scale;
583            }
584        }
585    }
586    Ok((dim, scale))
587}
588
589/// Shared implementation for resolving a `UnitExpr` to its dimension only (ignoring scales).
590///
591/// Works for both static and dynamic units.
592fn resolve_unit_dimension_impl(
593    units: &HashMap<UnitRef, UnitInfo>,
594    expr: &UnitExpr,
595) -> Result<Dimension, UnitResolveError> {
596    let mut dim = Dimension::dimensionless();
597    for item in &expr.terms {
598        let Some(info) = units.get(&item.name.value) else {
599            return Err(UnitResolveError::UnknownUnit(item.name.value.clone()));
600        };
601        let exp = item.power.unwrap_or(Rational::ONE);
602        let powered_dim = info.dimension.pow(exp)?;
603        dim = match item.op {
604            MulDivOp::Mul => (dim * powered_dim)?,
605            MulDivOp::Div => (dim / powered_dim)?,
606        };
607    }
608    Ok(dim)
609}
610
611/// Format a dimension, preferring a registered named alias for compound forms.
612///
613/// A pure base dimension (`Length`) or `Dimensionless` keeps its canonical
614/// rendering. A compound dimension (`Length^2 * Mass / Time^2`) is replaced by
615/// a matching named dimension (`Energy`) when one is registered; if several
616/// names match, the lexicographically smallest is chosen for determinism.
617fn format_dimension_preferring_alias(
618    dimensions: &HashMap<DimName, Dimension>,
619    base_dim_names: &BTreeMap<BaseDimId, String>,
620    dim: &Dimension,
621) -> String {
622    let canonical = format!("{}", dim.display_with(base_dim_names));
623    // Base dimensions and Dimensionless render as a single bare name already;
624    // only compound renderings benefit from an alias.
625    let is_compound = canonical.contains([' ', '^', '*', '/']);
626    if is_compound
627        && let Some(alias) = dimensions
628            .iter()
629            .filter(|(_, d)| *d == dim)
630            .map(|(name, _)| name)
631            .min()
632    {
633        return alias.to_string();
634    }
635    canonical
636}
637
638// ---------------------------------------------------------------------------
639// Domain-specific registries (frozen / read-only)
640// ---------------------------------------------------------------------------
641
642/// Dimension registry: maps dimension names to `Dimension` values and tracks
643/// base dimension metadata (ID assignment, names, default unit symbols).
644#[derive(Debug, Clone)]
645pub struct DimensionRegistry {
646    /// Base dimension ID → dimension name (for display).
647    base_dim_names: BTreeMap<BaseDimId, String>,
648    /// Base dimension ID → default unit symbol for runtime display.
649    base_dim_symbols: BTreeMap<BaseDimId, String>,
650    dimensions: HashMap<DimName, Dimension>,
651}
652
653impl DimensionRegistry {
654    /// Look up a dimension by name.
655    #[must_use]
656    pub fn get_dimension(&self, name: &str) -> Option<&Dimension> {
657        self.dimensions.get(name)
658    }
659
660    /// Iterate over all named dimensions.
661    pub fn all_dimensions(&self) -> impl Iterator<Item = (&DimName, &Dimension)> {
662        self.dimensions.iter()
663    }
664
665    /// Get the base dimension names map (for display purposes).
666    #[must_use]
667    pub const fn base_dim_names(&self) -> &BTreeMap<BaseDimId, String> {
668        &self.base_dim_names
669    }
670
671    /// Get the base dimension symbols map for runtime display.
672    #[must_use]
673    pub const fn base_dim_symbols(&self) -> &BTreeMap<BaseDimId, String> {
674        &self.base_dim_symbols
675    }
676
677    /// Format a dimension as a human-readable string using registered base dimension names.
678    ///
679    /// Returns `"Dimensionless"` for dimensionless, or names like `"Length / Time"`.
680    /// When a compound dimension matches a named dimension alias (e.g. `Energy`
681    /// for `Length^2 * Mass / Time^2`), the alias is preferred so diagnostics
682    /// speak the user's vocabulary.
683    #[must_use]
684    pub fn format_dimension(&self, dim: &Dimension) -> String {
685        format_dimension_preferring_alias(&self.dimensions, &self.base_dim_names, dim)
686    }
687
688    /// Resolve a `DimExpr` AST node to a concrete `Dimension`.
689    ///
690    /// Returns `Ok(None)` if any dimension name is unknown, and `Err` if
691    /// dimension exponent arithmetic overflows `i32`.
692    pub fn resolve_dim_expr(&self, expr: &DimExpr) -> Result<Option<Dimension>, RationalError> {
693        resolve_dim_expr_impl(&self.dimensions, expr)
694    }
695
696    /// Resolve a `TypeExpr` to a concrete `Dimension`.
697    ///
698    /// Returns `Ok(None)` if the type references unknown dimensions, and
699    /// `Err` if dimension exponent arithmetic overflows `i32`.
700    pub fn resolve_type_expr(
701        &self,
702        type_expr: &TypeExpr,
703    ) -> Result<Option<Dimension>, RationalError> {
704        resolve_type_expr_impl(&self.dimensions, type_expr)
705    }
706}
707
708/// Unit registry: maps unit names to `UnitInfo` (dimension + scale).
709#[derive(Debug, Clone)]
710pub struct UnitRegistry {
711    units: HashMap<UnitRef, UnitInfo>,
712}
713
714impl UnitRegistry {
715    /// Look up a unit by reference (bare or module-alias-qualified).
716    #[must_use]
717    pub fn get_unit(&self, name: &UnitRef) -> Option<&UnitInfo> {
718        self.units.get(name)
719    }
720
721    /// Iterate over all units: (reference, dimension, scale).
722    pub fn all_units(&self) -> impl Iterator<Item = (&UnitRef, &Dimension, &UnitScale)> {
723        self.units
724            .iter()
725            .map(|(name, info)| (name, &info.dimension, &info.scale))
726    }
727
728    /// Resolve a `UnitExpr` to its dimension and compound static scale factor.
729    ///
730    /// # Errors
731    ///
732    /// Returns a [`UnitResolveError`] naming the unknown or dynamic-scale
733    /// unit, or the exponent overflow.
734    pub fn resolve_unit_expr(&self, expr: &UnitExpr) -> Result<(Dimension, f64), UnitResolveError> {
735        resolve_unit_expr_impl(&self.units, expr)
736    }
737
738    /// Resolve a `UnitExpr` to its dimension only (ignoring scales).
739    ///
740    /// Works for both static and dynamic units.
741    ///
742    /// # Errors
743    ///
744    /// Returns a [`UnitResolveError`] naming the unknown unit, or the
745    /// exponent overflow.
746    pub fn resolve_unit_dimension(&self, expr: &UnitExpr) -> Result<Dimension, UnitResolveError> {
747        resolve_unit_dimension_impl(&self.units, expr)
748    }
749}
750
751/// Type registry: maps type names to `TypeDef` and provides
752/// constructor-namespace lookup.
753///
754/// The constructor namespace is *separate from* the type namespace: a
755/// single lexeme can name both a type (`Position` — the n-variant
756/// union) and a constructor (`Position` — the sole constructor of that
757/// union). [`lookup_ctor`](Self::lookup_ctor) walks the constructor
758/// side; [`get_type`](Self::get_type) walks the type side.
759#[derive(Debug, Clone)]
760pub struct TypeRegistry {
761    types: HashMap<StructTypeName, TypeDef>,
762    /// Constructor namespace: each constructor name resolves to the
763    /// union it belongs to. With no module system, the namespace is
764    /// flat. Duplicate names are rejected upstream during name
765    /// resolution; like every `register_*` entry point, insertion here
766    /// is last-wins defense-in-depth, not a validation layer.
767    ctors: HashMap<ConstructorName, StructTypeName>,
768}
769
770impl TypeRegistry {
771    /// Look up a type definition by type name.
772    #[must_use]
773    pub fn get_type(&self, name: &str) -> Option<&TypeDef> {
774        self.types.get(name)
775    }
776
777    /// Look up the union that owns a constructor name, plus the
778    /// constructor's payload fields. Returns `None` if the name is not
779    /// a registered constructor.
780    #[must_use]
781    pub fn lookup_ctor(&self, ctor: &ConstructorName) -> Option<(&TypeDef, &UnionMemberDef)> {
782        let union_name = self.ctors.get(ctor)?;
783        let td = self.types.get(union_name)?;
784        let members = td.union_members()?;
785        let member = members.iter().find(|m| m.name == *ctor)?;
786        Some((td, member))
787    }
788
789    /// Iterate over all registered type definitions.
790    pub fn all_types(&self) -> impl Iterator<Item = &TypeDef> {
791        self.types.values()
792    }
793}
794
795/// Index registry: maps declared index names and typed Nat-range identities to `IndexDef`.
796#[derive(Debug, Clone)]
797pub struct IndexRegistry {
798    indexes: HashMap<IndexName, IndexDef>,
799    nat_ranges: HashMap<NatRangeIndex, IndexDef>,
800}
801
802impl IndexRegistry {
803    /// Look up a declared index definition by name.
804    #[must_use]
805    pub fn get_index(&self, name: &str) -> Option<&IndexDef> {
806        self.indexes.get(name)
807    }
808
809    /// Look up a compiler-generated Nat range index by typed identity.
810    #[must_use]
811    pub fn get_nat_range(&self, index: NatRangeIndex) -> Option<&IndexDef> {
812        self.nat_ranges.get(&index)
813    }
814
815    /// Iterate over all index definitions.
816    pub fn all_indexes(&self) -> impl Iterator<Item = &IndexDef> {
817        self.indexes.values().chain(self.nat_ranges.values())
818    }
819}
820
821// ---------------------------------------------------------------------------
822// Frozen aggregate registry
823// ---------------------------------------------------------------------------
824
825/// The frozen, read-only aggregate of all domain registries.
826///
827/// Produced by [`RegistryBuilder::build`]. All fields are public so that
828/// consumers can access individual domain registries directly.
829#[derive(Debug, Clone)]
830pub struct Registry {
831    pub dimensions: DimensionRegistry,
832    pub units: UnitRegistry,
833    pub types: TypeRegistry,
834    pub indexes: IndexRegistry,
835    pub dags: DagRegistry,
836}
837
838/// Registry of `dag` declaration bodies accessible by name within a file.
839///
840/// Populated at IR lowering time with the raw AST body for each declared `dag`.
841/// Used during dim-checking (and later, evaluation) to resolve inline DAG
842/// invocations `@dag(args).out` against the called `dag`'s `pub param` and
843/// `pub node` signatures.
844#[derive(Debug, Default, Clone)]
845pub struct DagRegistry {
846    /// Dag bodies keyed by their declaration name. Dags live in the
847    /// declaration namespace, so the key is the typed [`DeclName`] like
848    /// every other registry — not a bare `String`.
849    dags: HashMap<DeclName, DagDecl>,
850}
851
852impl DagRegistry {
853    /// Return the AST body of the named `dag`, if one is declared in this file.
854    #[must_use]
855    pub fn get(&self, name: &str) -> Option<&DagDecl> {
856        self.dags.get(name)
857    }
858
859    /// Iterate over all registered dags.
860    pub fn all_dags(&self) -> impl Iterator<Item = (&DeclName, &DagDecl)> {
861        self.dags.iter()
862    }
863}
864
865// ---------------------------------------------------------------------------
866// Mutable builder
867// ---------------------------------------------------------------------------
868
869/// Mutable builder for constructing a [`Registry`].
870///
871/// Used during IR lowering and prelude loading. Call [`build()`](Self::build)
872/// to produce an immutable [`Registry`].
873#[derive(Debug, Default)]
874pub struct RegistryBuilder {
875    base_dim_names: BTreeMap<BaseDimId, String>,
876    base_dim_symbols: BTreeMap<BaseDimId, String>,
877
878    dimensions: HashMap<DimName, Dimension>,
879    units: HashMap<UnitRef, UnitInfo>,
880    types: HashMap<StructTypeName, TypeDef>,
881    ctors: HashMap<ConstructorName, StructTypeName>,
882    indexes: HashMap<IndexName, IndexDef>,
883    nat_ranges: HashMap<NatRangeIndex, IndexDef>,
884    dags: HashMap<DeclName, DagDecl>,
885    /// Base dimensions whose real-world units are affine (offset) scales,
886    /// e.g. Temperature (°C, °F). User unit definitions on these dimensions
887    /// are rejected because a purely multiplicative definition would display
888    /// silently wrong values (#648 U4).
889    affine_prone_dims: BTreeSet<BaseDimId>,
890}
891
892impl RegistryBuilder {
893    #[must_use]
894    pub fn new() -> Self {
895        Self::default()
896    }
897
898    /// Freeze the builder into an immutable [`Registry`].
899    #[must_use]
900    pub fn build(self) -> Registry {
901        Registry {
902            dimensions: DimensionRegistry {
903                base_dim_names: self.base_dim_names,
904                base_dim_symbols: self.base_dim_symbols,
905                dimensions: self.dimensions,
906            },
907            units: UnitRegistry { units: self.units },
908            types: TypeRegistry {
909                types: self.types,
910                ctors: self.ctors,
911            },
912            indexes: IndexRegistry {
913                indexes: self.indexes,
914                nat_ranges: self.nat_ranges,
915            },
916            dags: DagRegistry { dags: self.dags },
917        }
918    }
919
920    /// Register a `dag` declaration body keyed by the declaration's name.
921    ///
922    /// Accessed later during dim-checking of inline `@dag(args).out`
923    /// expressions.
924    pub fn register_dag(&mut self, name: DeclName, decl: DagDecl) {
925        self.dags.insert(name, decl);
926    }
927
928    /// Merge every entry from a frozen [`Registry`] into this builder.
929    ///
930    /// Used by inline-dag compilation: the dag body is lowered as a virtual
931    /// file whose registry is seeded with the enclosing file's dimensions,
932    /// units, indexes, types, and sibling dags so that reference resolution and
933    /// type checking behave as if the dag body were declared inline at the
934    /// top level.
935    pub fn merge_from_registry(&mut self, parent: &Registry) {
936        for (id, name) in &parent.dimensions.base_dim_names {
937            self.base_dim_names
938                .entry(id.clone())
939                .or_insert_with(|| name.clone());
940        }
941        for (id, symbol) in &parent.dimensions.base_dim_symbols {
942            self.base_dim_symbols
943                .entry(id.clone())
944                .or_insert_with(|| symbol.clone());
945        }
946        for (name, dim) in &parent.dimensions.dimensions {
947            self.dimensions
948                .entry(name.clone())
949                .or_insert_with(|| dim.clone());
950        }
951        for (name, info) in &parent.units.units {
952            self.units
953                .entry(name.clone())
954                .or_insert_with(|| info.clone());
955        }
956        for (name, def) in &parent.types.types {
957            self.types
958                .entry(name.clone())
959                .or_insert_with(|| def.clone());
960        }
961        for (ctor, union_name) in &parent.types.ctors {
962            self.ctors
963                .entry(ctor.clone())
964                .or_insert_with(|| union_name.clone());
965        }
966        for (name, def) in &parent.indexes.indexes {
967            self.indexes
968                .entry(name.clone())
969                .or_insert_with(|| def.clone());
970        }
971        for (index, def) in &parent.indexes.nat_ranges {
972            self.nat_ranges.entry(*index).or_insert_with(|| def.clone());
973        }
974        for (name, decl) in &parent.dags.dags {
975            self.dags
976                .entry(name.clone())
977                .or_insert_with(|| decl.clone());
978        }
979    }
980
981    // -- Mutation methods (only on builder) --
982
983    /// Register a new base dimension (`base dim Foo;`).
984    ///
985    /// The caller provides the [`BaseDimId`] which encodes the dimension's
986    /// identity (prelude name or user-defined file+name).
987    /// Mark a base dimension as affine-prone: its real-world units (e.g.
988    /// Celsius/Fahrenheit on Temperature) need offset conversions that unit
989    /// definitions cannot express, so user unit definitions on the bare
990    /// dimension are rejected (#648 U4).
991    pub fn mark_affine_prone(&mut self, id: BaseDimId) {
992        self.affine_prone_dims.insert(id);
993    }
994
995    /// Returns `true` when `dim` is exactly an affine-prone base dimension
996    /// (power 1). Compound dimensions involving the base (e.g.
997    /// `Temperature / Time`) stay allowed: offsets cancel in differences.
998    #[must_use]
999    pub fn is_affine_prone(&self, dim: &Dimension) -> bool {
1000        let mut iter = dim.iter();
1001        let Some((id, &exp)) = iter.next() else {
1002            return false;
1003        };
1004        iter.next().is_none() && exp == Rational::ONE && self.affine_prone_dims.contains(id)
1005    }
1006
1007    pub fn register_base_dimension(&mut self, name: DimName, id: BaseDimId) -> BaseDimId {
1008        let dim = Dimension::base(id.clone());
1009        self.base_dim_names.insert(id.clone(), name.to_string());
1010        self.dimensions.insert(name, dim);
1011        id
1012    }
1013
1014    /// Register a new base dimension with an SI symbol.
1015    ///
1016    /// Same as `register_base_dimension` but also records the default unit symbol
1017    /// used for runtime display (e.g., `"m"` for Length).
1018    pub fn register_base_dimension_with_symbol(
1019        &mut self,
1020        name: DimName,
1021        id: BaseDimId,
1022        symbol: String,
1023    ) -> BaseDimId {
1024        let id = self.register_base_dimension(name, id);
1025        self.base_dim_symbols.insert(id.clone(), symbol);
1026        id
1027    }
1028
1029    /// Record an SI symbol for an existing base dimension.
1030    ///
1031    /// Used when the first base unit for a user-defined dimension is declared
1032    /// (e.g., `base unit bit: Information;` records `"bit"` as the symbol).
1033    pub fn set_base_dim_symbol(&mut self, id: BaseDimId, symbol: String) {
1034        self.base_dim_symbols.entry(id).or_insert(symbol);
1035    }
1036
1037    /// Register a named dimension.
1038    pub fn register_dimension(&mut self, name: DimName, dim: Dimension) {
1039        self.dimensions.insert(name, dim);
1040    }
1041
1042    /// Register a named unit with its dimension and SI scale factor.
1043    pub fn register_unit(
1044        &mut self,
1045        name: impl Into<UnitRef>,
1046        dimension: Dimension,
1047        scale: PositiveFiniteScale,
1048    ) {
1049        self.units.insert(
1050            name.into(),
1051            UnitInfo {
1052                dimension,
1053                constness: UnitConstness::Const,
1054                scale: UnitScale::Static(scale),
1055            },
1056        );
1057    }
1058
1059    /// Register a named unit with an explicitly specified scale and constness.
1060    pub fn register_unit_with_scale(
1061        &mut self,
1062        name: impl Into<UnitRef>,
1063        dimension: Dimension,
1064        scale: UnitScale,
1065        constness: UnitConstness,
1066    ) {
1067        self.units.insert(
1068            name.into(),
1069            UnitInfo {
1070                dimension,
1071                constness,
1072                scale,
1073            },
1074        );
1075    }
1076
1077    /// Register a named runtime unit with a static or dynamic scale factor.
1078    pub fn register_unit_dynamic(
1079        &mut self,
1080        name: impl Into<UnitRef>,
1081        dimension: Dimension,
1082        scale: UnitScale,
1083    ) {
1084        self.register_unit_with_scale(name, dimension, scale, UnitConstness::Dynamic);
1085    }
1086
1087    /// Register a type definition.
1088    ///
1089    /// For tagged unions (the common case), also populates the
1090    /// constructor namespace: each variant's name resolves back to the
1091    /// union it belongs to. Constructor collisions are detected here —
1092    /// the prelude is loaded first, so any later user-defined
1093    /// constructor that collides with a prelude or sibling type's
1094    /// constructor is silently ignored on the *second* registration
1095    /// (consistent with the type-name "first wins" behavior).
1096    pub fn register_type(&mut self, def: TypeDef) {
1097        if let TypeDefKind::Union { ref members } = def.kind {
1098            for member in members {
1099                // Last-wins, like every other register_* entry point —
1100                // duplicates are rejected upstream during declaration collection.
1101                self.ctors.insert(member.name.clone(), def.name.clone());
1102            }
1103        }
1104        self.types.insert(def.name.clone(), def);
1105    }
1106
1107    /// Register an index definition.
1108    pub fn register_index(&mut self, def: IndexDef) {
1109        self.indexes.insert(def.name.clone(), def);
1110    }
1111
1112    /// Ensure a typed Nat range index of the given size is registered.
1113    ///
1114    /// If the index already exists, this is a no-op.
1115    ///
1116    /// `size` is `NonZeroUsize` because empty indexes are not representable.
1117    /// AST-level `u64` literals must be checked at the boundary before
1118    /// reaching this entry point.
1119    pub fn ensure_nat_range_index(&mut self, size: NonZeroUsize) -> NatRangeIndex {
1120        let nat_range = NatRangeIndex::new(size);
1121        self.nat_ranges
1122            .entry(nat_range)
1123            .or_insert_with(|| IndexDef {
1124                name: nat_range.display_name(),
1125                kind: IndexKind::NatRange { size },
1126            });
1127        nat_range
1128    }
1129
1130    // -- Read methods (needed during mid-build reads in ir.rs) --
1131
1132    /// Look up a dimension by name.
1133    #[must_use]
1134    pub fn get_dimension(&self, name: &str) -> Option<&Dimension> {
1135        self.dimensions.get(name)
1136    }
1137
1138    /// Look up a unit by name.
1139    #[must_use]
1140    pub fn get_unit(&self, name: &UnitRef) -> Option<&UnitInfo> {
1141        self.units.get(name)
1142    }
1143
1144    /// Iterate over all units: (reference, dimension, scale).
1145    pub fn all_units(&self) -> impl Iterator<Item = (&UnitRef, &Dimension, &UnitScale)> {
1146        self.units
1147            .iter()
1148            .map(|(name, info)| (name, &info.dimension, &info.scale))
1149    }
1150
1151    /// Look up a type definition by type name.
1152    #[must_use]
1153    pub fn get_type(&self, name: &str) -> Option<&TypeDef> {
1154        self.types.get(name)
1155    }
1156
1157    /// Look up a declared index definition by name.
1158    #[must_use]
1159    pub fn get_index(&self, name: &str) -> Option<&IndexDef> {
1160        self.indexes.get(name)
1161    }
1162
1163    /// Look up a compiler-generated Nat range index by typed identity.
1164    #[must_use]
1165    pub fn get_nat_range(&self, index: NatRangeIndex) -> Option<&IndexDef> {
1166        self.nat_ranges.get(&index)
1167    }
1168
1169    /// Get the base dimension names map (for display purposes).
1170    #[must_use]
1171    pub const fn base_dim_names(&self) -> &BTreeMap<BaseDimId, String> {
1172        &self.base_dim_names
1173    }
1174
1175    /// Get the base dimension symbols map for runtime display.
1176    #[must_use]
1177    pub const fn base_dim_symbols(&self) -> &BTreeMap<BaseDimId, String> {
1178        &self.base_dim_symbols
1179    }
1180
1181    /// Format a dimension as a human-readable string using registered base dimension names.
1182    ///
1183    /// Prefers a named dimension alias for compound dimensions, like
1184    /// [`DimensionRegistry::format_dimension`].
1185    #[must_use]
1186    pub fn format_dimension(&self, dim: &Dimension) -> String {
1187        format_dimension_preferring_alias(&self.dimensions, &self.base_dim_names, dim)
1188    }
1189
1190    /// Resolve a `DimExpr` AST node to a concrete `Dimension`.
1191    ///
1192    /// Returns `Ok(None)` if any dimension name is unknown, and `Err` if
1193    /// dimension exponent arithmetic overflows `i32`.
1194    pub fn resolve_dim_expr(&self, expr: &DimExpr) -> Result<Option<Dimension>, RationalError> {
1195        resolve_dim_expr_impl(&self.dimensions, expr)
1196    }
1197
1198    /// Resolve a `TypeExpr` to a concrete `Dimension`.
1199    ///
1200    /// Returns `Ok(None)` if the type references unknown dimensions, and
1201    /// `Err` if dimension exponent arithmetic overflows `i32`.
1202    pub fn resolve_type_expr(
1203        &self,
1204        type_expr: &TypeExpr,
1205    ) -> Result<Option<Dimension>, RationalError> {
1206        resolve_type_expr_impl(&self.dimensions, type_expr)
1207    }
1208
1209    /// Resolve a `UnitExpr` to its dimension and compound static scale factor.
1210    ///
1211    /// # Errors
1212    ///
1213    /// Returns a [`UnitResolveError`] naming the unknown or dynamic-scale
1214    /// unit, or the exponent overflow.
1215    pub fn resolve_unit_expr(&self, expr: &UnitExpr) -> Result<(Dimension, f64), UnitResolveError> {
1216        resolve_unit_expr_impl(&self.units, expr)
1217    }
1218
1219    /// Resolve a `UnitExpr` to its dimension only (ignoring scales).
1220    ///
1221    /// Works for both static and dynamic units.
1222    ///
1223    /// # Errors
1224    ///
1225    /// Returns a [`UnitResolveError`] naming the unknown unit, or the
1226    /// exponent overflow.
1227    pub fn resolve_unit_dimension(&self, expr: &UnitExpr) -> Result<Dimension, UnitResolveError> {
1228        resolve_unit_dimension_impl(&self.units, expr)
1229    }
1230}
1231
1232#[cfg(test)]
1233mod tests {
1234    use super::*;
1235    use crate::registry::prelude::load_prelude;
1236    use crate::syntax::ast::{DimExprItem, DimTerm, UnitExprItem};
1237    use crate::syntax::dimension::BaseDimId;
1238    use crate::syntax::names::{NamePath, UnitName};
1239    use crate::syntax::span::Span;
1240    use crate::syntax::span::Spanned;
1241
1242    // Well-known IDs matching prelude dimension names.
1243    fn length_id() -> BaseDimId {
1244        BaseDimId::Prelude("Length".to_string())
1245    }
1246    fn time_id() -> BaseDimId {
1247        BaseDimId::Prelude("Time".to_string())
1248    }
1249    fn mass_id() -> BaseDimId {
1250        BaseDimId::Prelude("Mass".to_string())
1251    }
1252
1253    fn make_registry() -> Registry {
1254        let mut b = RegistryBuilder::new();
1255        load_prelude(&mut b).unwrap();
1256        b.build()
1257    }
1258
1259    fn make_dim_term_name(name: &str) -> Spanned<NamePath> {
1260        Spanned::new(NamePath::from(name), Span::new(0, 0))
1261    }
1262
1263    /// Create a simple dimension `TypeExpr` from a name string.
1264    fn make_dim_type_expr(name: &str) -> TypeExpr {
1265        use crate::syntax::ast::{DimExpr, DimExprItem, DimTerm};
1266        TypeExpr {
1267            kind: TypeExprKind::DimExpr(DimExpr {
1268                terms: vec![DimExprItem {
1269                    op: MulDivOp::Mul,
1270                    term: DimTerm {
1271                        name: make_dim_term_name(name),
1272                        power: None,
1273                        span: Span::new(0, 0),
1274                    },
1275                }],
1276                span: Span::new(0, 0),
1277            }),
1278            constraints: vec![],
1279            span: Span::new(0, 0),
1280        }
1281    }
1282
1283    fn make_unit_name(name: &str) -> Spanned<UnitRef> {
1284        Spanned::new(UnitRef::local(UnitName::new(name)), Span::new(0, 0))
1285    }
1286
1287    #[test]
1288    fn registry_base_dimensions() {
1289        let r = make_registry();
1290        assert_eq!(
1291            r.dimensions.get_dimension("Length"),
1292            Some(&Dimension::base(length_id()))
1293        );
1294        assert_eq!(
1295            r.dimensions.get_dimension("Time"),
1296            Some(&Dimension::base(time_id()))
1297        );
1298        assert_eq!(
1299            r.dimensions.get_dimension("Mass"),
1300            Some(&Dimension::base(mass_id()))
1301        );
1302    }
1303
1304    #[test]
1305    fn registry_derived_dimensions() {
1306        let r = make_registry();
1307        let velocity = r.dimensions.get_dimension("Velocity").unwrap();
1308        let expected = (Dimension::base(length_id()) / Dimension::base(time_id())).unwrap();
1309        assert_eq!(*velocity, expected);
1310    }
1311
1312    #[test]
1313    fn registry_base_units() {
1314        let r = make_registry();
1315        let m = r.units.get_unit(&UnitRef::local("m")).unwrap();
1316        assert_eq!(m.dimension, Dimension::base(length_id()));
1317        assert!((m.scale.as_static().unwrap() - 1.0).abs() < f64::EPSILON);
1318    }
1319
1320    #[test]
1321    fn registry_derived_units() {
1322        let r = make_registry();
1323        let km = r.units.get_unit(&UnitRef::local("km")).unwrap();
1324        assert_eq!(km.dimension, Dimension::base(length_id()));
1325        assert!((km.scale.as_static().unwrap() - 1000.0).abs() < f64::EPSILON);
1326    }
1327
1328    #[test]
1329    fn resolve_dim_expr_velocity() {
1330        let r = make_registry();
1331        // Length / Time
1332        let expr = DimExpr {
1333            terms: vec![
1334                DimExprItem {
1335                    op: MulDivOp::Mul,
1336                    term: DimTerm {
1337                        name: make_dim_term_name("Length"),
1338                        power: None,
1339                        span: Span::new(0, 0),
1340                    },
1341                },
1342                DimExprItem {
1343                    op: MulDivOp::Div,
1344                    term: DimTerm {
1345                        name: make_dim_term_name("Time"),
1346                        power: None,
1347                        span: Span::new(0, 0),
1348                    },
1349                },
1350            ],
1351            span: Span::new(0, 0),
1352        };
1353        let dim = r.dimensions.resolve_dim_expr(&expr).unwrap().unwrap();
1354        let expected = (Dimension::base(length_id()) / Dimension::base(time_id())).unwrap();
1355        assert_eq!(dim, expected);
1356    }
1357
1358    #[test]
1359    fn resolve_unit_expr_m_per_s_squared() {
1360        let r = make_registry();
1361        // m / s^2
1362        let expr = UnitExpr {
1363            terms: vec![
1364                UnitExprItem {
1365                    op: MulDivOp::Mul,
1366                    name: make_unit_name("m"),
1367                    power: None,
1368                },
1369                UnitExprItem {
1370                    op: MulDivOp::Div,
1371                    name: make_unit_name("s"),
1372                    power: Some(Rational::from_int(2)),
1373                },
1374            ],
1375            span: Span::new(0, 0),
1376        };
1377        let (dim, scale) = r.units.resolve_unit_expr(&expr).unwrap();
1378        let expected_dim = (Dimension::base(length_id())
1379            / Dimension::base(time_id()).pow_int(2).unwrap())
1380        .unwrap();
1381        assert_eq!(dim, expected_dim);
1382        assert!((scale - 1.0).abs() < f64::EPSILON);
1383    }
1384
1385    #[test]
1386    fn resolve_unit_expr_km_per_hour() {
1387        let r = make_registry();
1388        // km / hour
1389        let expr = UnitExpr {
1390            terms: vec![
1391                UnitExprItem {
1392                    op: MulDivOp::Mul,
1393                    name: make_unit_name("km"),
1394                    power: None,
1395                },
1396                UnitExprItem {
1397                    op: MulDivOp::Div,
1398                    name: make_unit_name("hour"),
1399                    power: None,
1400                },
1401            ],
1402            span: Span::new(0, 0),
1403        };
1404        let (dim, scale) = r.units.resolve_unit_expr(&expr).unwrap();
1405        let expected_dim = (Dimension::base(length_id()) / Dimension::base(time_id())).unwrap();
1406        assert_eq!(dim, expected_dim);
1407        // km/hour = 1000 m / 3600 s ≈ 0.2778 m/s
1408        assert!((scale - 1000.0 / 3600.0).abs() < 1e-10);
1409    }
1410
1411    #[test]
1412    fn registry_type_register_and_lookup() {
1413        let mut b = RegistryBuilder::new();
1414        load_prelude(&mut b).unwrap();
1415        // Record-shaped types are single-variant unions whose sole
1416        // constructor's name matches the type's name.
1417        b.register_type(TypeDef {
1418            name: StructTypeName::new("TransferResult"),
1419            generic_params: vec![],
1420            kind: TypeDefKind::Union {
1421                members: vec![UnionMemberDef {
1422                    name: ConstructorName::new("TransferResult"),
1423                    fields: vec![
1424                        StructField {
1425                            name: FieldName::new("dv1"),
1426                            type_ann: make_dim_type_expr("Velocity"),
1427                        },
1428                        StructField {
1429                            name: FieldName::new("dv2"),
1430                            type_ann: make_dim_type_expr("Velocity"),
1431                        },
1432                    ],
1433                }],
1434            },
1435        });
1436        let r = b.build();
1437        let velocity_dim = (Dimension::base(length_id()) / Dimension::base(time_id())).unwrap();
1438        let def = r.types.get_type("TransferResult").unwrap();
1439        assert_eq!(def.name.as_str(), "TransferResult");
1440        assert!(def.is_union());
1441        let fields = def.record_fields().expect("single-variant collision");
1442        assert_eq!(fields.len(), 2);
1443        assert_eq!(fields[0].name.as_str(), "dv1");
1444        assert_eq!(
1445            r.dimensions.resolve_type_expr(&fields[0].type_ann),
1446            Ok(Some(velocity_dim))
1447        );
1448        assert!(r.types.get_type("NonExistent").is_none());
1449    }
1450
1451    #[test]
1452    fn registry_index_register_and_lookup() {
1453        let mut b = RegistryBuilder::new();
1454        load_prelude(&mut b).unwrap();
1455        b.register_index(IndexDef {
1456            name: IndexName::new("Maneuver"),
1457            kind: IndexKind::Named {
1458                variants: vec![
1459                    IndexVariantName::new("Departure"),
1460                    IndexVariantName::new("Correction"),
1461                    IndexVariantName::new("Insertion"),
1462                ],
1463            },
1464        });
1465        let r = b.build();
1466        let def = r.indexes.get_index("Maneuver").unwrap();
1467        assert_eq!(def.name.as_str(), "Maneuver");
1468        let variants = def.variants();
1469        let variant_strs: Vec<&str> = variants.iter().map(IndexVariantName::as_str).collect();
1470        assert_eq!(variant_strs, vec!["Departure", "Correction", "Insertion"]);
1471        assert!(r.indexes.get_index("NonExistent").is_none());
1472    }
1473
1474    #[test]
1475    fn register_user_defined_base_dimension() {
1476        let mut b = RegistryBuilder::new();
1477        load_prelude(&mut b).unwrap();
1478        let info_id = BaseDimId::UserDefined {
1479            dag: crate::dag_id::DagId::root("test"),
1480            name: "Information".to_string(),
1481        };
1482        let id = b.register_base_dimension(DimName::new("Information"), info_id.clone());
1483        assert_eq!(id, info_id);
1484        let r = b.build();
1485        // Should be retrievable
1486        let dim = r.dimensions.get_dimension("Information").unwrap();
1487        assert_eq!(*dim, Dimension::base(id.clone()));
1488        // Name should be recorded
1489        assert_eq!(
1490            r.dimensions.base_dim_names().get(&id),
1491            Some(&"Information".to_string())
1492        );
1493    }
1494
1495    #[test]
1496    fn register_base_dimension_with_symbol() {
1497        let mut b = RegistryBuilder::new();
1498        let id = b.register_base_dimension_with_symbol(
1499            DimName::new("Length"),
1500            BaseDimId::Prelude("Length".to_string()),
1501            "m".to_string(),
1502        );
1503        let r = b.build();
1504        assert_eq!(
1505            r.dimensions.base_dim_symbols().get(&id),
1506            Some(&"m".to_string())
1507        );
1508    }
1509
1510    #[test]
1511    fn set_base_dim_symbol_only_first() {
1512        let mut b = RegistryBuilder::new();
1513        let info_id = BaseDimId::UserDefined {
1514            dag: crate::dag_id::DagId::root("test"),
1515            name: "Information".to_string(),
1516        };
1517        let id = b.register_base_dimension(DimName::new("Information"), info_id);
1518        b.set_base_dim_symbol(id.clone(), "bit".to_string());
1519        // Second call should not overwrite
1520        b.set_base_dim_symbol(id.clone(), "byte".to_string());
1521        let r = b.build();
1522        assert_eq!(
1523            r.dimensions.base_dim_symbols().get(&id),
1524            Some(&"bit".to_string())
1525        );
1526    }
1527}