use std::collections::{BTreeMap, BTreeSet, HashMap};
use std::sync::Arc;
use miette::NamedSource;
use crate::desugar::desugared_ast::{MulDivOp, TypeExpr, TypeExprKind};
use crate::hir;
use crate::hir::diagnostics::{
expr_lower_error_to_graphcal, hir_lower_error_to_graphcal, resolved_decl_key,
};
pub use crate::ir::lower::{LoweredPlotBody, LoweredPlotField};
use crate::syntax::dimension::{Dimension, Rational};
use crate::syntax::names::{
ConstructorName, DeclName, DimName, FieldName, GenericParamName, IndexName, ModuleAliasName,
NameAtom, NamePath, StructTypeName,
};
use crate::syntax::nat::Monomial;
pub use crate::syntax::nat::{NatLinearForm, NatPolyForm};
use crate::syntax::span::{Span, Spanned};
use crate::ir::lower::IR;
use crate::ir::resolve::{DeclCategory, ExpectedFail};
use crate::registry::declared_type::IndexTypeRef;
use crate::registry::error::GraphcalError;
use crate::registry::time_scale::TimeScale;
use crate::registry::types::{
IndexDef, Registry, RegistryBuilder, TypeDef, TypeGenericConstraint, UnionMemberDef,
};
use crate::syntax::module_resolve::{ModuleResolveError, ModuleResolver};
use crate::syntax::names::{ResolvedName, ScopedName, namespace};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ResolvedTypeExpr {
Dimensionless,
Bool,
Int,
Datetime(TimeScale),
IndexArg(ResolvedIndex),
Scalar(Dimension),
Struct(ResolvedName<namespace::StructType>, Span),
GenericStruct {
name: ResolvedName<namespace::StructType>,
type_args: Vec<Self>,
span: Span,
},
GenericDimParam(GenericParamName, Span),
GenericTypeParam(GenericParamName, Span),
GenericDimExpr {
terms: Vec<ResolvedDimTerm>,
span: Span,
},
Indexed {
base: Box<Self>,
indexes: Vec<ResolvedIndex>,
},
}
impl ResolvedTypeExpr {
#[must_use]
pub fn format(&self, registry: &Registry) -> String {
match self {
Self::Dimensionless => "Dimensionless".to_string(),
Self::Bool => "Bool".to_string(),
Self::Int => "Int".to_string(),
Self::Datetime(scale) => {
if scale.is_utc() {
"Datetime".to_string()
} else {
format!("Datetime<{scale}>")
}
}
Self::IndexArg(index) => format!("index {}", format_resolved_index(index)),
Self::Scalar(dim) => {
let formatted = registry.dimensions.format_dimension(dim);
if formatted.is_empty() {
"Dimensionless".to_string()
} else {
formatted
}
}
Self::Struct(name, _) => name.as_str().to_string(),
Self::GenericStruct {
name, type_args, ..
} => {
let args: Vec<String> = type_args.iter().map(|a| a.format(registry)).collect();
format!("{}<{}>", name.as_str(), args.join(", "))
}
Self::GenericDimParam(name, _) | Self::GenericTypeParam(name, _) => name.to_string(),
Self::GenericDimExpr { terms, .. } => {
let parts: Vec<String> = terms.iter().map(|t| t.format(registry)).collect();
parts.join(" ")
}
Self::Indexed { base, indexes } => {
let base_str = base.format(registry);
let idx_strs: Vec<String> = indexes.iter().map(format_resolved_index).collect();
format!("{base_str}[{}]", idx_strs.join(", "))
}
}
}
}
fn format_resolved_index(index: &ResolvedIndex) -> String {
match index {
ResolvedIndex::Concrete(name, _) => name.as_str().to_string(),
ResolvedIndex::GenericParam(name, _) => name.to_string(),
ResolvedIndex::NatExpr(form, _) => format!("range({})", form.format()),
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ResolvedDimTerm {
Concrete {
dim: Dimension,
power: Rational,
op: MulDivOp,
},
GenericParam {
name: GenericParamName,
power: Rational,
op: MulDivOp,
span: Span,
},
}
impl ResolvedDimTerm {
#[must_use]
pub const fn op(&self) -> MulDivOp {
match self {
Self::Concrete { op, .. } | Self::GenericParam { op, .. } => *op,
}
}
#[must_use]
pub fn format(&self, registry: &Registry) -> String {
let (name, power, op) = match self {
Self::Concrete { dim, power, op } => {
(registry.dimensions.format_dimension(dim), *power, *op)
}
Self::GenericParam {
name, power, op, ..
} => (name.to_string(), *power, *op),
};
let prefix = match op {
MulDivOp::Mul => "",
MulDivOp::Div => "/ ",
};
if power == Rational::ONE {
format!("{prefix}{name}")
} else {
format!(
"{prefix}{name}{}",
crate::registry::format::format_exponent(power)
)
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct NatRangeIndexIdentity {
form: NatPolyForm,
}
impl NatRangeIndexIdentity {
pub fn try_from_form(
form: NatPolyForm,
) -> Result<Self, crate::registry::types::NatRangeIndexError> {
if form.is_constant() {
crate::registry::types::NatRangeIndex::try_from_u64(form.constant())?;
}
Ok(Self { form })
}
#[must_use]
pub const fn form(&self) -> &NatPolyForm {
&self.form
}
#[must_use]
pub fn into_form(self) -> NatPolyForm {
self.form
}
pub fn to_index_type_ref(
&self,
) -> Result<IndexTypeRef, crate::registry::types::NatRangeIndexError> {
IndexTypeRef::from_nat_range_form(self.form.clone())
}
}
impl NatPolyForm {
pub fn to_nat_range_identity(
&self,
) -> Result<NatRangeIndexIdentity, crate::registry::types::NatRangeIndexError> {
NatRangeIndexIdentity::try_from_form(self.clone())
}
}
pub fn normalize_nat_expr(
expr: &crate::desugar::desugared_ast::NatExpr,
nat_params: &[GenericParamName],
src: &NamedSource<Arc<String>>,
) -> Result<NatPolyForm, GraphcalError> {
use crate::desugar::desugared_ast::NatExpr;
match expr {
NatExpr::Literal(n, _) => Ok(NatPolyForm::from_constant(*n)),
NatExpr::Var(ident) => {
let gp = nat_params
.iter()
.find(|p| p.as_str() == ident.name.as_str())
.ok_or_else(|| GraphcalError::UnknownIndex {
name: IndexName::new(&ident.name),
src: src.clone(),
span: ident.span.into(),
})?;
Ok(NatPolyForm::from_var(gp.clone()))
}
NatExpr::Add(lhs, rhs, span) => {
let l = normalize_nat_expr(lhs, nat_params, src)?;
let r = normalize_nat_expr(rhs, nat_params, src)?;
l.add(&r).map_err(|err| nat_overflow_error(err, src, *span))
}
NatExpr::Mul(lhs, rhs, span) => {
let l = normalize_nat_expr(lhs, nat_params, src)?;
let r = normalize_nat_expr(rhs, nat_params, src)?;
l.mul(&r).map_err(|err| nat_overflow_error(err, src, *span))
}
}
}
#[must_use]
pub fn nat_overflow_error(
err: crate::syntax::nat::NatOverflowError,
src: &NamedSource<Arc<String>>,
span: Span,
) -> GraphcalError {
GraphcalError::EvalError {
message: err.to_string(),
src: src.clone(),
span: span.into(),
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ResolvedIndex {
Concrete(ResolvedName<namespace::Index>, Span),
GenericParam(GenericParamName, Span),
NatExpr(NatPolyForm, Span),
}
impl ResolvedIndex {
#[must_use]
pub fn format_for_diagnostic(&self) -> String {
match self {
Self::Concrete(name, _) => name.as_str().to_string(),
Self::GenericParam(name, _) => name.to_string(),
Self::NatExpr(form, _) => format!("range({})", form.format()),
}
}
}
#[derive(Debug, Clone)]
pub struct ModuleConstructorDef {
pub owning_type: ResolvedName<namespace::StructType>,
pub type_def: TypeDef,
pub variant: UnionMemberDef,
}
#[derive(Debug, Default, Clone)]
pub struct ModuleTypeRegistry {
dimensions: HashMap<ResolvedName<namespace::Dim>, Dimension>,
indexes: HashMap<ResolvedName<namespace::Index>, IndexDef>,
struct_types: HashMap<ResolvedName<namespace::StructType>, TypeDef>,
constructors: HashMap<ResolvedName<namespace::Constructor>, ModuleConstructorDef>,
}
impl ModuleTypeRegistry {
pub fn insert_graphcal_prelude(
&mut self,
) -> Result<(), crate::syntax::dimension::RationalError> {
let mut builder = RegistryBuilder::new();
crate::registry::prelude::load_prelude(&mut builder)?;
let registry = builder.build();
let owner = crate::registry::prelude::prelude_dag_id();
for name in crate::registry::prelude::PRELUDE_DIMENSION_NAMES {
if let Some(dim) = registry.dimensions.get_dimension(name) {
self.dimensions.insert(
ResolvedName::from_def(owner.clone(), DimName::new(*name)),
dim.clone(),
);
}
}
Ok(())
}
pub fn insert_registry(&mut self, owner: &crate::dag_id::DagId, registry: &Registry) {
for (name, dim) in registry.dimensions.all_dimensions() {
self.dimensions.insert(
ResolvedName::from_def(owner.clone(), name.clone()),
dim.clone(),
);
}
for index in registry.indexes.all_indexes() {
self.indexes.insert(
ResolvedName::from_def(owner.clone(), index.name.clone()),
index.clone(),
);
}
for type_def in registry.types.all_types() {
let type_name = ResolvedName::from_def(owner.clone(), type_def.name.clone());
self.struct_types
.insert(type_name.clone(), type_def.clone());
if let Some(members) = type_def.union_members() {
for member in members {
self.constructors.insert(
ResolvedName::from_def(owner.clone(), member.name.clone()),
ModuleConstructorDef {
owning_type: type_name.clone(),
type_def: type_def.clone(),
variant: member.clone(),
},
);
}
}
}
}
#[must_use]
pub fn get_dimension(&self, name: &ResolvedName<namespace::Dim>) -> Option<&Dimension> {
self.dimensions.get(name)
}
#[must_use]
pub fn get_index(&self, name: &ResolvedName<namespace::Index>) -> Option<&IndexDef> {
self.indexes.get(name)
}
#[must_use]
pub fn get_struct_type(&self, name: &ResolvedName<namespace::StructType>) -> Option<&TypeDef> {
self.struct_types.get(name)
}
#[must_use]
pub fn lookup_constructor(
&self,
constructor: &ResolvedName<namespace::Constructor>,
) -> Option<&ModuleConstructorDef> {
self.constructors.get(constructor)
}
}
#[derive(Debug, Clone, Copy)]
pub struct ModuleTypeContext<'a> {
owner: &'a crate::dag_id::DagId,
resolver: &'a ModuleResolver,
types: &'a ModuleTypeRegistry,
}
impl<'a> ModuleTypeContext<'a> {
#[must_use]
pub const fn new(
owner: &'a crate::dag_id::DagId,
resolver: &'a ModuleResolver,
types: &'a ModuleTypeRegistry,
) -> Self {
Self {
owner,
resolver,
types,
}
}
#[must_use]
pub const fn owner(self) -> &'a crate::dag_id::DagId {
self.owner
}
}
#[derive(Debug, Clone)]
pub struct ResolvedDomainConstraint {
pub min: Option<f64>,
pub max: Option<f64>,
pub min_display: Option<String>,
pub max_display: Option<String>,
pub span: Span,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct StructFieldConstraintKey {
pub owning_type: crate::registry::declared_type::StructTypeRef,
pub constructor: ConstructorName,
pub field: FieldName,
}
impl StructFieldConstraintKey {
#[must_use]
pub const fn new(
owning_type: crate::registry::declared_type::StructTypeRef,
constructor: ConstructorName,
field: FieldName,
) -> Self {
Self {
owning_type,
constructor,
field,
}
}
}
pub type DagRegistry = HashMap<crate::dag_id::DagId, DagTIR>;
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct ResolvedDagDependencies {
pub runtime_deps:
HashMap<ResolvedName<namespace::Decl>, BTreeSet<ResolvedName<namespace::Decl>>>,
pub const_deps: HashMap<ResolvedName<namespace::Decl>, BTreeSet<ResolvedName<namespace::Decl>>>,
}
#[derive(Debug, Clone, Default)]
pub struct ResolvedExpressions {
pub consts: HashMap<ResolvedName<namespace::Decl>, hir::Expr>,
pub param_defaults: HashMap<ResolvedName<namespace::Decl>, hir::Expr>,
pub nodes: HashMap<ResolvedName<namespace::Decl>, hir::Expr>,
pub asserts: HashMap<ResolvedName<namespace::Decl>, hir::AssertBody>,
}
impl ResolvedExpressions {
#[must_use]
pub fn runtime_expr(&self, key: &ResolvedName<namespace::Decl>) -> Option<&hir::Expr> {
self.param_defaults.get(key).or_else(|| self.nodes.get(key))
}
}
#[derive(Debug, Clone, Default)]
pub struct ResolvedCollectionRefs {
pub index_defs: HashMap<ResolvedName<namespace::Index>, IndexDef>,
}
#[derive(Debug, Clone, Default)]
pub struct ResolvedConstructorRefs {
pub constructor_defs: HashMap<ResolvedName<namespace::Constructor>, ResolvedConstructorTarget>,
}
#[derive(Debug, Clone, Default)]
pub struct ResolvedInlineDagRefs {
pub calls: HashMap<Span, ResolvedInlineDagCall>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ResolvedStructFieldTypeKey {
pub owning_type: ResolvedName<namespace::StructType>,
pub constructor: ConstructorName,
pub field: FieldName,
}
#[derive(Debug, Clone, Default)]
pub struct ResolvedTypeDefs {
pub struct_types: HashMap<ResolvedName<namespace::StructType>, TypeDef>,
pub field_types: HashMap<ResolvedStructFieldTypeKey, ResolvedTypeExpr>,
pub field_bounds: HashMap<ResolvedStructFieldTypeKey, Vec<ResolvedDomainBound>>,
pub generic_defaults:
HashMap<(ResolvedName<namespace::StructType>, GenericParamName), ResolvedTypeExpr>,
}
#[derive(Debug, Clone)]
pub struct ResolvedDomainBound {
pub kind: crate::syntax::ast::DomainBoundKind,
pub kind_span: Span,
pub value: hir::Expr,
pub span: Span,
}
#[derive(Debug, Clone, Default)]
pub struct DagSemanticBody {
pub expressions: ResolvedExpressions,
pub domain_bounds: HashMap<ResolvedName<namespace::Decl>, Vec<ResolvedDomainBound>>,
pub plot_exprs: ResolvedPlotExprs,
pub dynamic_unit_scales: HashMap<crate::syntax::names::UnitRef, hir::Expr>,
pub dependencies: ResolvedDagDependencies,
pub collection_refs: ResolvedCollectionRefs,
pub constructor_refs: ResolvedConstructorRefs,
pub inline_dag_refs: ResolvedInlineDagRefs,
pub type_defs: ResolvedTypeDefs,
pub decl_bindings: HashMap<ScopedName, ResolvedName<namespace::Decl>>,
}
#[derive(Debug, Clone, Default)]
pub struct ResolvedPlotExprs {
pub plots: HashMap<ScopedName, LoweredPlotBody>,
pub figures: HashMap<ScopedName, Vec<LoweredPlotField>>,
pub layers: HashMap<ScopedName, Vec<LoweredPlotField>>,
}
#[derive(Debug, Clone)]
pub struct ResolvedInlineDagCall {
pub target: crate::dag_id::DagId,
pub arg_targets: HashMap<Span, ResolvedName<namespace::Decl>>,
pub output: Spanned<ResolvedName<namespace::Decl>>,
}
#[derive(Debug, Clone)]
pub struct ResolvedConstructorTarget {
pub constructor: ResolvedName<namespace::Constructor>,
pub owning_type: ResolvedName<namespace::StructType>,
pub type_def: TypeDef,
pub variant: UnionMemberDef,
}
#[derive(Debug, Clone)]
pub struct TIR {
pub registry: Registry,
pub root_dag_id: crate::dag_id::DagId,
pub dags: DagRegistry,
pub module_aliases: HashMap<ModuleAliasName, crate::dag_id::DagId>,
}
impl TIR {
#[must_use]
#[expect(
clippy::expect_used,
reason = "TIR invariant: root entry always present"
)]
pub fn root(&self) -> &DagTIR {
self.dags
.get(&self.root_dag_id)
.expect("TIR.dags must contain root_dag_id")
}
#[expect(
clippy::expect_used,
reason = "TIR invariant: root entry always present"
)]
pub fn root_mut(&mut self) -> &mut DagTIR {
self.dags
.get_mut(&self.root_dag_id)
.expect("TIR.dags must contain root_dag_id")
}
#[must_use]
pub fn is_library(&self) -> bool {
self.root().params.iter().any(|p| p.default_expr.is_none())
|| self
.registry
.indexes
.all_indexes()
.any(crate::registry::types::IndexDef::is_required)
}
pub fn build_declared_types(
&self,
src: &NamedSource<Arc<String>>,
) -> Result<HashMap<ScopedName, crate::registry::declared_type::DeclaredType>, GraphcalError>
{
self.root().build_declared_types(src)
}
#[must_use]
pub fn lookup_call_target(&self, path: &crate::syntax::ast::ModulePath) -> Option<&DagTIR> {
let id = self.resolve_call_path(path)?;
self.dags.get(&id)
}
#[must_use]
pub fn resolve_call_path(
&self,
path: &crate::syntax::ast::ModulePath,
) -> Option<crate::dag_id::DagId> {
if path.segments.len() == 1 {
return Some(self.root_dag_id.child(path.segments[0].name.as_str()));
}
let alias = path.segments[0].name.as_str();
let dep_id = self.module_aliases.get(alias)?;
let mut id = dep_id.clone();
for seg in &path.segments.as_slice()[1..] {
id = id.child(seg.name.as_str());
}
Some(id)
}
#[must_use]
pub fn empty_for_eval_helpers(registry: Registry) -> Self {
let root_dag_id = crate::dag_id::DagId::root("<eval-helper>");
let mut dags = DagRegistry::new();
dags.insert(
root_dag_id.clone(),
DagTIR {
dag_id: root_dag_id.clone(),
consts: Vec::new(),
params: Vec::new(),
nodes: Vec::new(),
asserts: Vec::new(),
plots: Vec::new(),
figures: Vec::new(),
layers: Vec::new(),
included_plots: Vec::new(),
semantic: DagSemanticBody::default(),
source_order: Vec::new(),
assert_names: std::collections::HashSet::new(),
assumes_map: HashMap::new(),
expected_fail: HashMap::new(),
resolved_decl_types: HashMap::new(),
domain_constraints: HashMap::new(),
imported_values: HashMap::new(),
imported_decl_types: HashMap::new(),
imported_value_sources: HashMap::new(),
pub_nodes: std::collections::HashSet::new(),
},
);
Self {
registry,
root_dag_id,
dags,
module_aliases: HashMap::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct DagTIR {
pub dag_id: crate::dag_id::DagId,
pub consts: Vec<crate::ir::lower::ConstEntry>,
pub params: Vec<crate::ir::lower::ParamEntry>,
pub nodes: Vec<crate::ir::lower::NodeEntry>,
pub asserts: Vec<crate::ir::lower::AssertEntry>,
pub plots: Vec<crate::ir::lower::PlotEntry>,
pub figures: Vec<crate::ir::lower::FigureEntry>,
pub layers: Vec<crate::ir::lower::LayerEntry>,
pub included_plots: Vec<crate::ir::lower::IncludedPlotEntry>,
pub semantic: DagSemanticBody,
pub source_order: Vec<(ScopedName, DeclCategory)>,
pub assert_names: std::collections::HashSet<ScopedName>,
pub assumes_map: HashMap<ScopedName, Vec<ScopedName>>,
pub expected_fail: HashMap<ScopedName, ExpectedFail>,
pub resolved_decl_types: HashMap<ScopedName, ResolvedTypeExpr>,
pub domain_constraints: HashMap<ScopedName, ResolvedDomainConstraint>,
pub imported_values: HashMap<
ScopedName,
(
crate::registry::runtime_value::RuntimeValue,
crate::registry::declared_type::DeclaredType,
),
>,
pub imported_decl_types: HashMap<ScopedName, crate::registry::declared_type::DeclaredType>,
pub imported_value_sources: HashMap<ScopedName, crate::ir::lower::ImportedValueSource>,
pub pub_nodes: std::collections::HashSet<DeclName>,
}
impl DagTIR {
pub fn build_declared_types(
&self,
src: &NamedSource<Arc<String>>,
) -> Result<HashMap<ScopedName, crate::registry::declared_type::DeclaredType>, GraphcalError>
{
let mut declared_types = HashMap::new();
for name in crate::registry::builtins::builtin_constants().keys() {
declared_types.insert(
ScopedName::local(*name),
crate::registry::declared_type::DeclaredType::Scalar(Dimension::dimensionless()),
);
}
for (name, dt) in &self.imported_decl_types {
declared_types.insert(name.clone(), dt.clone());
}
for (name, (_rv, dt)) in &self.imported_values {
declared_types.insert(name.clone(), dt.clone());
}
for (name, resolved) in &self.resolved_decl_types {
let dt = resolved_to_declared_type(resolved, src)?;
declared_types.insert(name.clone(), dt);
}
Ok(declared_types)
}
pub fn populate_pub_nodes(&mut self, body: &[crate::desugar::desugared_ast::Declaration]) {
use crate::desugar::desugared_ast::DeclKind;
for decl in body {
if let DeclKind::Node(n) = &decl.kind
&& n.visibility.is_public()
{
self.pub_nodes.insert(n.name.value.clone());
}
}
}
#[must_use]
pub fn resolved_decl_key_for_local(
&self,
name: &ScopedName,
) -> Option<ResolvedName<namespace::Decl>> {
if let Some(resolved) = self.semantic.decl_bindings.get(name) {
return Some(resolved.clone());
}
if self.resolved_decl_types.contains_key(name)
|| self
.source_order
.iter()
.any(|(source_name, _)| source_name == name)
{
return resolved_decl_key(&self.dag_id, name);
}
if !name.is_qualified() {
let mut candidates = self
.resolved_decl_types
.keys()
.filter(|candidate| candidate.member() == name.member())
.filter_map(|candidate| resolved_decl_key(&self.dag_id, candidate));
if let Some(candidate) = candidates.next()
&& candidates.next().is_none()
{
return Some(candidate);
}
}
resolved_decl_key(&self.dag_id, name)
}
}
pub fn type_resolve_with_modules(
ir: IR,
root_dag_id: crate::dag_id::DagId,
src: &NamedSource<Arc<String>>,
module_resolver: &ModuleResolver,
module_types: &ModuleTypeRegistry,
) -> Result<TIR, GraphcalError> {
let owner_for_ctx = root_dag_id.clone();
let ctx = ModuleTypeContext::new(&owner_for_ctx, module_resolver, module_types);
type_resolve_impl(ir, root_dag_id, src, ctx)
}
fn type_resolve_impl(
ir: IR,
root_dag_id: crate::dag_id::DagId,
src: &NamedSource<Arc<String>>,
module_ctx: ModuleTypeContext<'_>,
) -> Result<TIR, GraphcalError> {
let imported_value_sources_for_hir = ir.imported_value_sources.clone();
let asserts_for_hir = ir.asserts.clone();
let mut root_dag = type_resolve_dag(
ir.consts,
ir.params,
ir.nodes,
&asserts_for_hir,
&ir.registry,
src,
&root_dag_id,
module_ctx,
&imported_value_sources_for_hir,
)?
.with_body(
ir.asserts,
ir.plots,
ir.figures,
ir.layers,
ir.included_plots,
ir.source_order,
ir.assert_names,
ir.assumes_map,
ir.expected_fail,
ir.imported_values,
ir.imported_decl_types,
ir.imported_value_sources,
module_ctx,
src,
)?;
lower_dynamic_unit_scales(&ir.registry, module_ctx, &mut root_dag.semantic);
augment_runtime_deps_for_dynamic_units(&mut root_dag.semantic);
check_hir_body_policies(
&root_dag.semantic,
&ir.registry,
&ir.pub_names,
module_ctx,
src,
)?;
let mut dags = DagRegistry::new();
dags.insert(root_dag_id.clone(), root_dag);
Ok(TIR {
registry: ir.registry,
root_dag_id,
dags,
module_aliases: HashMap::new(),
})
}
pub fn type_resolve_single_with_modules(
ir: IR,
dag_id: &crate::dag_id::DagId,
src: &NamedSource<Arc<String>>,
module_resolver: &ModuleResolver,
module_types: &ModuleTypeRegistry,
) -> Result<DagTIR, GraphcalError> {
let ctx = ModuleTypeContext::new(dag_id, module_resolver, module_types);
type_resolve_single_impl(ir, dag_id, src, ctx)
}
fn type_resolve_single_impl(
ir: IR,
dag_id: &crate::dag_id::DagId,
src: &NamedSource<Arc<String>>,
module_ctx: ModuleTypeContext<'_>,
) -> Result<DagTIR, GraphcalError> {
let imported_value_sources_for_hir = ir.imported_value_sources.clone();
let asserts_for_hir = ir.asserts.clone();
let mut dag = type_resolve_dag(
ir.consts,
ir.params,
ir.nodes,
&asserts_for_hir,
&ir.registry,
src,
dag_id,
module_ctx,
&imported_value_sources_for_hir,
)?
.with_body(
ir.asserts,
ir.plots,
ir.figures,
ir.layers,
ir.included_plots,
ir.source_order,
ir.assert_names,
ir.assumes_map,
ir.expected_fail,
ir.imported_values,
ir.imported_decl_types,
ir.imported_value_sources,
module_ctx,
src,
)?;
lower_dynamic_unit_scales(&ir.registry, module_ctx, &mut dag.semantic);
augment_runtime_deps_for_dynamic_units(&mut dag.semantic);
check_hir_body_policies(&dag.semantic, &ir.registry, &ir.pub_names, module_ctx, src)?;
Ok(dag)
}
fn lower_dynamic_unit_scales(
registry: &Registry,
ctx: ModuleTypeContext<'_>,
semantic: &mut DagSemanticBody,
) {
let generic_scope = hir::GenericScope::new();
let prelude = hir::PreludeTypeScope::graphcal();
let expr_ctx = hir::ExprLoweringContext::new(ctx.owner, ctx.resolver, &generic_scope)
.with_prelude(&prelude)
.with_decl_bindings(&semantic.decl_bindings);
for (name, _dim, scale) in registry.units.all_units() {
if let crate::registry::types::UnitScale::Dynamic { scale_expr, .. } = scale
&& let Ok(lowered) = hir::lower_expr(scale_expr, expr_ctx)
{
semantic.dynamic_unit_scales.insert(name.clone(), lowered);
}
}
}
#[expect(
clippy::too_many_arguments,
reason = "orchestrates per-DAG type resolution across IR declarations and semantic body data"
)]
fn type_resolve_dag(
consts: Vec<crate::ir::lower::ConstEntry>,
params: Vec<crate::ir::lower::ParamEntry>,
nodes: Vec<crate::ir::lower::NodeEntry>,
asserts: &[crate::ir::lower::AssertEntry],
registry: &Registry,
src: &NamedSource<Arc<String>>,
dag_id: &crate::dag_id::DagId,
module_ctx: ModuleTypeContext<'_>,
imported_value_sources: &HashMap<ScopedName, crate::ir::lower::ImportedValueSource>,
) -> Result<DagTIRSeed, GraphcalError> {
let mut resolved_decl_types = HashMap::new();
let no_generic_params: &[GenericParamName] = &[];
for entry in &consts {
let resolved = resolve_type_expr_inner(
&entry.type_ann,
registry,
dag_id,
no_generic_params,
no_generic_params,
no_generic_params,
entry.src.resolve(src),
Some(module_ctx),
)?;
resolved_decl_types.insert(entry.name.clone(), resolved);
}
for entry in ¶ms {
let resolved = resolve_type_expr_inner(
&entry.type_ann,
registry,
dag_id,
no_generic_params,
no_generic_params,
no_generic_params,
entry.src.resolve(src),
Some(module_ctx),
)?;
resolved_decl_types.insert(entry.name.clone(), resolved);
}
for entry in &nodes {
let resolved = resolve_type_expr_inner(
&entry.type_ann,
registry,
dag_id,
no_generic_params,
no_generic_params,
no_generic_params,
entry.src.resolve(src),
Some(module_ctx),
)?;
resolved_decl_types.insert(entry.name.clone(), resolved);
}
let LoweredDagExpressions {
exprs: expressions,
domain_bounds,
} = lower_resolved_expressions(
&consts,
¶ms,
&nodes,
asserts,
module_ctx,
imported_value_sources,
src,
)?;
let dependencies =
collect_resolved_dag_dependencies(&consts, ¶ms, &nodes, &expressions, module_ctx, src)?;
let collection_refs = collect_resolved_collection_refs(
&expressions,
&domain_bounds,
&resolved_decl_types,
module_ctx,
src,
)?;
let constructor_refs =
collect_resolved_constructor_refs(&expressions, &domain_bounds, module_ctx, src)?;
let inline_dag_refs = collect_resolved_inline_dag_refs(&expressions);
let type_defs = collect_resolved_type_defs(
&resolved_decl_types,
&constructor_refs,
module_ctx,
registry,
src,
)?;
let semantic = DagSemanticBody {
expressions,
domain_bounds,
plot_exprs: ResolvedPlotExprs::default(),
dynamic_unit_scales: HashMap::new(),
dependencies,
collection_refs,
constructor_refs,
inline_dag_refs,
type_defs,
decl_bindings: HashMap::new(),
};
Ok(DagTIRSeed {
dag_id: dag_id.clone(),
consts,
params,
nodes,
resolved_decl_types,
semantic,
})
}
fn collect_resolved_type_defs(
resolved_decl_types: &HashMap<ScopedName, ResolvedTypeExpr>,
constructor_refs: &ResolvedConstructorRefs,
ctx: ModuleTypeContext<'_>,
registry: &Registry,
src: &NamedSource<Arc<String>>,
) -> Result<ResolvedTypeDefs, GraphcalError> {
let mut defs = ResolvedTypeDefs::default();
if let Some(symbols) = ctx.resolver.modules().get(ctx.owner) {
for symbol in symbols.struct_types().values() {
record_resolved_struct_type_def(symbol.resolved(), ctx, registry, src, &mut defs)?;
}
}
for resolved in resolved_decl_types.values() {
collect_struct_type_defs_from_resolved_type(resolved, ctx, registry, src, &mut defs)?;
}
for target in constructor_refs.constructor_defs.values() {
record_resolved_struct_type_def(&target.owning_type, ctx, registry, src, &mut defs)?;
}
Ok(defs)
}
fn collect_struct_type_defs_from_resolved_type(
resolved: &ResolvedTypeExpr,
ctx: ModuleTypeContext<'_>,
registry: &Registry,
src: &NamedSource<Arc<String>>,
defs: &mut ResolvedTypeDefs,
) -> Result<(), GraphcalError> {
match resolved {
ResolvedTypeExpr::Struct(name, _) => {
record_resolved_struct_type_def(name, ctx, registry, src, defs)?;
}
ResolvedTypeExpr::GenericStruct {
name, type_args, ..
} => {
record_resolved_struct_type_def(name, ctx, registry, src, defs)?;
for arg in type_args {
collect_struct_type_defs_from_resolved_type(arg, ctx, registry, src, defs)?;
}
}
ResolvedTypeExpr::Indexed { base, indexes: _ } => {
collect_struct_type_defs_from_resolved_type(base, ctx, registry, src, defs)?;
}
ResolvedTypeExpr::Dimensionless
| ResolvedTypeExpr::Bool
| ResolvedTypeExpr::Int
| ResolvedTypeExpr::Datetime(_)
| ResolvedTypeExpr::IndexArg(_)
| ResolvedTypeExpr::Scalar(_)
| ResolvedTypeExpr::GenericDimParam(_, _)
| ResolvedTypeExpr::GenericTypeParam(_, _)
| ResolvedTypeExpr::GenericDimExpr { .. } => {}
}
Ok(())
}
fn record_resolved_struct_type_def(
name: &ResolvedName<namespace::StructType>,
ctx: ModuleTypeContext<'_>,
registry: &Registry,
src: &NamedSource<Arc<String>>,
defs: &mut ResolvedTypeDefs,
) -> Result<(), GraphcalError> {
if defs.struct_types.contains_key(name) {
return Ok(());
}
let Some(type_def) = ctx.types.get_struct_type(name) else {
return Ok(());
};
for param in &type_def.generic_params {
if let Some(default) = ¶m.default {
let resolved =
resolve_type_expr_in_struct_scope(default, name, type_def, ctx, registry, src)?;
defs.generic_defaults
.insert((name.clone(), param.name.clone()), resolved);
}
}
if let Some(members) = type_def.union_members() {
let generic_scope = generic_scope_for_type_def(name, type_def, src)?;
let prelude = hir::PreludeTypeScope::graphcal();
let bound_expr_ctx =
hir::ExprLoweringContext::new(name.owner(), ctx.resolver, &generic_scope)
.with_prelude(&prelude);
for member in members {
for field in &member.fields {
let key = ResolvedStructFieldTypeKey {
owning_type: name.clone(),
constructor: member.name.clone(),
field: field.name.clone(),
};
let resolved = resolve_type_expr_in_struct_scope(
&field.type_ann,
name,
type_def,
ctx,
registry,
src,
)?;
let bounds = lower_domain_bounds(&field.type_ann, bound_expr_ctx, src)?;
if !bounds.is_empty() {
defs.field_bounds.insert(key.clone(), bounds);
}
defs.field_types.insert(key, resolved);
}
}
}
defs.struct_types.insert(name.clone(), type_def.clone());
Ok(())
}
fn generic_scope_for_type_def(
name: &ResolvedName<namespace::StructType>,
type_def: &TypeDef,
src: &NamedSource<Arc<String>>,
) -> Result<hir::GenericScope, GraphcalError> {
let owner = hir::GenericParamOwner::Type(name.clone());
let mut scope = hir::GenericScope::new();
for param in &type_def.generic_params {
let constraint = match param.constraint {
TypeGenericConstraint::Dim => crate::syntax::ast::GenericConstraint::Dim,
TypeGenericConstraint::Index => crate::syntax::ast::GenericConstraint::Index,
TypeGenericConstraint::Nat => crate::syntax::ast::GenericConstraint::Nat,
TypeGenericConstraint::Unconstrained => crate::syntax::ast::GenericConstraint::Type,
};
scope
.insert_binding(hir::GenericParamBinding::new(
hir::GenericParamId::new(owner.clone(), param.name.clone()),
constraint,
Span::new(0, 0),
))
.map_err(|err| GraphcalError::InternalError {
message: format!("duplicate generic param while scoping `{name}`: {err}"),
src: src.clone(),
span: Span::new(0, 0).into(),
})?;
}
Ok(scope)
}
fn resolve_type_expr_in_struct_scope(
type_expr: &TypeExpr,
type_owner: &ResolvedName<namespace::StructType>,
type_def: &TypeDef,
ctx: ModuleTypeContext<'_>,
registry: &Registry,
src: &NamedSource<Arc<String>>,
) -> Result<ResolvedTypeExpr, GraphcalError> {
let prelude = hir::PreludeTypeScope::graphcal();
let resolve_ctx = HirTypeResolutionContext {
src,
resolver: ctx.resolver,
module_types: ctx.types,
registry: Some(registry),
prelude: &prelude,
};
let hir_type = lower_type_generic_default(type_expr, type_owner, type_def, resolve_ctx)?;
resolve_hir_type_expr_inner(&hir_type, resolve_ctx)
}
fn lower_resolved_expressions(
consts: &[crate::ir::lower::ConstEntry],
params: &[crate::ir::lower::ParamEntry],
nodes: &[crate::ir::lower::NodeEntry],
asserts: &[crate::ir::lower::AssertEntry],
ctx: ModuleTypeContext<'_>,
imported_value_sources: &HashMap<ScopedName, crate::ir::lower::ImportedValueSource>,
src: &NamedSource<Arc<String>>,
) -> Result<LoweredDagExpressions, GraphcalError> {
let generic_scope = hir::GenericScope::new();
let prelude = hir::PreludeTypeScope::graphcal();
let decl_bindings = collect_hir_decl_bindings(
ctx.owner,
consts,
params,
nodes,
imported_value_sources,
src,
)?;
let expr_ctx = hir::ExprLoweringContext::new(ctx.owner, ctx.resolver, &generic_scope)
.with_prelude(&prelude)
.with_decl_bindings(&decl_bindings);
let mut exprs = ResolvedExpressions::default();
let mut domain_bounds = HashMap::new();
for entry in consts {
let body_src = entry.src.resolve(src);
let key = decl_key_or_internal_error(ctx.owner, &entry.name, entry.span, body_src)?;
let bounds = lower_domain_bounds(&entry.type_ann, expr_ctx, body_src)?;
if !bounds.is_empty() {
domain_bounds.insert(key.clone(), bounds);
}
exprs.consts.insert(key, entry.expr.clone());
}
for entry in params {
let body_src = entry.src.resolve(src);
let key = decl_key_or_internal_error(ctx.owner, &entry.name, entry.span, body_src)?;
let bounds = lower_domain_bounds(&entry.type_ann, expr_ctx, body_src)?;
if !bounds.is_empty() {
domain_bounds.insert(key.clone(), bounds);
}
let Some(expr) = &entry.default_expr else {
continue;
};
exprs.param_defaults.insert(key, expr.clone());
}
for entry in nodes {
let body_src = entry.src.resolve(src);
let key = decl_key_or_internal_error(ctx.owner, &entry.name, entry.span, body_src)?;
let bounds = lower_domain_bounds(&entry.type_ann, expr_ctx, body_src)?;
if !bounds.is_empty() {
domain_bounds.insert(key.clone(), bounds);
}
exprs.nodes.insert(key, entry.expr.clone());
}
for entry in asserts {
let key =
decl_key_or_internal_error(ctx.owner, &entry.name, entry.span, entry.src.resolve(src))?;
exprs.asserts.insert(key, entry.body.clone());
}
Ok(LoweredDagExpressions {
exprs,
domain_bounds,
})
}
fn collect_plot_exprs(
plots: &[crate::ir::lower::PlotEntry],
figures: &[crate::ir::lower::FigureEntry],
layers: &[crate::ir::lower::LayerEntry],
ctx: ModuleTypeContext<'_>,
src: &NamedSource<Arc<String>>,
semantic: &mut DagSemanticBody,
) -> Result<(), GraphcalError> {
let mut plot_exprs = ResolvedPlotExprs::default();
let collect = |expr: &hir::Expr,
collection_refs: &mut ResolvedCollectionRefs,
constructor_refs: &mut ResolvedConstructorRefs|
-> Result<(), GraphcalError> {
collect_resolved_collection_refs_from_expr(expr, ctx, src, collection_refs)?;
collect_resolved_constructor_refs_from_expr(expr, ctx, src, constructor_refs)
};
for entry in plots {
let Some(body) = &entry.body else {
continue;
};
for (_, expr) in &body.encodings {
collect(
expr,
&mut semantic.collection_refs,
&mut semantic.constructor_refs,
)?;
}
for field in body.mark_properties.iter().chain(&body.properties) {
collect(
&field.value,
&mut semantic.collection_refs,
&mut semantic.constructor_refs,
)?;
}
plot_exprs.plots.insert(entry.name.clone(), body.clone());
}
for (name, fields, is_figure) in figures
.iter()
.map(|entry| (&entry.name, &entry.fields, true))
.chain(
layers
.iter()
.map(|entry| (&entry.name, &entry.fields, false)),
)
{
for field in fields {
collect(
&field.value,
&mut semantic.collection_refs,
&mut semantic.constructor_refs,
)?;
}
if is_figure {
plot_exprs.figures.insert(name.clone(), fields.clone());
} else {
plot_exprs.layers.insert(name.clone(), fields.clone());
}
}
semantic.plot_exprs = plot_exprs;
Ok(())
}
fn decl_key_or_internal_error(
owner: &crate::dag_id::DagId,
name: &ScopedName,
span: Span,
src: &NamedSource<Arc<String>>,
) -> Result<ResolvedName<namespace::Decl>, GraphcalError> {
resolved_decl_key(owner, name).ok_or_else(|| {
internal_error(
format!("could not build canonical declaration key for `{name}`"),
src,
span,
)
})
}
fn lower_domain_bounds(
type_ann: &crate::desugar::desugared_ast::TypeExpr,
expr_ctx: hir::ExprLoweringContext<'_>,
src: &NamedSource<Arc<String>>,
) -> Result<Vec<ResolvedDomainBound>, GraphcalError> {
type_ann
.domain_bounds()
.iter()
.map(|bound| {
let value = hir::lower_expr(&bound.value, expr_ctx)
.map_err(|err| expr_lower_error_to_graphcal(&err, src))?;
Ok(ResolvedDomainBound {
kind: bound.kind,
kind_span: bound.kind_span,
value,
span: bound.span,
})
})
.collect()
}
struct LoweredDagExpressions {
exprs: ResolvedExpressions,
domain_bounds: HashMap<ResolvedName<namespace::Decl>, Vec<ResolvedDomainBound>>,
}
fn check_hir_body_policies(
semantic: &DagSemanticBody,
registry: &Registry,
pub_names: &std::collections::HashSet<DeclName>,
ctx: ModuleTypeContext<'_>,
src: &NamedSource<Arc<String>>,
) -> Result<(), GraphcalError> {
let checker = HirPolicyChecker { registry, ctx, src };
let local = |key: &ResolvedName<namespace::Decl>| key.owner() == ctx.owner;
let is_pub = |leaf: &str| pub_names.contains(&DeclName::new(leaf));
for (key, expr) in &semantic.expressions.consts {
checker.check_expr(expr, true, local(key))?;
}
for (key, bounds) in &semantic.domain_bounds {
if semantic.expressions.consts.contains_key(key) {
for bound in bounds {
checker.check_expr(&bound.value, true, local(key))?;
}
}
}
for (key, expr) in &semantic.expressions.nodes {
checker.check_expr(expr, false, local(key))?;
}
for expr in semantic.expressions.param_defaults.values() {
checker.check_expr(expr, false, false)?;
}
for (key, body) in &semantic.expressions.asserts {
let check_literals = local(key) && is_pub(key.as_str());
match body {
hir::AssertBody::Expr(expr) => checker.check_expr(expr, false, check_literals)?,
hir::AssertBody::Tolerance {
actual,
expected,
tolerance,
..
} => {
checker.check_expr(actual, false, check_literals)?;
checker.check_expr(expected, false, check_literals)?;
checker.check_expr(tolerance, false, check_literals)?;
}
}
}
for (name, body) in &semantic.plot_exprs.plots {
let check_literals = !name.is_qualified() && is_pub(name.member());
for (_, expr) in &body.encodings {
checker.check_expr(expr, false, check_literals)?;
}
for field in body.mark_properties.iter().chain(&body.properties) {
checker.check_expr(&field.value, false, check_literals)?;
}
}
for (name, fields) in semantic
.plot_exprs
.figures
.iter()
.chain(&semantic.plot_exprs.layers)
{
let check_literals = !name.is_qualified() && is_pub(name.member());
for field in fields {
checker.check_expr(&field.value, false, check_literals)?;
}
}
Ok(())
}
struct HirPolicyChecker<'a> {
registry: &'a Registry,
ctx: ModuleTypeContext<'a>,
src: &'a NamedSource<Arc<String>>,
}
impl HirPolicyChecker<'_> {
fn check_expr(
&self,
expr: &hir::Expr,
const_body: bool,
check_pub_bind_literals: bool,
) -> Result<(), GraphcalError> {
crate::stack::with_stack_growth(|| {
self.check_expr_inner(expr, const_body, check_pub_bind_literals)
})
}
fn check_expr_inner(
&self,
expr: &hir::Expr,
const_body: bool,
check_pub_bind_literals: bool,
) -> Result<(), GraphcalError> {
let recurse =
|inner: &hir::Expr| self.check_expr(inner, const_body, check_pub_bind_literals);
match &expr.kind {
hir::ExprKind::Error
| hir::ExprKind::Number(_)
| hir::ExprKind::Integer(_)
| hir::ExprKind::Bool(_)
| hir::ExprKind::StringLiteral(_)
| hir::ExprKind::TypeSystemRef(_)
| hir::ExprKind::ConstRef(_)
| hir::ExprKind::LocalRef(_) => Ok(()),
hir::ExprKind::UnitLiteral { unit, .. } => self.check_const_unit_expr(unit, const_body),
hir::ExprKind::GraphRef(target) => {
self.check_graph_ref(target, expr.span, const_body)
}
hir::ExprKind::VariantLiteral(variant) => {
self.check_variant_literal(variant, check_pub_bind_literals)
}
hir::ExprKind::BinOp { lhs, rhs, .. } => {
recurse(lhs)?;
recurse(rhs)
}
hir::ExprKind::UnaryOp { operand, .. }
| hir::ExprKind::DisplayTimezone { expr: operand, .. }
| hir::ExprKind::FieldAccess { expr: operand, .. } => recurse(operand),
hir::ExprKind::Convert {
expr: operand,
target,
} => {
self.check_const_unit_expr(target, const_body)?;
recurse(operand)
}
hir::ExprKind::FnCall { args, .. } => args.iter().try_for_each(recurse),
hir::ExprKind::If {
condition,
then_branch,
else_branch,
} => {
recurse(condition)?;
recurse(then_branch)?;
recurse(else_branch)
}
hir::ExprKind::ConstructorCall { fields, .. } => {
fields.iter().try_for_each(|field| recurse(&field.value))
}
hir::ExprKind::MapLiteral { entries } => {
for entry in entries {
for key in &entry.keys {
if let hir::expr::MapEntryKey::IndexVariant(variant) = key {
self.check_variant_literal(variant, check_pub_bind_literals)?;
}
}
recurse(&entry.value)?;
}
Ok(())
}
hir::ExprKind::ForComp { body, .. } => recurse(body),
hir::ExprKind::IndexAccess { expr: inner, args } => {
recurse(inner)?;
for arg in args {
match arg {
hir::expr::IndexArg::Variant(variant) => {
self.check_variant_literal(variant, check_pub_bind_literals)?;
}
hir::expr::IndexArg::Expr(arg_expr) => recurse(arg_expr)?,
hir::expr::IndexArg::Var(_) => {}
}
}
Ok(())
}
hir::ExprKind::Scan {
source, init, body, ..
} => {
recurse(source)?;
recurse(init)?;
recurse(body)
}
hir::ExprKind::Unfold { init, body, .. } => {
recurse(init)?;
recurse(body)
}
hir::ExprKind::Match { scrutinee, arms } => {
recurse(scrutinee)?;
for arm in arms {
if let hir::expr::MatchPattern::IndexLabel { variant, .. } = &arm.pattern {
self.check_variant_literal(variant, check_pub_bind_literals)?;
}
recurse(&arm.body)?;
}
Ok(())
}
hir::ExprKind::InlineDagRef { args, .. } => {
args.iter().try_for_each(|arg| recurse(&arg.value))
}
}
}
fn check_const_unit_expr(
&self,
unit: &crate::desugar::desugared_ast::UnitExpr,
const_body: bool,
) -> Result<(), GraphcalError> {
if !const_body {
return Ok(());
}
for term in &unit.terms {
let Some(info) = self.registry.units.get_unit(&term.name.value) else {
continue;
};
if !info.constness.is_const() {
return Err(GraphcalError::NonConstUnitInConst {
name: term.name.value.clone(),
src: self.src.clone(),
span: term.name.span.into(),
});
}
}
Ok(())
}
fn check_graph_ref(
&self,
target: &Spanned<ResolvedName<namespace::Decl>>,
ref_span: Span,
const_body: bool,
) -> Result<(), GraphcalError> {
let Ok(kind) = self.ctx.resolver.decl_symbol_kind(&target.value) else {
return Ok(());
};
if matches!(kind, crate::syntax::module_resolve::DeclSymbolKind::Assert) {
return Err(GraphcalError::GraphRefToAssert {
name: DeclName::new(target.value.as_str()),
src: self.src.clone(),
span: ref_span.into(),
});
}
if const_body && !kind.is_const() {
return Err(GraphcalError::GraphRefInConst {
name: ScopedName::local(target.value.as_str()),
src: self.src.clone(),
span: ref_span.into(),
});
}
Ok(())
}
fn check_variant_literal(
&self,
variant: &hir::expr::IndexVariantRef,
check_pub_bind_literals: bool,
) -> Result<(), GraphcalError> {
if !check_pub_bind_literals {
return Ok(());
}
let index = variant.variant.index();
if index.owner() != self.ctx.owner {
return Ok(());
}
let is_pub_bind = self
.ctx
.resolver
.modules()
.get(self.ctx.owner)
.and_then(|symbols| symbols.indexes().get(&IndexName::new(index.as_str())))
.is_some_and(|symbol| {
symbol.visibility().is_bindable() && !symbol.variants().is_empty()
});
if is_pub_bind {
return Err(GraphcalError::PubIndexVariantLiteral {
index: index.as_str().to_string(),
variant: variant.variant.variant().as_str().to_string(),
src: self.src.clone(),
span: variant.path_span().into(),
});
}
Ok(())
}
}
fn augment_runtime_deps_for_dynamic_units(semantic: &mut DagSemanticBody) {
if semantic.dynamic_unit_scales.is_empty() {
return;
}
let scale_deps: HashMap<
crate::syntax::names::UnitRef,
BTreeSet<ResolvedName<namespace::Decl>>,
> = semantic
.dynamic_unit_scales
.iter()
.map(|(name, expr)| {
(
name.clone(),
hir::collect_expr_dependencies(expr).graph_refs,
)
})
.collect();
let DagSemanticBody {
expressions,
dependencies,
..
} = semantic;
for (key, expr) in expressions.param_defaults.iter().chain(&expressions.nodes) {
let mut unit_names = std::collections::HashSet::new();
collect_unit_names_from_hir(expr, &mut unit_names);
let extra: BTreeSet<ResolvedName<namespace::Decl>> = unit_names
.iter()
.filter_map(|unit| scale_deps.get(unit))
.flatten()
.cloned()
.collect();
if !extra.is_empty() {
dependencies
.runtime_deps
.entry(key.clone())
.or_default()
.extend(extra);
}
}
}
fn collect_unit_names_from_hir(
expr: &hir::Expr,
names: &mut std::collections::HashSet<crate::syntax::names::UnitRef>,
) {
crate::stack::with_stack_growth(|| match &expr.kind {
hir::ExprKind::UnitLiteral { unit, .. } => {
for term in &unit.terms {
names.insert(term.name.value.clone());
}
}
hir::ExprKind::Convert {
expr: inner,
target,
} => {
for term in &target.terms {
names.insert(term.name.value.clone());
}
collect_unit_names_from_hir(inner, names);
}
hir::ExprKind::Error
| hir::ExprKind::Number(_)
| hir::ExprKind::Integer(_)
| hir::ExprKind::Bool(_)
| hir::ExprKind::StringLiteral(_)
| hir::ExprKind::TypeSystemRef(_)
| hir::ExprKind::GraphRef(_)
| hir::ExprKind::ConstRef(_)
| hir::ExprKind::LocalRef(_)
| hir::ExprKind::VariantLiteral(_) => {}
hir::ExprKind::BinOp { lhs, rhs, .. } => {
collect_unit_names_from_hir(lhs, names);
collect_unit_names_from_hir(rhs, names);
}
hir::ExprKind::UnaryOp { operand, .. }
| hir::ExprKind::DisplayTimezone { expr: operand, .. }
| hir::ExprKind::FieldAccess { expr: operand, .. } => {
collect_unit_names_from_hir(operand, names);
}
hir::ExprKind::FnCall { args, .. } => {
for arg in args {
collect_unit_names_from_hir(arg, names);
}
}
hir::ExprKind::If {
condition,
then_branch,
else_branch,
} => {
collect_unit_names_from_hir(condition, names);
collect_unit_names_from_hir(then_branch, names);
collect_unit_names_from_hir(else_branch, names);
}
hir::ExprKind::ConstructorCall { fields, .. } => {
for field in fields {
collect_unit_names_from_hir(&field.value, names);
}
}
hir::ExprKind::MapLiteral { entries } => {
for entry in entries {
collect_unit_names_from_hir(&entry.value, names);
}
}
hir::ExprKind::ForComp { body, .. } => collect_unit_names_from_hir(body, names),
hir::ExprKind::IndexAccess { expr: inner, args } => {
collect_unit_names_from_hir(inner, names);
for arg in args {
if let hir::expr::IndexArg::Expr(arg_expr) = arg {
collect_unit_names_from_hir(arg_expr, names);
}
}
}
hir::ExprKind::Scan {
source, init, body, ..
} => {
collect_unit_names_from_hir(source, names);
collect_unit_names_from_hir(init, names);
collect_unit_names_from_hir(body, names);
}
hir::ExprKind::Unfold { init, body, .. } => {
collect_unit_names_from_hir(init, names);
collect_unit_names_from_hir(body, names);
}
hir::ExprKind::Match { scrutinee, arms } => {
collect_unit_names_from_hir(scrutinee, names);
for arm in arms {
collect_unit_names_from_hir(&arm.body, names);
}
}
hir::ExprKind::InlineDagRef { args, .. } => {
for arg in args {
collect_unit_names_from_hir(&arg.value, names);
}
}
});
}
fn collect_resolved_dag_dependencies(
consts: &[crate::ir::lower::ConstEntry],
params: &[crate::ir::lower::ParamEntry],
nodes: &[crate::ir::lower::NodeEntry],
exprs: &ResolvedExpressions,
ctx: ModuleTypeContext<'_>,
src: &NamedSource<Arc<String>>,
) -> Result<ResolvedDagDependencies, GraphcalError> {
let mut resolved = ResolvedDagDependencies::default();
for entry in consts {
let body_src = entry.src.resolve(src);
let key = resolved_decl_key(ctx.owner, &entry.name).ok_or_else(|| {
internal_error(
format!(
"could not build canonical declaration key for `{}`",
entry.name
),
body_src,
entry.span,
)
})?;
let hir_expr = exprs.consts.get(&key).ok_or_else(|| {
internal_error(
format!(
"missing HIR expression for const declaration `{}`",
entry.name
),
body_src,
entry.span,
)
})?;
let mut deps = hir::collect_expr_dependencies(hir_expr);
for graph_ref in &deps.graph_refs {
let kind = ctx
.resolver
.decl_symbol_kind(graph_ref)
.map_err(|err| module_resolve_error(&err, body_src, entry.span))?;
if kind.is_const() {
deps.const_refs.insert(graph_ref.clone());
}
}
resolved.const_deps.insert(key, deps.const_refs);
}
for entry in params {
let key = resolved_decl_key(ctx.owner, &entry.name).ok_or_else(|| {
internal_error(
format!(
"could not build canonical declaration key for `{}`",
entry.name
),
entry.src.resolve(src),
entry.span,
)
})?;
let deps = exprs.param_defaults.get(&key).map_or_else(
hir::ExprDependencies::default,
hir::collect_expr_dependencies,
);
resolved.runtime_deps.insert(key, deps.graph_refs);
}
for entry in nodes {
let body_src = entry.src.resolve(src);
let key = resolved_decl_key(ctx.owner, &entry.name).ok_or_else(|| {
internal_error(
format!(
"could not build canonical declaration key for `{}`",
entry.name
),
body_src,
entry.span,
)
})?;
let hir_expr = exprs.nodes.get(&key).ok_or_else(|| {
internal_error(
format!(
"missing HIR expression for node declaration `{}`",
entry.name
),
body_src,
entry.span,
)
})?;
let mut deps = hir::collect_expr_dependencies(hir_expr);
if !hir::has_ref_outside_unfold(hir_expr, &key) {
deps.graph_refs.remove(&key);
}
resolved.runtime_deps.insert(key, deps.graph_refs);
}
Ok(resolved)
}
fn collect_resolved_collection_refs(
exprs: &ResolvedExpressions,
domain_bounds: &HashMap<ResolvedName<namespace::Decl>, Vec<ResolvedDomainBound>>,
resolved_decl_types: &HashMap<ScopedName, ResolvedTypeExpr>,
ctx: ModuleTypeContext<'_>,
src: &NamedSource<Arc<String>>,
) -> Result<ResolvedCollectionRefs, GraphcalError> {
let mut refs = ResolvedCollectionRefs::default();
for resolved_type in resolved_decl_types.values() {
collect_resolved_collection_indexes_from_type(resolved_type, ctx, src, &mut refs)?;
}
for hir_expr in exprs
.consts
.values()
.chain(exprs.param_defaults.values())
.chain(exprs.nodes.values())
.chain(domain_bounds.values().flatten().map(|bound| &bound.value))
{
collect_resolved_collection_refs_from_expr(hir_expr, ctx, src, &mut refs)?;
}
for body in exprs.asserts.values() {
collect_resolved_collection_refs_from_assert_body(body, ctx, src, &mut refs)?;
}
Ok(refs)
}
fn record_resolved_collection_index(
index: &ResolvedName<namespace::Index>,
ctx: ModuleTypeContext<'_>,
src: &NamedSource<Arc<String>>,
span: Span,
refs: &mut ResolvedCollectionRefs,
) -> Result<(), GraphcalError> {
if refs.index_defs.contains_key(index) {
return Ok(());
}
let def = ctx.types.get_index(index).cloned().ok_or_else(|| {
internal_error(
format!("semantic collection metadata references unknown index `{index}`"),
src,
span,
)
})?;
refs.index_defs.insert(index.clone(), def);
Ok(())
}
fn collect_resolved_collection_indexes_from_type(
resolved_type: &ResolvedTypeExpr,
ctx: ModuleTypeContext<'_>,
src: &NamedSource<Arc<String>>,
refs: &mut ResolvedCollectionRefs,
) -> Result<(), GraphcalError> {
match resolved_type {
ResolvedTypeExpr::IndexArg(ResolvedIndex::Concrete(index, span)) => {
record_resolved_collection_index(index, ctx, src, *span, refs)
}
ResolvedTypeExpr::Indexed { base, indexes } => {
collect_resolved_collection_indexes_from_type(base, ctx, src, refs)?;
for index in indexes {
if let ResolvedIndex::Concrete(resolved, span) = index {
record_resolved_collection_index(resolved, ctx, src, *span, refs)?;
}
}
Ok(())
}
ResolvedTypeExpr::GenericStruct { type_args, .. } => {
for arg in type_args {
collect_resolved_collection_indexes_from_type(arg, ctx, src, refs)?;
}
Ok(())
}
ResolvedTypeExpr::Dimensionless
| ResolvedTypeExpr::Bool
| ResolvedTypeExpr::Int
| ResolvedTypeExpr::Datetime(_)
| ResolvedTypeExpr::IndexArg(_)
| ResolvedTypeExpr::Scalar(_)
| ResolvedTypeExpr::Struct(_, _)
| ResolvedTypeExpr::GenericDimParam(_, _)
| ResolvedTypeExpr::GenericTypeParam(_, _)
| ResolvedTypeExpr::GenericDimExpr { .. } => Ok(()),
}
}
fn collect_resolved_collection_refs_from_expr(
expr: &hir::Expr,
ctx: ModuleTypeContext<'_>,
src: &NamedSource<Arc<String>>,
refs: &mut ResolvedCollectionRefs,
) -> Result<(), GraphcalError> {
crate::stack::with_stack_growth(|| {
collect_resolved_collection_refs_from_expr_inner(expr, ctx, src, refs)
})
}
#[expect(
clippy::too_many_lines,
reason = "expression traversal mirrors HIR variants"
)]
fn collect_resolved_collection_refs_from_expr_inner(
expr: &hir::Expr,
ctx: ModuleTypeContext<'_>,
src: &NamedSource<Arc<String>>,
refs: &mut ResolvedCollectionRefs,
) -> Result<(), GraphcalError> {
match &expr.kind {
hir::ExprKind::Error
| hir::ExprKind::Number(_)
| hir::ExprKind::Integer(_)
| hir::ExprKind::Bool(_)
| hir::ExprKind::StringLiteral(_)
| hir::ExprKind::TypeSystemRef(_)
| hir::ExprKind::GraphRef(_)
| hir::ExprKind::LocalRef(_)
| hir::ExprKind::ConstRef(_)
| hir::ExprKind::UnitLiteral { .. } => Ok(()),
hir::ExprKind::VariantLiteral(variant) => record_resolved_collection_index(
variant.variant.index(),
ctx,
src,
variant.path_span(),
refs,
),
hir::ExprKind::BinOp { lhs, rhs, .. } => {
collect_resolved_collection_refs_from_expr(lhs, ctx, src, refs)?;
collect_resolved_collection_refs_from_expr(rhs, ctx, src, refs)
}
hir::ExprKind::UnaryOp { operand, .. } => {
collect_resolved_collection_refs_from_expr(operand, ctx, src, refs)
}
hir::ExprKind::FnCall { args, .. } => {
for arg in args {
collect_resolved_collection_refs_from_expr(arg, ctx, src, refs)?;
}
Ok(())
}
hir::ExprKind::If {
condition,
then_branch,
else_branch,
} => {
collect_resolved_collection_refs_from_expr(condition, ctx, src, refs)?;
collect_resolved_collection_refs_from_expr(then_branch, ctx, src, refs)?;
collect_resolved_collection_refs_from_expr(else_branch, ctx, src, refs)
}
hir::ExprKind::Convert { expr, .. }
| hir::ExprKind::DisplayTimezone { expr, .. }
| hir::ExprKind::FieldAccess { expr, .. } => {
collect_resolved_collection_refs_from_expr(expr, ctx, src, refs)
}
hir::ExprKind::ConstructorCall { fields, .. } => {
for field in fields {
collect_resolved_collection_refs_from_expr(&field.value, ctx, src, refs)?;
}
Ok(())
}
hir::ExprKind::MapLiteral { entries } => {
for entry in entries {
for key in &entry.keys {
match key {
hir::expr::MapEntryKey::IndexVariant(variant) => {
record_resolved_collection_index(
variant.variant.index(),
ctx,
src,
variant.variant_span,
refs,
)?;
}
hir::expr::MapEntryKey::NatRangeVariant { .. } => {}
}
}
collect_resolved_collection_refs_from_expr(&entry.value, ctx, src, refs)?;
}
Ok(())
}
hir::ExprKind::ForComp { bindings, body } => {
for binding in bindings {
match &binding.index {
hir::expr::ForBindingIndex::Named(index) => {
record_resolved_collection_index(&index.value, ctx, src, index.span, refs)?;
}
hir::expr::ForBindingIndex::Range { .. } => {}
}
}
collect_resolved_collection_refs_from_expr(body, ctx, src, refs)
}
hir::ExprKind::IndexAccess { expr, args } => {
collect_resolved_collection_refs_from_expr(expr, ctx, src, refs)?;
for arg in args {
match arg {
hir::expr::IndexArg::Variant(variant) => {
record_resolved_collection_index(
variant.variant.index(),
ctx,
src,
variant.path_span(),
refs,
)?;
}
hir::expr::IndexArg::Expr(expr) => {
collect_resolved_collection_refs_from_expr(expr, ctx, src, refs)?;
}
hir::expr::IndexArg::Var(_) => {}
}
}
Ok(())
}
hir::ExprKind::Scan {
source, init, body, ..
} => {
collect_resolved_collection_refs_from_expr(source, ctx, src, refs)?;
collect_resolved_collection_refs_from_expr(init, ctx, src, refs)?;
collect_resolved_collection_refs_from_expr(body, ctx, src, refs)
}
hir::ExprKind::Unfold { init, body, .. } => {
collect_resolved_collection_refs_from_expr(init, ctx, src, refs)?;
collect_resolved_collection_refs_from_expr(body, ctx, src, refs)
}
hir::ExprKind::Match { scrutinee, arms } => {
collect_resolved_collection_refs_from_expr(scrutinee, ctx, src, refs)?;
for arm in arms {
if let hir::expr::MatchPattern::IndexLabel { variant, span: _ } = &arm.pattern {
record_resolved_collection_index(
variant.variant.index(),
ctx,
src,
variant.path_span(),
refs,
)?;
}
collect_resolved_collection_refs_from_expr(&arm.body, ctx, src, refs)?;
}
Ok(())
}
hir::ExprKind::InlineDagRef { args, .. } => {
for arg in args {
collect_resolved_collection_refs_from_expr(&arg.value, ctx, src, refs)?;
}
Ok(())
}
}
}
fn collect_resolved_collection_refs_from_assert_body(
body: &hir::AssertBody,
ctx: ModuleTypeContext<'_>,
src: &NamedSource<Arc<String>>,
refs: &mut ResolvedCollectionRefs,
) -> Result<(), GraphcalError> {
match body {
hir::AssertBody::Expr(expr) => {
collect_resolved_collection_refs_from_expr(expr, ctx, src, refs)
}
hir::AssertBody::Tolerance {
actual,
expected,
tolerance,
is_relative: _,
} => {
collect_resolved_collection_refs_from_expr(actual, ctx, src, refs)?;
collect_resolved_collection_refs_from_expr(expected, ctx, src, refs)?;
collect_resolved_collection_refs_from_expr(tolerance, ctx, src, refs)
}
}
}
fn collect_resolved_constructor_refs(
exprs: &ResolvedExpressions,
domain_bounds: &HashMap<ResolvedName<namespace::Decl>, Vec<ResolvedDomainBound>>,
ctx: ModuleTypeContext<'_>,
src: &NamedSource<Arc<String>>,
) -> Result<ResolvedConstructorRefs, GraphcalError> {
let mut refs = ResolvedConstructorRefs::default();
for hir_expr in exprs
.consts
.values()
.chain(exprs.param_defaults.values())
.chain(exprs.nodes.values())
.chain(domain_bounds.values().flatten().map(|bound| &bound.value))
{
collect_resolved_constructor_refs_from_expr(hir_expr, ctx, src, &mut refs)?;
}
for body in exprs.asserts.values() {
collect_resolved_constructor_refs_from_assert_body(body, ctx, src, &mut refs)?;
}
Ok(refs)
}
fn record_resolved_constructor_target(
constructor: &ResolvedName<namespace::Constructor>,
ctx: ModuleTypeContext<'_>,
src: &NamedSource<Arc<String>>,
span: Span,
refs: &mut ResolvedConstructorRefs,
) -> Result<ResolvedConstructorTarget, GraphcalError> {
if let Some(target) = refs.constructor_defs.get(constructor) {
return Ok(target.clone());
}
let def = ctx.types.lookup_constructor(constructor).ok_or_else(|| {
internal_error(
format!("semantic constructor metadata references unknown constructor `{constructor}`"),
src,
span,
)
})?;
let target = ResolvedConstructorTarget {
constructor: constructor.clone(),
owning_type: def.owning_type.clone(),
type_def: def.type_def.clone(),
variant: def.variant.clone(),
};
refs.constructor_defs
.insert(constructor.clone(), target.clone());
Ok(target)
}
fn collect_resolved_constructor_refs_from_expr(
expr: &hir::Expr,
ctx: ModuleTypeContext<'_>,
src: &NamedSource<Arc<String>>,
refs: &mut ResolvedConstructorRefs,
) -> Result<(), GraphcalError> {
crate::stack::with_stack_growth(|| {
collect_resolved_constructor_refs_from_expr_inner(expr, ctx, src, refs)
})
}
#[expect(
clippy::too_many_lines,
reason = "expression traversal mirrors HIR variants"
)]
fn collect_resolved_constructor_refs_from_expr_inner(
expr: &hir::Expr,
ctx: ModuleTypeContext<'_>,
src: &NamedSource<Arc<String>>,
refs: &mut ResolvedConstructorRefs,
) -> Result<(), GraphcalError> {
match &expr.kind {
hir::ExprKind::Error
| hir::ExprKind::Number(_)
| hir::ExprKind::Integer(_)
| hir::ExprKind::Bool(_)
| hir::ExprKind::StringLiteral(_)
| hir::ExprKind::TypeSystemRef(_)
| hir::ExprKind::GraphRef(_)
| hir::ExprKind::LocalRef(_)
| hir::ExprKind::UnitLiteral { .. }
| hir::ExprKind::VariantLiteral(_) => Ok(()),
hir::ExprKind::ConstRef(target) => {
if let hir::ConstRef::Constructor(constructor) = &target.value {
record_resolved_constructor_target(constructor, ctx, src, target.span, refs)?;
}
Ok(())
}
hir::ExprKind::BinOp { lhs, rhs, .. } => {
collect_resolved_constructor_refs_from_expr(lhs, ctx, src, refs)?;
collect_resolved_constructor_refs_from_expr(rhs, ctx, src, refs)
}
hir::ExprKind::UnaryOp { operand, .. } => {
collect_resolved_constructor_refs_from_expr(operand, ctx, src, refs)
}
hir::ExprKind::FnCall { args, .. } => {
for arg in args {
collect_resolved_constructor_refs_from_expr(arg, ctx, src, refs)?;
}
Ok(())
}
hir::ExprKind::If {
condition,
then_branch,
else_branch,
} => {
collect_resolved_constructor_refs_from_expr(condition, ctx, src, refs)?;
collect_resolved_constructor_refs_from_expr(then_branch, ctx, src, refs)?;
collect_resolved_constructor_refs_from_expr(else_branch, ctx, src, refs)
}
hir::ExprKind::Convert { expr, .. }
| hir::ExprKind::DisplayTimezone { expr, .. }
| hir::ExprKind::FieldAccess { expr, .. } => {
collect_resolved_constructor_refs_from_expr(expr, ctx, src, refs)
}
hir::ExprKind::ConstructorCall { callee, fields, .. } => {
record_resolved_constructor_target(&callee.value, ctx, src, callee.span, refs)?;
for field in fields {
collect_resolved_constructor_refs_from_expr(&field.value, ctx, src, refs)?;
}
Ok(())
}
hir::ExprKind::MapLiteral { entries } => {
for entry in entries {
collect_resolved_constructor_refs_from_expr(&entry.value, ctx, src, refs)?;
}
Ok(())
}
hir::ExprKind::ForComp { body, .. } => {
collect_resolved_constructor_refs_from_expr(body, ctx, src, refs)
}
hir::ExprKind::IndexAccess { expr, args } => {
collect_resolved_constructor_refs_from_expr(expr, ctx, src, refs)?;
for arg in args {
if let hir::expr::IndexArg::Expr(expr) = arg {
collect_resolved_constructor_refs_from_expr(expr, ctx, src, refs)?;
}
}
Ok(())
}
hir::ExprKind::Scan {
source, init, body, ..
} => {
collect_resolved_constructor_refs_from_expr(source, ctx, src, refs)?;
collect_resolved_constructor_refs_from_expr(init, ctx, src, refs)?;
collect_resolved_constructor_refs_from_expr(body, ctx, src, refs)
}
hir::ExprKind::Unfold { init, body, .. } => {
collect_resolved_constructor_refs_from_expr(init, ctx, src, refs)?;
collect_resolved_constructor_refs_from_expr(body, ctx, src, refs)
}
hir::ExprKind::Match { scrutinee, arms } => {
collect_resolved_constructor_refs_from_expr(scrutinee, ctx, src, refs)?;
for arm in arms {
if let hir::expr::MatchPattern::Constructor { constructor, .. } = &arm.pattern {
record_resolved_constructor_target(
&constructor.value,
ctx,
src,
constructor.span,
refs,
)?;
}
collect_resolved_constructor_refs_from_expr(&arm.body, ctx, src, refs)?;
}
Ok(())
}
hir::ExprKind::InlineDagRef { args, .. } => {
for arg in args {
collect_resolved_constructor_refs_from_expr(&arg.value, ctx, src, refs)?;
}
Ok(())
}
}
}
fn collect_resolved_constructor_refs_from_assert_body(
body: &hir::AssertBody,
ctx: ModuleTypeContext<'_>,
src: &NamedSource<Arc<String>>,
refs: &mut ResolvedConstructorRefs,
) -> Result<(), GraphcalError> {
match body {
hir::AssertBody::Expr(expr) => {
collect_resolved_constructor_refs_from_expr(expr, ctx, src, refs)
}
hir::AssertBody::Tolerance {
actual,
expected,
tolerance,
is_relative: _,
} => {
collect_resolved_constructor_refs_from_expr(actual, ctx, src, refs)?;
collect_resolved_constructor_refs_from_expr(expected, ctx, src, refs)?;
collect_resolved_constructor_refs_from_expr(tolerance, ctx, src, refs)
}
}
}
fn collect_resolved_inline_dag_refs(exprs: &ResolvedExpressions) -> ResolvedInlineDagRefs {
let mut refs = ResolvedInlineDagRefs::default();
for hir_expr in exprs
.consts
.values()
.chain(exprs.param_defaults.values())
.chain(exprs.nodes.values())
{
collect_resolved_inline_dag_refs_from_expr(hir_expr, &mut refs);
}
for body in exprs.asserts.values() {
collect_resolved_inline_dag_refs_from_assert_body(body, &mut refs);
}
refs
}
fn collect_resolved_inline_dag_refs_from_expr(expr: &hir::Expr, refs: &mut ResolvedInlineDagRefs) {
crate::stack::with_stack_growth(|| {
collect_resolved_inline_dag_refs_from_expr_inner(expr, refs);
});
}
fn collect_resolved_inline_dag_refs_from_expr_inner(
expr: &hir::Expr,
refs: &mut ResolvedInlineDagRefs,
) {
match &expr.kind {
hir::ExprKind::Error
| hir::ExprKind::Number(_)
| hir::ExprKind::Integer(_)
| hir::ExprKind::Bool(_)
| hir::ExprKind::StringLiteral(_)
| hir::ExprKind::TypeSystemRef(_)
| hir::ExprKind::GraphRef(_)
| hir::ExprKind::ConstRef(_)
| hir::ExprKind::LocalRef(_)
| hir::ExprKind::UnitLiteral { .. }
| hir::ExprKind::VariantLiteral(_) => {}
hir::ExprKind::BinOp { lhs, rhs, .. } => {
collect_resolved_inline_dag_refs_from_expr(lhs, refs);
collect_resolved_inline_dag_refs_from_expr(rhs, refs);
}
hir::ExprKind::UnaryOp { operand, .. }
| hir::ExprKind::Convert { expr: operand, .. }
| hir::ExprKind::DisplayTimezone { expr: operand, .. }
| hir::ExprKind::FieldAccess { expr: operand, .. } => {
collect_resolved_inline_dag_refs_from_expr(operand, refs);
}
hir::ExprKind::FnCall { args, .. } => {
for arg in args {
collect_resolved_inline_dag_refs_from_expr(arg, refs);
}
}
hir::ExprKind::If {
condition,
then_branch,
else_branch,
} => {
collect_resolved_inline_dag_refs_from_expr(condition, refs);
collect_resolved_inline_dag_refs_from_expr(then_branch, refs);
collect_resolved_inline_dag_refs_from_expr(else_branch, refs);
}
hir::ExprKind::ConstructorCall { fields, .. } => {
for field in fields {
collect_resolved_inline_dag_refs_from_expr(&field.value, refs);
}
}
hir::ExprKind::MapLiteral { entries } => {
for entry in entries {
collect_resolved_inline_dag_refs_from_expr(&entry.value, refs);
}
}
hir::ExprKind::ForComp { body, .. } => {
collect_resolved_inline_dag_refs_from_expr(body, refs);
}
hir::ExprKind::IndexAccess { expr, args } => {
collect_resolved_inline_dag_refs_from_expr(expr, refs);
for arg in args {
if let hir::expr::IndexArg::Expr(expr) = arg {
collect_resolved_inline_dag_refs_from_expr(expr, refs);
}
}
}
hir::ExprKind::Scan {
source, init, body, ..
} => {
collect_resolved_inline_dag_refs_from_expr(source, refs);
collect_resolved_inline_dag_refs_from_expr(init, refs);
collect_resolved_inline_dag_refs_from_expr(body, refs);
}
hir::ExprKind::Unfold { init, body, .. } => {
collect_resolved_inline_dag_refs_from_expr(init, refs);
collect_resolved_inline_dag_refs_from_expr(body, refs);
}
hir::ExprKind::Match { scrutinee, arms } => {
collect_resolved_inline_dag_refs_from_expr(scrutinee, refs);
for arm in arms {
collect_resolved_inline_dag_refs_from_expr(&arm.body, refs);
}
}
hir::ExprKind::InlineDagRef {
target,
args,
output,
} => {
let arg_targets = args
.iter()
.map(|arg| (arg.target.span, arg.target.value.clone()))
.collect();
refs.calls.insert(
expr.span,
ResolvedInlineDagCall {
target: target.value.clone(),
arg_targets,
output: output.clone(),
},
);
for arg in args {
collect_resolved_inline_dag_refs_from_expr(&arg.value, refs);
}
}
}
}
fn collect_resolved_inline_dag_refs_from_assert_body(
body: &hir::AssertBody,
refs: &mut ResolvedInlineDagRefs,
) {
match body {
hir::AssertBody::Expr(expr) => collect_resolved_inline_dag_refs_from_expr(expr, refs),
hir::AssertBody::Tolerance {
actual,
expected,
tolerance,
is_relative: _,
} => {
collect_resolved_inline_dag_refs_from_expr(actual, refs);
collect_resolved_inline_dag_refs_from_expr(expected, refs);
collect_resolved_inline_dag_refs_from_expr(tolerance, refs);
}
}
}
fn collect_hir_decl_bindings(
owner: &crate::dag_id::DagId,
consts: &[crate::ir::lower::ConstEntry],
params: &[crate::ir::lower::ParamEntry],
nodes: &[crate::ir::lower::NodeEntry],
imported_value_sources: &HashMap<ScopedName, crate::ir::lower::ImportedValueSource>,
src: &NamedSource<Arc<String>>,
) -> Result<HashMap<ScopedName, ResolvedName<namespace::Decl>>, GraphcalError> {
let mut bindings = HashMap::new();
for name in consts
.iter()
.map(|entry| &entry.name)
.chain(params.iter().map(|entry| &entry.name))
.chain(nodes.iter().map(|entry| &entry.name))
{
let resolved = resolved_decl_key(owner, name).ok_or_else(|| {
internal_error(
format!("could not build canonical declaration key for `{name}`"),
src,
Span::new(0, 0),
)
})?;
bindings.insert(name.clone(), resolved);
}
for (name, source) in imported_value_sources {
bindings.insert(
name.clone(),
ResolvedName::from_def(source.dag_id.clone(), source.source_name.clone()),
);
}
Ok(bindings)
}
#[expect(
clippy::too_many_arguments,
reason = "collects local and imported declaration binding sources for a completed DAG"
)]
fn collect_resolved_decl_bindings(
ctx: ModuleTypeContext<'_>,
consts: &[crate::ir::lower::ConstEntry],
params: &[crate::ir::lower::ParamEntry],
nodes: &[crate::ir::lower::NodeEntry],
imported_values: &HashMap<
ScopedName,
(
crate::registry::runtime_value::RuntimeValue,
crate::registry::declared_type::DeclaredType,
),
>,
imported_decl_types: &HashMap<ScopedName, crate::registry::declared_type::DeclaredType>,
imported_value_sources: &HashMap<ScopedName, crate::ir::lower::ImportedValueSource>,
src: &NamedSource<Arc<String>>,
) -> Result<HashMap<ScopedName, ResolvedName<namespace::Decl>>, GraphcalError> {
let mut bindings = collect_hir_decl_bindings(
ctx.owner,
consts,
params,
nodes,
imported_value_sources,
src,
)?;
for name in imported_values
.keys()
.chain(imported_decl_types.keys())
.chain(imported_value_sources.keys())
{
if bindings.contains_key(name) {
continue;
}
let path = scoped_name_to_name_path(name).ok_or_else(|| {
internal_error(
format!("could not convert visible declaration `{name}` to a name path"),
src,
Span::new(0, 0),
)
})?;
let resolved = match ctx.resolver.resolve_decl_path(ctx.owner, &path) {
Ok(resolved) => resolved,
Err(_err)
if imported_values.contains_key(name) || imported_decl_types.contains_key(name) =>
{
let synthetic_owner = name
.qualifier()
.iter()
.fold(ctx.owner.clone(), |owner, segment| {
owner.child(segment.as_ref())
});
ResolvedName::from_def(synthetic_owner, DeclName::new(name.member()))
}
Err(err) => return Err(module_resolve_error(&err, src, Span::new(0, 0))),
};
bindings.insert(name.clone(), resolved);
}
Ok(bindings)
}
fn scoped_name_to_name_path(name: &ScopedName) -> Option<NamePath> {
let qualifier = name
.qualifier()
.iter()
.map(|segment| NameAtom::parse(segment.as_ref()).ok())
.collect::<Option<Vec<_>>>()?;
let leaf = NameAtom::parse(name.member()).ok()?;
Some(if qualifier.is_empty() {
NamePath::local(leaf)
} else {
NamePath::qualified_path(qualifier, leaf)
})
}
fn resolve_expected_fail_keys(
expected_fail: HashMap<ScopedName, ExpectedFail>,
ctx: ModuleTypeContext<'_>,
src: &NamedSource<Arc<String>>,
) -> Result<HashMap<ScopedName, ExpectedFail>, GraphcalError> {
expected_fail
.into_iter()
.map(|(assert_name, expected)| {
let resolved = match expected {
ExpectedFail::All => ExpectedFail::All,
ExpectedFail::Variants(keys) => {
let resolved_keys = keys
.into_iter()
.map(|key| {
key.into_iter()
.map(|part| {
let Some(index_path) = part.source_index_path().cloned() else {
return Ok(part);
};
let resolved = ctx
.resolver
.resolve_index_variant_parts(
ctx.owner,
&index_path,
&part.variant(),
)
.map_err(|err| {
module_resolve_error(&err, src, part.span())
})?;
Ok(part.with_resolved_variant(resolved))
})
.collect::<Result<_, GraphcalError>>()
})
.collect::<Result<_, GraphcalError>>()?;
ExpectedFail::Variants(resolved_keys)
}
};
Ok((assert_name, resolved))
})
.collect()
}
struct DagTIRSeed {
dag_id: crate::dag_id::DagId,
consts: Vec<crate::ir::lower::ConstEntry>,
params: Vec<crate::ir::lower::ParamEntry>,
nodes: Vec<crate::ir::lower::NodeEntry>,
resolved_decl_types: HashMap<ScopedName, ResolvedTypeExpr>,
semantic: DagSemanticBody,
}
impl DagTIRSeed {
#[expect(
clippy::too_many_arguments,
reason = "single conversion that absorbs every IR field beyond the resolved decls"
)]
fn with_body(
self,
asserts: Vec<crate::ir::lower::AssertEntry>,
plots: Vec<crate::ir::lower::PlotEntry>,
figures: Vec<crate::ir::lower::FigureEntry>,
layers: Vec<crate::ir::lower::LayerEntry>,
included_plots: Vec<crate::ir::lower::IncludedPlotEntry>,
source_order: Vec<(ScopedName, DeclCategory)>,
assert_names: std::collections::HashSet<ScopedName>,
assumes_map: HashMap<ScopedName, Vec<ScopedName>>,
expected_fail: HashMap<ScopedName, ExpectedFail>,
imported_values: HashMap<
ScopedName,
(
crate::registry::runtime_value::RuntimeValue,
crate::registry::declared_type::DeclaredType,
),
>,
imported_decl_types: HashMap<ScopedName, crate::registry::declared_type::DeclaredType>,
imported_value_sources: HashMap<ScopedName, crate::ir::lower::ImportedValueSource>,
module_ctx: ModuleTypeContext<'_>,
src: &NamedSource<Arc<String>>,
) -> Result<DagTIR, GraphcalError> {
let decl_bindings = collect_resolved_decl_bindings(
module_ctx,
&self.consts,
&self.params,
&self.nodes,
&imported_values,
&imported_decl_types,
&imported_value_sources,
src,
)?;
let expected_fail = resolve_expected_fail_keys(expected_fail, module_ctx, src)?;
let mut semantic = self.semantic;
semantic.decl_bindings = decl_bindings;
collect_plot_exprs(&plots, &figures, &layers, module_ctx, src, &mut semantic)?;
Ok(DagTIR {
dag_id: self.dag_id,
consts: self.consts,
params: self.params,
nodes: self.nodes,
asserts,
plots,
figures,
layers,
included_plots,
semantic,
source_order,
assert_names,
assumes_map,
expected_fail,
resolved_decl_types: self.resolved_decl_types,
domain_constraints: HashMap::new(), imported_values,
imported_decl_types,
imported_value_sources,
pub_nodes: std::collections::HashSet::new(),
})
}
}
pub fn resolved_to_declared_type(
resolved: &ResolvedTypeExpr,
src: &NamedSource<Arc<String>>,
) -> Result<crate::registry::declared_type::DeclaredType, GraphcalError> {
use crate::registry::declared_type::{DeclaredType, StructTypeRef};
match resolved {
ResolvedTypeExpr::Dimensionless => Ok(DeclaredType::Scalar(Dimension::dimensionless())),
ResolvedTypeExpr::Bool => Ok(DeclaredType::Bool),
ResolvedTypeExpr::Int => Ok(DeclaredType::Int),
ResolvedTypeExpr::Datetime(scale) => Ok(DeclaredType::Datetime(*scale)),
ResolvedTypeExpr::IndexArg(index) => Err(GraphcalError::EvalError {
message: format!(
"index `{}` cannot be used as a value type",
format_resolved_index(index)
),
src: src.clone(),
span: resolved_type_expr_span(resolved).into(),
}),
ResolvedTypeExpr::Scalar(dim) => Ok(DeclaredType::Scalar(dim.clone())),
ResolvedTypeExpr::Struct(name, _) => Ok(DeclaredType::Struct(
StructTypeRef::from_resolved(name.clone()),
vec![],
)),
ResolvedTypeExpr::GenericStruct {
name, type_args, ..
} => {
let mut declared_args = Vec::with_capacity(type_args.len());
for arg in type_args {
declared_args.push(resolved_type_arg_to_declared_type(arg, src)?);
}
Ok(DeclaredType::Struct(
StructTypeRef::from_resolved(name.clone()),
declared_args,
))
}
ResolvedTypeExpr::GenericDimParam(name, span) => Err(GraphcalError::EvalError {
message: format!("cannot use generic dimension parameter `{name}` as a concrete type"),
src: src.clone(),
span: (*span).into(),
}),
ResolvedTypeExpr::GenericTypeParam(name, span) => Err(GraphcalError::EvalError {
message: format!("cannot use generic type parameter `{name}` as a concrete type"),
src: src.clone(),
span: (*span).into(),
}),
ResolvedTypeExpr::GenericDimExpr { span, .. } => Err(GraphcalError::EvalError {
message: "cannot use generic dimension expression as a concrete type".to_string(),
src: src.clone(),
span: (*span).into(),
}),
ResolvedTypeExpr::Indexed { base, indexes } => {
let mut result = resolved_to_declared_type(base, src)?;
for idx in indexes.iter().rev() {
match idx {
ResolvedIndex::Concrete(name, _) => {
result = DeclaredType::Indexed {
element: Box::new(result),
index: IndexTypeRef::from_resolved(name.clone()),
};
}
ResolvedIndex::NatExpr(form, span) => {
if !form.is_constant() {
return Err(GraphcalError::EvalError {
message: format!(
"cannot use generic nat expression `{}` as a concrete type",
form.format()
),
src: src.clone(),
span: (*span).into(),
});
}
let nat_range =
crate::registry::types::NatRangeIndex::try_from_u64(form.constant())
.map_err(|err| GraphcalError::EvalError {
message: err.to_string(),
src: src.clone(),
span: (*span).into(),
})?;
result = DeclaredType::Indexed {
element: Box::new(result),
index: IndexTypeRef::from_nat_range(nat_range),
};
}
ResolvedIndex::GenericParam(name, span) => {
return Err(GraphcalError::EvalError {
message: format!(
"cannot use generic index parameter `{name}` as a concrete type"
),
src: src.clone(),
span: (*span).into(),
});
}
}
}
Ok(result)
}
}
}
fn resolved_type_arg_to_declared_type(
resolved: &ResolvedTypeExpr,
src: &NamedSource<Arc<String>>,
) -> Result<crate::registry::declared_type::DeclaredType, GraphcalError> {
match resolved {
ResolvedTypeExpr::IndexArg(index) => resolved_index_to_declared_arg(index, src),
_ => resolved_to_declared_type(resolved, src),
}
}
fn resolved_type_expr_span(resolved: &ResolvedTypeExpr) -> Span {
match resolved {
ResolvedTypeExpr::Dimensionless
| ResolvedTypeExpr::Bool
| ResolvedTypeExpr::Int
| ResolvedTypeExpr::Datetime(_)
| ResolvedTypeExpr::Scalar(_) => Span::new(0, 0),
ResolvedTypeExpr::IndexArg(index) => resolved_index_span(index),
ResolvedTypeExpr::Struct(_, span)
| ResolvedTypeExpr::GenericDimParam(_, span)
| ResolvedTypeExpr::GenericTypeParam(_, span)
| ResolvedTypeExpr::GenericDimExpr { span, .. }
| ResolvedTypeExpr::GenericStruct { span, .. } => *span,
ResolvedTypeExpr::Indexed { base, .. } => resolved_type_expr_span(base),
}
}
const fn resolved_index_span(index: &ResolvedIndex) -> Span {
match index {
ResolvedIndex::Concrete(_, span)
| ResolvedIndex::GenericParam(_, span)
| ResolvedIndex::NatExpr(_, span) => *span,
}
}
fn resolved_index_to_declared_arg(
index: &ResolvedIndex,
src: &NamedSource<Arc<String>>,
) -> Result<crate::registry::declared_type::DeclaredType, GraphcalError> {
let reference = match index {
ResolvedIndex::Concrete(name, _) => IndexTypeRef::from_resolved(name.clone()),
ResolvedIndex::NatExpr(form, span) => IndexTypeRef::from_nat_range_form(form.clone())
.map_err(|err| GraphcalError::EvalError {
message: err.to_string(),
src: src.clone(),
span: (*span).into(),
})?,
ResolvedIndex::GenericParam(name, span) => {
return Err(GraphcalError::EvalError {
message: format!("generic index parameter `{name}` is not bound"),
src: src.clone(),
span: (*span).into(),
});
}
};
Ok(crate::registry::declared_type::DeclaredType::IndexArg(
reference,
))
}
fn resolved_index_to_inferred(
index: &ResolvedIndex,
src: &NamedSource<Arc<String>>,
) -> Result<crate::tir::dim_check::InferredIndex, GraphcalError> {
let reference = match index {
ResolvedIndex::Concrete(name, _) => IndexTypeRef::from_resolved(name.clone()),
ResolvedIndex::NatExpr(form, span) => IndexTypeRef::from_nat_range_form(form.clone())
.map_err(|err| GraphcalError::EvalError {
message: err.to_string(),
src: src.clone(),
span: (*span).into(),
})?,
ResolvedIndex::GenericParam(name, span) => {
return Err(GraphcalError::EvalError {
message: format!("generic index parameter `{name}` is not bound"),
src: src.clone(),
span: (*span).into(),
});
}
};
Ok(crate::tir::dim_check::InferredIndex::from_ref(reference))
}
fn resolved_index_matches_inferred(
expected: &ResolvedIndex,
actual: &crate::tir::dim_check::InferredIndex,
) -> bool {
match expected {
ResolvedIndex::Concrete(name, _) => actual.matches_resolved(name),
ResolvedIndex::GenericParam(_, _) => false,
ResolvedIndex::NatExpr(form, _) => actual.nat_range_form().as_ref() == Some(form),
}
}
fn resolved_index_display_name(index: &ResolvedIndex) -> IndexName {
match index {
ResolvedIndex::Concrete(name, _) => name.to_unowned_def_name(),
ResolvedIndex::GenericParam(name, _) => IndexName::from_atom(name.atom().clone()),
ResolvedIndex::NatExpr(form, _) => IndexName::new(format!("range({})", form.format())),
}
}
fn unify_nat_poly_form(
form: &NatPolyForm,
target: u64,
nat_sub: &mut HashMap<GenericParamName, u64>,
actual_idx: &IndexName,
src: &NamedSource<Arc<String>>,
span: Span,
) -> Result<(), GraphcalError> {
let mut reduced_constant: u64 = 0;
let mut reduced_terms: BTreeMap<Monomial, u64> = BTreeMap::new();
let form_mismatch = || GraphcalError::IndexMismatch {
expected: IndexName::new(format!("range({})", form.format())),
found: actual_idx.clone(),
src: src.clone(),
span: span.into(),
};
for (mono, coeff) in &form.terms {
let (remaining_mono, factor) = mono.substitute(nat_sub).ok_or_else(form_mismatch)?;
let term_value = coeff.checked_mul(factor).ok_or_else(form_mismatch)?;
if remaining_mono.is_constant() {
reduced_constant = reduced_constant
.checked_add(term_value)
.ok_or_else(form_mismatch)?;
} else {
let entry = reduced_terms.entry(remaining_mono).or_insert(0);
*entry = entry.checked_add(term_value).ok_or_else(form_mismatch)?;
}
}
reduced_terms.retain(|_, c| *c != 0);
if reduced_terms.is_empty() {
if reduced_constant != target {
let expected = match form.evaluate(nat_sub) {
Some(n) => crate::registry::types::NatRangeIndex::try_from_u64(n)
.map_err(|err| GraphcalError::EvalError {
message: err.to_string(),
src: src.clone(),
span: span.into(),
})?
.display_name(),
None => IndexName::new(format!("range({})", form.format())),
};
return Err(GraphcalError::IndexMismatch {
expected,
found: actual_idx.clone(),
src: src.clone(),
span: span.into(),
});
}
return Ok(());
}
let mut unbound_vars = std::collections::BTreeSet::new();
for mono in reduced_terms.keys() {
for var in mono.0.keys() {
unbound_vars.insert(var.clone());
}
}
if let [var] = unbound_vars.iter().collect::<Vec<_>>().as_slice() {
let var = (*var).clone();
let all_linear = reduced_terms
.keys()
.all(|m| m.0.len() == 1 && m.0.get(&var) == Some(&1));
if all_linear {
let total_coeff = reduced_terms
.values()
.try_fold(0u64, |acc, c| acc.checked_add(*c))
.ok_or_else(form_mismatch)?;
if target < reduced_constant {
return Err(form_mismatch());
}
let remainder = target - reduced_constant;
if total_coeff == 0 || !remainder.is_multiple_of(total_coeff) {
return Err(form_mismatch());
}
let value = remainder / total_coeff;
bind_or_check(nat_sub, var, value, |prev, _| {
match crate::registry::types::NatRangeIndex::try_from_u64(*prev) {
Ok(index) => GraphcalError::IndexMismatch {
expected: index.display_name(),
found: actual_idx.clone(),
src: src.clone(),
span: span.into(),
},
Err(err) => GraphcalError::EvalError {
message: err.to_string(),
src: src.clone(),
span: span.into(),
},
}
})?;
return Ok(());
}
}
let var_names: Vec<&str> = unbound_vars.iter().map(GenericParamName::as_str).collect();
Err(GraphcalError::EvalError {
message: format!(
"cannot infer Nat parameters [{}] from a single index — \
provide more arguments or use explicit type annotations",
var_names.join(", ")
),
src: src.clone(),
span: span.into(),
})
}
fn bind_or_check<K, V, E>(
sub: &mut HashMap<K, V>,
key: K,
value: V,
on_conflict: impl FnOnce(&V, &V) -> E,
) -> Result<(), E>
where
K: Eq + std::hash::Hash,
V: PartialEq,
{
if let Some(prev) = sub.get(&key) {
if *prev != value {
return Err(on_conflict(prev, &value));
}
} else {
sub.insert(key, value);
}
Ok(())
}
#[expect(
clippy::too_many_lines,
reason = "complex generic unification requires many match arms"
)]
#[expect(
clippy::implicit_hasher,
reason = "always called with standard HashMap"
)]
#[expect(
clippy::too_many_arguments,
reason = "unification needs all substitution maps, registry, and source context"
)]
pub fn unify_resolved_type(
resolved: &ResolvedTypeExpr,
actual: &crate::tir::dim_check::InferredType,
dim_sub: &mut HashMap<GenericParamName, Dimension>,
index_sub: &mut HashMap<GenericParamName, IndexTypeRef>,
nat_sub: &mut HashMap<GenericParamName, u64>,
registry: &Registry,
src: &NamedSource<Arc<String>>,
span: Span,
) -> Result<(), GraphcalError> {
use crate::tir::dim_check::InferredType;
match resolved {
ResolvedTypeExpr::Indexed { base, indexes } => {
let mut current = actual;
for idx in indexes {
let InferredType::Indexed {
element,
index: actual_idx,
} = current
else {
return Err(GraphcalError::DimensionMismatch {
expected: "indexed type".to_string(),
found: crate::tir::dim_check::format_inferred_type(current, registry),
help: "expected an indexed value".to_string(),
src: src.clone(),
span: span.into(),
});
};
match idx {
ResolvedIndex::GenericParam(gp, _) => {
bind_or_check(
index_sub,
gp.clone(),
actual_idx.type_ref().clone(),
|prev, _| GraphcalError::IndexMismatch {
expected: prev.display_name(),
found: actual_idx.name(),
src: src.clone(),
span: span.into(),
},
)?;
}
ResolvedIndex::Concrete(name, _) => {
if !actual_idx.matches_resolved(name) {
return Err(GraphcalError::IndexMismatch {
expected: name.to_unowned_def_name(),
found: actual_idx.name(),
src: src.clone(),
span: span.into(),
});
}
}
ResolvedIndex::NatExpr(form, _) => {
let actual_nat = actual_idx
.nat_range_form()
.filter(NatPolyForm::is_constant)
.map(|actual_form| actual_form.constant())
.ok_or_else(|| GraphcalError::IndexMismatch {
expected: IndexName::new(format!("range({})", form.format())),
found: actual_idx.name(),
src: src.clone(),
span: span.into(),
})?;
let actual_idx_name = actual_idx.name();
unify_nat_poly_form(
form,
actual_nat,
nat_sub,
&actual_idx_name,
src,
span,
)?;
}
}
current = element;
}
unify_resolved_type(
base, current, dim_sub, index_sub, nat_sub, registry, src, span,
)
}
ResolvedTypeExpr::Bool => {
if *actual != InferredType::Bool {
return Err(GraphcalError::DimensionMismatch {
expected: "Bool".to_string(),
found: crate::tir::dim_check::format_inferred_type(actual, registry),
help: "expected Bool argument".to_string(),
src: src.clone(),
span: span.into(),
});
}
Ok(())
}
ResolvedTypeExpr::Int => {
if !actual.is_int_like() {
return Err(GraphcalError::DimensionMismatch {
expected: "Int".to_string(),
found: crate::tir::dim_check::format_inferred_type(actual, registry),
help: "expected Int argument".to_string(),
src: src.clone(),
span: span.into(),
});
}
Ok(())
}
ResolvedTypeExpr::Datetime(expected_scale) => {
if *actual != InferredType::Datetime(*expected_scale) {
let expected_str = if expected_scale.is_utc() {
"Datetime".to_string()
} else {
format!("Datetime<{expected_scale}>")
};
return Err(GraphcalError::DimensionMismatch {
expected: expected_str,
found: crate::tir::dim_check::format_inferred_type(actual, registry),
help: "expected Datetime argument".to_string(),
src: src.clone(),
span: span.into(),
});
}
Ok(())
}
ResolvedTypeExpr::IndexArg(expected_index) => {
let InferredType::NamedIndex(actual_index) = actual else {
return Err(GraphcalError::DimensionMismatch {
expected: format!("index {}", format_resolved_index(expected_index)),
found: crate::tir::dim_check::format_inferred_type(actual, registry),
help: "expected an index generic argument".to_string(),
src: src.clone(),
span: span.into(),
});
};
if !resolved_index_matches_inferred(expected_index, actual_index) {
return Err(GraphcalError::IndexMismatch {
expected: resolved_index_display_name(expected_index),
found: actual_index.name(),
src: src.clone(),
span: span.into(),
});
}
Ok(())
}
ResolvedTypeExpr::Dimensionless => {
let actual_dim = crate::tir::dim_check::expect_scalar(actual, registry, src, span)?;
if !actual_dim.is_dimensionless() {
return Err(GraphcalError::DimensionMismatch {
expected: "Dimensionless".to_string(),
found: registry.dimensions.format_dimension(&actual_dim),
help: "expected Dimensionless argument".to_string(),
src: src.clone(),
span: span.into(),
});
}
Ok(())
}
ResolvedTypeExpr::Scalar(expected_dim) => {
let actual_dim = crate::tir::dim_check::expect_scalar(actual, registry, src, span)?;
if *expected_dim != actual_dim {
return Err(GraphcalError::DimensionMismatch {
expected: registry.dimensions.format_dimension(expected_dim),
found: registry.dimensions.format_dimension(&actual_dim),
help: "dimension mismatch in function argument".to_string(),
src: src.clone(),
span: span.into(),
});
}
Ok(())
}
ResolvedTypeExpr::GenericStruct {
name, type_args, ..
} => {
let InferredType::Struct(actual_name, actual_args) = actual else {
return Err(GraphcalError::DimensionMismatch {
expected: name.as_str().to_string(),
found: crate::tir::dim_check::format_inferred_type(actual, registry),
help: format!("expected struct type `{}`", name.as_str()),
src: src.clone(),
span: span.into(),
});
};
if actual_name.resolved() != name {
return Err(GraphcalError::DimensionMismatch {
expected: name.as_str().to_string(),
found: crate::tir::dim_check::format_inferred_type(actual, registry),
help: format!("expected struct type `{}`", name.as_str()),
src: src.clone(),
span: span.into(),
});
}
for (declared_arg, actual_arg) in type_args.iter().zip(actual_args) {
unify_resolved_type(
declared_arg,
actual_arg,
dim_sub,
index_sub,
nat_sub,
registry,
src,
span,
)?;
}
Ok(())
}
ResolvedTypeExpr::Struct(name, _) => {
let InferredType::Struct(actual_name, _) = actual else {
return Err(GraphcalError::DimensionMismatch {
expected: name.as_str().to_string(),
found: crate::tir::dim_check::format_inferred_type(actual, registry),
help: format!("expected struct type `{}`", name.as_str()),
src: src.clone(),
span: span.into(),
});
};
if actual_name.resolved() != name {
return Err(GraphcalError::DimensionMismatch {
expected: name.as_str().to_string(),
found: crate::tir::dim_check::format_inferred_type(actual, registry),
help: format!("expected struct type `{}`", name.as_str()),
src: src.clone(),
span: span.into(),
});
}
Ok(())
}
ResolvedTypeExpr::GenericDimParam(gp, _) => {
let actual_dim = crate::tir::dim_check::expect_scalar(actual, registry, src, span)?;
bind_or_check(dim_sub, gp.clone(), actual_dim, |prev, new| {
GraphcalError::DimensionMismatch {
expected: registry.dimensions.format_dimension(prev),
found: registry.dimensions.format_dimension(new),
help: format!(
"generic `{gp}` was bound to {} but this argument requires {}",
registry.dimensions.format_dimension(prev),
registry.dimensions.format_dimension(new),
),
src: src.clone(),
span: span.into(),
}
})
}
ResolvedTypeExpr::GenericTypeParam(gp, gp_span) => Err(GraphcalError::EvalError {
message: format!(
"cannot infer unconstrained generic type parameter `{gp}` in this position yet"
),
src: src.clone(),
span: (*gp_span).into(),
}),
ResolvedTypeExpr::GenericDimExpr { terms, .. } => {
let actual_dim = crate::tir::dim_check::expect_scalar(actual, registry, src, span)?;
if terms.len() == 1
&& let ResolvedDimTerm::GenericParam {
name: gp,
power,
op: MulDivOp::Mul,
..
} = &terms[0]
{
let bound_dim = if *power == Rational::ONE {
actual_dim
} else {
let exponent = Rational::try_new(power.den(), power.num()).map_err(|_| {
GraphcalError::InternalError {
message: format!("generic dimension parameter `{gp}` has zero power"),
src: src.clone(),
span: span.into(),
}
})?;
actual_dim
.pow(exponent)
.map_err(|_| GraphcalError::DimensionOverflow {
src: src.clone(),
span: span.into(),
})?
};
bind_or_check(dim_sub, gp.clone(), bound_dim, |prev, new| {
GraphcalError::DimensionMismatch {
expected: registry.dimensions.format_dimension(prev),
found: registry.dimensions.format_dimension(new),
help: format!(
"generic `{gp}` was bound to {} but this argument requires {}",
registry.dimensions.format_dimension(prev),
registry.dimensions.format_dimension(new),
),
src: src.clone(),
span: span.into(),
}
})?;
return Ok(());
}
let mut expected_dim = Dimension::dimensionless();
for term in terms {
let overflow_err = || GraphcalError::DimensionOverflow {
src: src.clone(),
span: span.into(),
};
let term_dim = match term {
ResolvedDimTerm::Concrete { dim, power, .. } => {
dim.pow(*power).map_err(|_| overflow_err())?
}
ResolvedDimTerm::GenericParam {
name: gp, power, ..
} => {
if let Some(prev) = dim_sub.get(gp) {
prev.pow(*power).map_err(|_| overflow_err())?
} else {
return Err(GraphcalError::DimensionMismatch {
expected: format!("generic `{gp}` (unresolved)"),
found: registry.dimensions.format_dimension(&actual_dim),
help: format!(
"generic `{gp}` could not be inferred from this argument"
),
src: src.clone(),
span: span.into(),
});
}
}
};
expected_dim = match term.op() {
MulDivOp::Mul => (expected_dim * term_dim).map_err(|_| overflow_err())?,
MulDivOp::Div => (expected_dim / term_dim).map_err(|_| overflow_err())?,
};
}
if expected_dim != actual_dim {
return Err(GraphcalError::DimensionMismatch {
expected: registry.dimensions.format_dimension(&expected_dim),
found: registry.dimensions.format_dimension(&actual_dim),
help: "dimension mismatch in function argument".to_string(),
src: src.clone(),
span: span.into(),
});
}
Ok(())
}
}
}
#[expect(
clippy::implicit_hasher,
reason = "always called with standard HashMap"
)]
pub fn substitute_resolved_type(
resolved: &ResolvedTypeExpr,
dim_sub: &HashMap<GenericParamName, Dimension>,
index_sub: &HashMap<GenericParamName, IndexTypeRef>,
nat_sub: &HashMap<GenericParamName, u64>,
src: &NamedSource<Arc<String>>,
) -> Result<crate::tir::dim_check::InferredType, GraphcalError> {
let no_type_sub = HashMap::new();
substitute_resolved_type_with_types(resolved, dim_sub, index_sub, nat_sub, &no_type_sub, src)
}
#[expect(
clippy::implicit_hasher,
reason = "always called with standard HashMap"
)]
#[expect(
clippy::too_many_lines,
reason = "single dispatch over ResolvedTypeExpr variants with per-variant generic-substitution + dimension-arithmetic overflow handling"
)]
pub fn substitute_resolved_type_with_types(
resolved: &ResolvedTypeExpr,
dim_sub: &HashMap<GenericParamName, Dimension>,
index_sub: &HashMap<GenericParamName, IndexTypeRef>,
nat_sub: &HashMap<GenericParamName, u64>,
type_sub: &HashMap<GenericParamName, crate::tir::dim_check::InferredType>,
src: &NamedSource<Arc<String>>,
) -> Result<crate::tir::dim_check::InferredType, GraphcalError> {
use crate::tir::dim_check::InferredType;
match resolved {
ResolvedTypeExpr::Dimensionless => Ok(InferredType::Scalar(Dimension::dimensionless())),
ResolvedTypeExpr::Bool => Ok(InferredType::Bool),
ResolvedTypeExpr::Int => Ok(InferredType::Int),
ResolvedTypeExpr::Datetime(scale) => Ok(InferredType::Datetime(*scale)),
ResolvedTypeExpr::IndexArg(index) => {
resolved_index_to_inferred(index, src).map(InferredType::NamedIndex)
}
ResolvedTypeExpr::Scalar(dim) => Ok(InferredType::Scalar(dim.clone())),
ResolvedTypeExpr::Struct(name, _) => Ok(InferredType::Struct(
crate::tir::dim_check::InferredStructType::from_resolved(name.clone()),
vec![],
)),
ResolvedTypeExpr::GenericStruct {
name, type_args, ..
} => {
let mut inferred_args = Vec::with_capacity(type_args.len());
for arg in type_args {
inferred_args.push(substitute_resolved_type_with_types(
arg, dim_sub, index_sub, nat_sub, type_sub, src,
)?);
}
Ok(InferredType::Struct(
crate::tir::dim_check::InferredStructType::from_resolved(name.clone()),
inferred_args,
))
}
ResolvedTypeExpr::GenericDimParam(gp, span) => dim_sub.get(gp).map_or_else(
|| {
Err(GraphcalError::EvalError {
message: format!("generic `{gp}` not bound during substitution"),
src: src.clone(),
span: (*span).into(),
})
},
|dim| Ok(InferredType::Scalar(dim.clone())),
),
ResolvedTypeExpr::GenericTypeParam(gp, span) => type_sub.get(gp).map_or_else(
|| {
Err(GraphcalError::EvalError {
message: format!("generic type parameter `{gp}` not bound during substitution"),
src: src.clone(),
span: (*span).into(),
})
},
|ty| Ok(ty.clone()),
),
ResolvedTypeExpr::GenericDimExpr { terms, span } => {
let overflow_err = || GraphcalError::DimensionOverflow {
src: src.clone(),
span: (*span).into(),
};
let mut result = Dimension::dimensionless();
for term in terms {
let term_dim = match term {
ResolvedDimTerm::Concrete { dim, power, .. } => {
dim.pow(*power).map_err(|_| overflow_err())?
}
ResolvedDimTerm::GenericParam {
name: gp,
power,
span: term_span,
..
} => {
let base = dim_sub.get(gp).ok_or_else(|| GraphcalError::EvalError {
message: format!("generic `{gp}` not bound during substitution"),
src: src.clone(),
span: (*term_span).into(),
})?;
base.pow(*power).map_err(|_| overflow_err())?
}
};
result = match term.op() {
MulDivOp::Mul => (result * term_dim).map_err(|_| overflow_err())?,
MulDivOp::Div => (result / term_dim).map_err(|_| overflow_err())?,
};
}
Ok(InferredType::Scalar(result))
}
ResolvedTypeExpr::Indexed { base, indexes } => {
let mut result = substitute_resolved_type_with_types(
base, dim_sub, index_sub, nat_sub, type_sub, src,
)?;
for idx in indexes.iter().rev() {
let resolved_idx = match idx {
ResolvedIndex::Concrete(name, _) => {
result = InferredType::Indexed {
element: Box::new(result),
index: crate::tir::dim_check::InferredIndex::from_resolved(
name.clone(),
),
};
continue;
}
ResolvedIndex::GenericParam(gp, span) => {
crate::tir::dim_check::InferredIndex::from_ref(
index_sub
.get(gp)
.cloned()
.ok_or_else(|| GraphcalError::EvalError {
message: format!(
"generic index `{gp}` not bound during substitution"
),
src: src.clone(),
span: (*span).into(),
})?,
)
}
ResolvedIndex::NatExpr(form, span) => {
let n = form.evaluate(nat_sub).ok_or_else(|| {
let vars = form.variables();
let unbound: Vec<&str> = vars
.iter()
.filter(|k| !nat_sub.contains_key(*k))
.map(GenericParamName::as_str)
.collect();
GraphcalError::EvalError {
message: format!(
"generic nat parameter(s) [{}] not bound during substitution",
unbound.join(", ")
),
src: src.clone(),
span: (*span).into(),
}
})?;
crate::tir::dim_check::InferredIndex::from_nat_range_form(
NatPolyForm::from_constant(n),
)
.map_err(|err| GraphcalError::EvalError {
message: err.to_string(),
src: src.clone(),
span: (*span).into(),
})?
}
};
result = InferredType::Indexed {
element: Box::new(result),
index: resolved_idx,
};
}
Ok(result)
}
}
}
fn require_local_type_level_path<'a>(
path: &'a NamePath,
span: Span,
src: &NamedSource<Arc<String>>,
) -> Result<&'a str, GraphcalError> {
path.as_bare()
.map(super::super::syntax::names::NameAtom::as_str)
.ok_or_else(|| GraphcalError::EvalError {
message: format!(
"qualified type-level reference `{path}` needs module-aware resolution"
),
src: src.clone(),
span: span.into(),
})
}
fn module_resolve_error(
err: &ModuleResolveError,
src: &NamedSource<Arc<String>>,
span: Span,
) -> GraphcalError {
GraphcalError::EvalError {
message: err.to_string(),
src: src.clone(),
span: span.into(),
}
}
fn internal_error(message: String, src: &NamedSource<Arc<String>>, span: Span) -> GraphcalError {
GraphcalError::InternalError {
message,
src: src.clone(),
span: span.into(),
}
}
const fn module_lookup_is_absent(err: &ModuleResolveError) -> bool {
matches!(err, ModuleResolveError::UnknownName { .. })
}
fn type_lower_error_to_graphcal(
err: &hir::HirLowerError,
type_ann: &TypeExpr,
src: &NamedSource<Arc<String>>,
) -> GraphcalError {
if let hir::HirLowerError::UnknownTypePath { path, span } = err {
if type_expr_has_index_name_at_span(type_ann, *span)
&& let Ok(name) = IndexName::try_new(path.clone())
{
return GraphcalError::UnknownIndex {
name,
src: src.clone(),
span: (*span).into(),
};
}
if type_expr_has_dim_term_at_span(type_ann, *span)
&& let Ok(name) = DimName::try_new(path.clone())
{
return GraphcalError::UnknownDimension {
name,
src: src.clone(),
span: (*span).into(),
};
}
}
hir_lower_error_to_graphcal(err, src)
}
fn type_expr_has_index_name_at_span(type_ann: &TypeExpr, span: Span) -> bool {
match &type_ann.kind {
TypeExprKind::Indexed { base, indexes } => {
type_expr_has_index_name_at_span(base, span)
|| indexes.iter().any(|index| match index {
crate::desugar::desugared_ast::IndexExpr::Name(name) => name.span == span,
crate::desugar::desugared_ast::IndexExpr::NatExpr(_) => false,
})
}
TypeExprKind::TypeApplication { type_args, .. }
| TypeExprKind::DatetimeApplication { type_args } => type_args
.iter()
.any(|arg| type_expr_has_index_name_at_span(arg, span)),
TypeExprKind::Dimensionless
| TypeExprKind::Bool
| TypeExprKind::Int
| TypeExprKind::Datetime
| TypeExprKind::DimExpr(_) => false,
}
}
fn type_expr_has_dim_term_at_span(type_ann: &TypeExpr, span: Span) -> bool {
match &type_ann.kind {
TypeExprKind::DimExpr(dim_expr) => dim_expr
.terms
.iter()
.any(|item| item.term.name.span == span),
TypeExprKind::Indexed { base, .. } => type_expr_has_dim_term_at_span(base, span),
TypeExprKind::TypeApplication { type_args, .. }
| TypeExprKind::DatetimeApplication { type_args } => type_args
.iter()
.any(|arg| type_expr_has_dim_term_at_span(arg, span)),
TypeExprKind::Dimensionless
| TypeExprKind::Bool
| TypeExprKind::Int
| TypeExprKind::Datetime => false,
}
}
#[derive(Clone, Copy)]
struct HirTypeResolutionContext<'a> {
src: &'a NamedSource<Arc<String>>,
resolver: &'a ModuleResolver,
module_types: &'a ModuleTypeRegistry,
registry: Option<&'a Registry>,
prelude: &'a hir::PreludeTypeScope,
}
pub fn resolve_hir_type_expr(
type_ann: &hir::TypeExpr,
_registry: &Registry,
src: &NamedSource<Arc<String>>,
module_ctx: ModuleTypeContext<'_>,
) -> Result<ResolvedTypeExpr, GraphcalError> {
let prelude = hir::PreludeTypeScope::graphcal();
let ctx = HirTypeResolutionContext {
src,
resolver: module_ctx.resolver,
module_types: module_ctx.types,
registry: None,
prelude: &prelude,
};
resolve_hir_type_expr_inner(type_ann, ctx)
}
fn resolve_ast_type_expr_via_hir(
type_ann: &TypeExpr,
registry: &Registry,
src: &NamedSource<Arc<String>>,
module_ctx: ModuleTypeContext<'_>,
) -> Result<ResolvedTypeExpr, GraphcalError> {
let generic_scope = hir::GenericScope::new();
let prelude = hir::PreludeTypeScope::graphcal();
let lower_ctx =
hir::TypeLoweringContext::new(module_ctx.owner, module_ctx.resolver, &generic_scope)
.with_prelude(&prelude);
let hir_type = hir::lower_type_expr(type_ann, lower_ctx)
.map_err(|err| type_lower_error_to_graphcal(&err, type_ann, src))?;
let resolve_ctx = HirTypeResolutionContext {
src,
resolver: module_ctx.resolver,
module_types: module_ctx.types,
registry: Some(registry),
prelude: &prelude,
};
resolve_hir_type_expr_inner(&hir_type, resolve_ctx)
}
fn resolve_hir_type_expr_inner(
type_ann: &hir::TypeExpr,
ctx: HirTypeResolutionContext<'_>,
) -> Result<ResolvedTypeExpr, GraphcalError> {
match &type_ann.kind {
hir::TypeExprKind::Builtin(builtin) => Ok(resolve_hir_builtin_type(*builtin)),
hir::TypeExprKind::DimExpr(dim_expr) => resolve_hir_dim_expr(dim_expr, ctx),
hir::TypeExprKind::Index(index) => Err(GraphcalError::EvalError {
message: format!(
"index `{}` cannot be used as a type",
format_hir_index_ref(index)
),
src: ctx.src.clone(),
span: hir_index_ref_span(index).into(),
}),
hir::TypeExprKind::Struct(name) => {
hir_struct_type_def(&name.value, name.span, ctx)?;
Ok(ResolvedTypeExpr::Struct(name.value.clone(), name.span))
}
hir::TypeExprKind::GenericTypeParam(param) => Ok(ResolvedTypeExpr::GenericTypeParam(
param.value.name.clone(),
param.span,
)),
hir::TypeExprKind::TypeApplication { name, type_args } => {
resolve_hir_type_application(type_ann, name, type_args, ctx)
}
hir::TypeExprKind::Indexed { base, indexes } => {
let resolved_base = resolve_hir_type_expr_inner(base, ctx)?;
let resolved_indexes = indexes
.iter()
.map(|index| resolve_hir_index_ref(index, ctx))
.collect::<Result<Vec<_>, _>>()?;
Ok(ResolvedTypeExpr::Indexed {
base: Box::new(resolved_base),
indexes: resolved_indexes,
})
}
}
}
const fn hir_index_ref_span(index: &hir::IndexRef) -> Span {
match index {
hir::IndexRef::Concrete(name) => name.span,
hir::IndexRef::GenericParam(param) => param.span,
hir::IndexRef::NatExpr(nat_expr) => nat_expr.span(),
}
}
fn format_hir_index_ref(index: &hir::IndexRef) -> String {
match index {
hir::IndexRef::Concrete(name) => name.value.as_str().to_string(),
hir::IndexRef::GenericParam(param) => param.value.name.to_string(),
hir::IndexRef::NatExpr(nat_expr) => format!("range({})", format_hir_nat_expr(nat_expr)),
}
}
fn format_hir_nat_expr(nat_expr: &hir::NatExpr) -> String {
match nat_expr {
hir::NatExpr::Literal(n, _) => n.to_string(),
hir::NatExpr::Param(param) => param.value.name.to_string(),
hir::NatExpr::Add(lhs, rhs, _) => {
format!(
"{} + {}",
format_hir_nat_expr(lhs),
format_hir_nat_expr(rhs)
)
}
hir::NatExpr::Mul(lhs, rhs, _) => {
format!(
"{} * {}",
format_hir_nat_expr(lhs),
format_hir_nat_expr(rhs)
)
}
}
}
const fn resolve_hir_builtin_type(builtin: hir::BuiltinType) -> ResolvedTypeExpr {
match builtin {
hir::BuiltinType::Dimensionless => ResolvedTypeExpr::Dimensionless,
hir::BuiltinType::Bool => ResolvedTypeExpr::Bool,
hir::BuiltinType::Int => ResolvedTypeExpr::Int,
hir::BuiltinType::Datetime(scale) => ResolvedTypeExpr::Datetime(scale.scale()),
}
}
fn hir_dimension(
name: &ResolvedName<namespace::Dim>,
span: Span,
ctx: HirTypeResolutionContext<'_>,
) -> Result<Dimension, GraphcalError> {
ctx.module_types
.get_dimension(name)
.cloned()
.or_else(|| {
ctx.registry.and_then(|registry| {
registry
.dimensions
.get_dimension(name.to_unowned_def_name().as_str())
.cloned()
})
})
.ok_or_else(|| GraphcalError::UnknownDimension {
name: name.to_unowned_def_name(),
src: ctx.src.clone(),
span: span.into(),
})
}
fn hir_index_name(
name: &ResolvedName<namespace::Index>,
span: Span,
ctx: HirTypeResolutionContext<'_>,
) -> Result<IndexName, GraphcalError> {
if ctx.module_types.get_index(name).is_some() {
Ok(name.to_unowned_def_name())
} else {
Err(GraphcalError::UnknownIndex {
name: name.to_unowned_def_name(),
src: ctx.src.clone(),
span: span.into(),
})
}
}
fn hir_struct_type_def<'a>(
name: &ResolvedName<namespace::StructType>,
span: Span,
ctx: HirTypeResolutionContext<'a>,
) -> Result<&'a TypeDef, GraphcalError> {
ctx.module_types
.get_struct_type(name)
.ok_or_else(|| GraphcalError::UnknownStructType {
name: name.to_string(),
src: ctx.src.clone(),
span: span.into(),
})
}
fn resolve_hir_dim_expr(
dim_expr: &hir::DimExpr,
ctx: HirTypeResolutionContext<'_>,
) -> Result<ResolvedTypeExpr, GraphcalError> {
let terms = dim_expr
.terms
.iter()
.map(|item| resolve_hir_dim_expr_item(item, ctx))
.collect::<Result<Vec<_>, _>>()?;
if let [
ResolvedDimTerm::GenericParam {
name,
power,
op: MulDivOp::Mul,
span,
},
] = terms.as_slice()
&& *power == Rational::ONE
{
return Ok(ResolvedTypeExpr::GenericDimParam(name.clone(), *span));
}
let has_generic = terms
.iter()
.any(|term| matches!(term, ResolvedDimTerm::GenericParam { .. }));
if has_generic {
return Ok(ResolvedTypeExpr::GenericDimExpr {
terms,
span: dim_expr.span,
});
}
let result = terms.iter().try_fold(
Dimension::dimensionless(),
|acc, term| -> Result<Dimension, GraphcalError> {
let ResolvedDimTerm::Concrete { dim, power, op } = term else {
return Err(GraphcalError::InternalError {
message: "generic dimension term reached concrete dimension folding"
.to_string(),
src: ctx.src.clone(),
span: dim_expr.span.into(),
});
};
let overflow_err = || GraphcalError::DimensionOverflow {
src: ctx.src.clone(),
span: dim_expr.span.into(),
};
let powered = dim.pow(*power).map_err(|_| overflow_err())?;
match op {
MulDivOp::Mul => (acc * powered).map_err(|_| overflow_err()),
MulDivOp::Div => (acc / powered).map_err(|_| overflow_err()),
}
},
)?;
Ok(ResolvedTypeExpr::Scalar(result))
}
fn resolve_hir_dim_expr_item(
item: &hir::DimExprItem,
ctx: HirTypeResolutionContext<'_>,
) -> Result<ResolvedDimTerm, GraphcalError> {
let power = item.term.power.unwrap_or(Rational::ONE);
match &item.term.target {
hir::DimTermTarget::Dimension(name) => Ok(ResolvedDimTerm::Concrete {
dim: hir_dimension(&name.value, name.span, ctx)?,
power,
op: item.op,
}),
hir::DimTermTarget::GenericParam(param) => Ok(ResolvedDimTerm::GenericParam {
name: param.value.name.clone(),
power,
op: item.op,
span: item.term.span,
}),
}
}
fn resolve_hir_index_ref(
index: &hir::IndexRef,
ctx: HirTypeResolutionContext<'_>,
) -> Result<ResolvedIndex, GraphcalError> {
match index {
hir::IndexRef::Concrete(name) => {
hir_index_name(&name.value, name.span, ctx)?;
Ok(ResolvedIndex::Concrete(name.value.clone(), name.span))
}
hir::IndexRef::GenericParam(param) => Ok(ResolvedIndex::GenericParam(
param.value.name.clone(),
param.span,
)),
hir::IndexRef::NatExpr(nat_expr) => Ok(ResolvedIndex::NatExpr(
normalize_hir_nat_expr(nat_expr)
.map_err(|err| nat_overflow_error(err, ctx.src, nat_expr.span()))?,
nat_expr.span(),
)),
}
}
fn normalize_hir_nat_expr(
expr: &hir::NatExpr,
) -> Result<NatPolyForm, crate::syntax::nat::NatOverflowError> {
match expr {
hir::NatExpr::Literal(value, _) => Ok(NatPolyForm::from_constant(*value)),
hir::NatExpr::Param(param) => Ok(NatPolyForm::from_var(param.value.name.clone())),
hir::NatExpr::Add(lhs, rhs, _) => {
normalize_hir_nat_expr(lhs)?.add(&normalize_hir_nat_expr(rhs)?)
}
hir::NatExpr::Mul(lhs, rhs, _) => {
normalize_hir_nat_expr(lhs)?.mul(&normalize_hir_nat_expr(rhs)?)
}
}
}
fn check_type_application_arity(
type_name: &str,
type_def: &TypeDef,
arg_count: usize,
span: Span,
src: &NamedSource<Arc<String>>,
) -> Result<(), GraphcalError> {
let total_params = type_def.generic_params.len();
let required_count = type_def
.generic_params
.iter()
.take_while(|p| p.default.is_none())
.count();
if arg_count < required_count || arg_count > total_params {
let hint = if required_count == total_params {
format!("{total_params}")
} else {
format!("{required_count}..{total_params}")
};
return Err(GraphcalError::EvalError {
message: format!("type `{type_name}` expects {hint} type argument(s), got {arg_count}"),
src: src.clone(),
span: span.into(),
});
}
Ok(())
}
fn resolve_hir_type_application(
type_ann: &hir::TypeExpr,
name: &crate::syntax::span::Spanned<ResolvedName<namespace::StructType>>,
type_args: &[hir::TypeExpr],
ctx: HirTypeResolutionContext<'_>,
) -> Result<ResolvedTypeExpr, GraphcalError> {
let type_def = hir_struct_type_def(&name.value, name.span, ctx)?;
check_type_application_arity(
name.value.as_str(),
type_def,
type_args.len(),
type_ann.span,
ctx.src,
)?;
let mut resolved_args = Vec::with_capacity(type_def.generic_params.len());
for (param, arg) in type_def.generic_params.iter().zip(type_args) {
resolved_args.push(resolve_hir_type_arg_for_param(param, arg, ctx)?);
}
for param in type_def.generic_params.iter().skip(type_args.len()) {
let default_expr = param
.default
.as_ref()
.ok_or_else(|| GraphcalError::EvalError {
message: format!(
"internal: generic parameter `{}` has no default",
param.name
),
src: ctx.src.clone(),
span: type_ann.span.into(),
})?;
let default_hir = lower_type_generic_default(default_expr, &name.value, type_def, ctx)?;
resolved_args.push(resolve_hir_type_arg_for_param(param, &default_hir, ctx)?);
}
Ok(ResolvedTypeExpr::GenericStruct {
name: name.value.clone(),
type_args: resolved_args,
span: type_ann.span,
})
}
fn resolve_hir_type_arg_for_param(
param: &crate::registry::types::TypeGenericParam,
arg: &hir::TypeExpr,
ctx: HirTypeResolutionContext<'_>,
) -> Result<ResolvedTypeExpr, GraphcalError> {
match param.constraint {
TypeGenericConstraint::Index => match &arg.kind {
hir::TypeExprKind::Index(index) => {
resolve_hir_index_ref(index, ctx).map(ResolvedTypeExpr::IndexArg)
}
_ => Err(GraphcalError::EvalError {
message: format!(
"generic parameter `{}` expects an Index argument",
param.name
),
src: ctx.src.clone(),
span: arg.span.into(),
}),
},
TypeGenericConstraint::Nat => Err(GraphcalError::EvalError {
message: format!(
"generic parameter `{}` expects a Nat argument, got a type argument",
param.name
),
src: ctx.src.clone(),
span: arg.span.into(),
}),
TypeGenericConstraint::Dim | TypeGenericConstraint::Unconstrained => {
resolve_hir_type_expr_inner(arg, ctx)
}
}
}
fn lower_type_generic_default(
default_expr: &TypeExpr,
type_owner: &ResolvedName<namespace::StructType>,
type_def: &TypeDef,
ctx: HirTypeResolutionContext<'_>,
) -> Result<hir::TypeExpr, GraphcalError> {
let mut scope = hir::GenericScope::new();
for param in &type_def.generic_params {
let constraint = match param.constraint {
TypeGenericConstraint::Dim => crate::syntax::ast::GenericConstraint::Dim,
TypeGenericConstraint::Index => crate::syntax::ast::GenericConstraint::Index,
TypeGenericConstraint::Nat => crate::syntax::ast::GenericConstraint::Nat,
TypeGenericConstraint::Unconstrained => crate::syntax::ast::GenericConstraint::Type,
};
let id = hir::GenericParamId::new(
hir::GenericParamOwner::Type(type_owner.clone()),
param.name.clone(),
);
scope
.insert_binding(hir::GenericParamBinding::new(
id,
constraint,
default_expr.span,
))
.map_err(|err| hir_lower_error_to_graphcal(&err, ctx.src))?;
}
let lower_ctx = hir::TypeLoweringContext::new(type_owner.owner(), ctx.resolver, &scope)
.with_prelude(ctx.prelude);
hir::lower_type_expr(default_expr, lower_ctx)
.map_err(|err| hir_lower_error_to_graphcal(&err, ctx.src))
}
#[expect(
clippy::too_many_arguments,
reason = "resolves one AST index path against generic params, local registry, and module context"
)]
fn resolve_index_expr_name(
path: &NamePath,
span: Span,
registry: &Registry,
owner: &crate::dag_id::DagId,
index_params: &[GenericParamName],
nat_params: &[GenericParamName],
src: &NamedSource<Arc<String>>,
module_ctx: Option<ModuleTypeContext<'_>>,
) -> Result<ResolvedIndex, GraphcalError> {
if let Some(atom) = path.as_bare() {
let text = atom.as_str();
if let Some(gp) = nat_params.iter().find(|p| p.as_str() == text) {
return Ok(ResolvedIndex::NatExpr(
NatPolyForm::from_var(gp.clone()),
span,
));
}
if let Some(gp) = index_params.iter().find(|p| p.as_str() == text) {
return Ok(ResolvedIndex::GenericParam(gp.clone(), span));
}
}
if let Some(ctx) = module_ctx {
match ctx.resolver.resolve_index_path(ctx.owner, path) {
Ok(resolved) => {
if ctx.types.get_index(&resolved).is_some() {
return Ok(ResolvedIndex::Concrete(resolved, span));
}
return Err(GraphcalError::UnknownIndex {
name: resolved.to_unowned_def_name(),
src: src.clone(),
span: span.into(),
});
}
Err(err) if path.is_bare() && module_lookup_is_absent(&err) => {}
Err(err) => return Err(module_resolve_error(&err, src, span)),
}
}
let text = require_local_type_level_path(path, span, src)?;
if registry.indexes.get_index(text).is_some() {
Ok(ResolvedIndex::Concrete(
ResolvedName::from_def(owner.clone(), IndexName::new(text)),
span,
))
} else {
Err(GraphcalError::UnknownIndex {
name: IndexName::new(text),
src: src.clone(),
span: span.into(),
})
}
}
fn resolve_concrete_index_path(
path: &NamePath,
span: Span,
registry: &Registry,
owner: &crate::dag_id::DagId,
src: &NamedSource<Arc<String>>,
module_ctx: Option<ModuleTypeContext<'_>>,
) -> Result<Option<ResolvedName<namespace::Index>>, GraphcalError> {
if let Some(ctx) = module_ctx {
match ctx.resolver.resolve_index_path(ctx.owner, path) {
Ok(resolved) => {
let Some(index) = ctx.types.get_index(&resolved) else {
return Err(GraphcalError::UnknownIndex {
name: resolved.to_unowned_def_name(),
src: src.clone(),
span: span.into(),
});
};
if matches!(
index.kind,
crate::registry::types::IndexKind::Named { .. }
| crate::registry::types::IndexKind::RequiredNamed
) {
return Ok(Some(resolved));
}
return Ok(None);
}
Err(err) if module_lookup_is_absent(&err) => {}
Err(_) if path.is_bare() => {
}
Err(err) => return Err(module_resolve_error(&err, src, span)),
}
}
let Some(atom) = path.as_bare() else {
return Ok(None);
};
let Some(index) = registry.indexes.get_index(atom.as_str()) else {
return Ok(None);
};
Ok(matches!(
index.kind,
crate::registry::types::IndexKind::Named { .. }
| crate::registry::types::IndexKind::RequiredNamed
)
.then(|| ResolvedName::from_def(owner.clone(), IndexName::from_atom(atom.clone()))))
}
type ResolvedStructTypeLookup<'a> = Option<(ResolvedName<namespace::StructType>, &'a TypeDef)>;
fn resolve_struct_type_path<'a>(
path: &NamePath,
span: Span,
registry: &'a Registry,
owner: &crate::dag_id::DagId,
src: &NamedSource<Arc<String>>,
module_ctx: Option<ModuleTypeContext<'a>>,
) -> Result<ResolvedStructTypeLookup<'a>, GraphcalError> {
if let Some(ctx) = module_ctx {
match ctx.resolver.resolve_struct_type_path(ctx.owner, path) {
Ok(resolved) => {
if let Some(type_def) = ctx.types.get_struct_type(&resolved) {
return Ok(Some((resolved, type_def)));
}
return Err(GraphcalError::UnknownStructType {
name: resolved.to_string(),
src: src.clone(),
span: span.into(),
});
}
Err(err) if module_lookup_is_absent(&err) => {}
Err(_err) if path.is_bare() => {}
Err(err) => return Err(module_resolve_error(&err, src, span)),
}
}
let Some(atom) = path.as_bare() else {
return Ok(None);
};
Ok(registry.types.get_type(atom.as_str()).map(|type_def| {
(
ResolvedName::from_def(owner.clone(), StructTypeName::from_atom(atom.clone())),
type_def,
)
}))
}
fn resolve_dimension_path(
path: &NamePath,
span: Span,
registry: &Registry,
src: &NamedSource<Arc<String>>,
module_ctx: Option<ModuleTypeContext<'_>>,
) -> Result<Option<Dimension>, GraphcalError> {
if let Some(ctx) = module_ctx {
match ctx.resolver.resolve_dimension_path(ctx.owner, path) {
Ok(resolved) => {
return ctx
.types
.get_dimension(&resolved)
.cloned()
.map(Some)
.ok_or_else(|| GraphcalError::UnknownDimension {
name: resolved.to_unowned_def_name(),
src: src.clone(),
span: span.into(),
});
}
Err(err) if path.is_bare() && module_lookup_is_absent(&err) => {}
Err(err) => return Err(module_resolve_error(&err, src, span)),
}
}
let text = require_local_type_level_path(path, span, src)?;
Ok(registry.dimensions.get_dimension(text).cloned())
}
pub fn resolve_type_expr(
type_ann: &TypeExpr,
registry: &Registry,
dim_params: &[GenericParamName],
index_params: &[GenericParamName],
nat_params: &[GenericParamName],
src: &NamedSource<Arc<String>>,
) -> Result<ResolvedTypeExpr, GraphcalError> {
let owner = crate::dag_id::DagId::root("<type-resolution>");
resolve_type_expr_inner(
type_ann,
registry,
&owner,
dim_params,
index_params,
nat_params,
src,
None,
)
}
pub fn resolve_type_expr_with_modules(
type_ann: &TypeExpr,
registry: &Registry,
dim_params: &[GenericParamName],
index_params: &[GenericParamName],
nat_params: &[GenericParamName],
src: &NamedSource<Arc<String>>,
module_ctx: ModuleTypeContext<'_>,
) -> Result<ResolvedTypeExpr, GraphcalError> {
resolve_type_expr_inner(
type_ann,
registry,
module_ctx.owner,
dim_params,
index_params,
nat_params,
src,
Some(module_ctx),
)
}
#[expect(
clippy::too_many_arguments,
reason = "recursive resolver threads generic parameter scopes and optional module context"
)]
fn resolve_type_expr_inner(
type_ann: &TypeExpr,
registry: &Registry,
owner: &crate::dag_id::DagId,
dim_params: &[GenericParamName],
index_params: &[GenericParamName],
nat_params: &[GenericParamName],
src: &NamedSource<Arc<String>>,
module_ctx: Option<ModuleTypeContext<'_>>,
) -> Result<ResolvedTypeExpr, GraphcalError> {
if let Some(ctx) = module_ctx
&& dim_params.is_empty()
&& index_params.is_empty()
&& nat_params.is_empty()
{
return resolve_ast_type_expr_via_hir(type_ann, registry, src, ctx);
}
match &type_ann.kind {
TypeExprKind::Dimensionless => Ok(ResolvedTypeExpr::Dimensionless),
TypeExprKind::Bool => Ok(ResolvedTypeExpr::Bool),
TypeExprKind::Int => Ok(ResolvedTypeExpr::Int),
TypeExprKind::Datetime => Ok(ResolvedTypeExpr::Datetime(TimeScale::UTC)),
TypeExprKind::DatetimeApplication { type_args } => {
resolve_datetime_application(type_ann, type_args, src)
}
TypeExprKind::Indexed { base, indexes } => {
let resolved_base = resolve_type_expr_inner(
base,
registry,
owner,
dim_params,
index_params,
nat_params,
src,
module_ctx,
)?;
let mut resolved_indexes = Vec::with_capacity(indexes.len());
for idx in indexes {
match idx {
crate::desugar::desugared_ast::IndexExpr::NatExpr(nat_expr) => {
let form = normalize_nat_expr(nat_expr, nat_params, src)?;
resolved_indexes.push(ResolvedIndex::NatExpr(form, nat_expr.span()));
}
crate::desugar::desugared_ast::IndexExpr::Name(path) => {
resolved_indexes.push(resolve_index_expr_name(
&path.value,
path.span,
registry,
owner,
index_params,
nat_params,
src,
module_ctx,
)?);
}
}
}
Ok(ResolvedTypeExpr::Indexed {
base: Box::new(resolved_base),
indexes: resolved_indexes,
})
}
TypeExprKind::DimExpr(dim_expr) => resolve_dim_expr(
dim_expr,
registry,
owner,
dim_params,
index_params,
src,
module_ctx,
),
TypeExprKind::TypeApplication { name, type_args } => resolve_type_application(
type_ann,
name,
type_args,
registry,
owner,
dim_params,
index_params,
nat_params,
src,
module_ctx,
),
}
}
fn resolve_dim_expr(
dim_expr: &crate::desugar::desugared_ast::DimExpr,
registry: &Registry,
owner: &crate::dag_id::DagId,
dim_params: &[GenericParamName],
index_params: &[GenericParamName],
src: &NamedSource<Arc<String>>,
module_ctx: Option<ModuleTypeContext<'_>>,
) -> Result<ResolvedTypeExpr, GraphcalError> {
if dim_expr.terms.len() == 1 && dim_expr.terms[0].term.power.is_none() {
let term = &dim_expr.terms[0].term;
if let Some(index) = resolve_concrete_index_path(
&term.name.value,
term.name.span,
registry,
owner,
src,
module_ctx,
)? {
return Ok(ResolvedTypeExpr::IndexArg(ResolvedIndex::Concrete(
index, term.span,
)));
}
if let Some(atom) = term.name.value.as_bare()
&& let Some(gp) = index_params.iter().find(|p| p.as_str() == atom.as_str())
{
return Ok(ResolvedTypeExpr::IndexArg(ResolvedIndex::GenericParam(
gp.clone(),
term.span,
)));
}
if let Some((type_name, _)) = resolve_struct_type_path(
&term.name.value,
term.name.span,
registry,
owner,
src,
module_ctx,
)? {
return Ok(ResolvedTypeExpr::Struct(type_name, term.span));
}
if let Some(atom) = term.name.value.as_bare()
&& let Some(gp) = dim_params.iter().find(|p| p.as_str() == atom.as_str())
{
return Ok(ResolvedTypeExpr::GenericDimParam(gp.clone(), term.span));
}
}
let has_generic = dim_expr.terms.iter().any(|item| {
item.term
.name
.value
.as_bare()
.is_some_and(|atom| dim_params.iter().any(|p| p.as_str() == atom.as_str()))
});
if has_generic {
let terms = dim_expr
.terms
.iter()
.map(|item| {
resolve_dim_term_in_generic_expr(item, registry, dim_params, src, module_ctx)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(ResolvedTypeExpr::GenericDimExpr {
terms,
span: dim_expr.span,
})
} else {
let result = dim_expr.terms.iter().try_fold(
Dimension::dimensionless(),
|acc, item| -> Result<Dimension, GraphcalError> {
let base = concrete_dimension_for_term(item, registry, src, module_ctx)?;
let exp = item.term.power.unwrap_or(Rational::ONE);
let overflow_err = || GraphcalError::DimensionOverflow {
src: src.clone(),
span: item.term.span.into(),
};
let powered = base.pow(exp).map_err(|_| overflow_err())?;
match item.op {
MulDivOp::Mul => (acc * powered).map_err(|_| overflow_err()),
MulDivOp::Div => (acc / powered).map_err(|_| overflow_err()),
}
},
)?;
Ok(ResolvedTypeExpr::Scalar(result))
}
}
fn resolve_dim_term_in_generic_expr(
item: &crate::desugar::desugared_ast::DimExprItem,
registry: &Registry,
dim_params: &[GenericParamName],
src: &NamedSource<Arc<String>>,
module_ctx: Option<ModuleTypeContext<'_>>,
) -> Result<ResolvedDimTerm, GraphcalError> {
let power = item.term.power.unwrap_or(Rational::ONE);
let op = item.op;
if let Some(atom) = item.term.name.value.as_bare()
&& let Some(gp) = dim_params.iter().find(|p| p.as_str() == atom.as_str())
{
return Ok(ResolvedDimTerm::GenericParam {
name: gp.clone(),
power,
op,
span: item.term.span,
});
}
concrete_dimension_for_term(item, registry, src, module_ctx)
.map(|dim| ResolvedDimTerm::Concrete { dim, power, op })
}
fn concrete_dimension_for_term(
item: &crate::desugar::desugared_ast::DimExprItem,
registry: &Registry,
src: &NamedSource<Arc<String>>,
module_ctx: Option<ModuleTypeContext<'_>>,
) -> Result<Dimension, GraphcalError> {
resolve_dimension_path(
&item.term.name.value,
item.term.name.span,
registry,
src,
module_ctx,
)?
.ok_or_else(|| {
let name = item
.term
.name
.value
.as_bare()
.map_or_else(|| item.term.name.value.display_path(), ToString::to_string);
GraphcalError::UnknownDimension {
name: DimName::new(name),
src: src.clone(),
span: item.term.span.into(),
}
})
}
fn resolve_datetime_application(
type_ann: &TypeExpr,
type_args: &[TypeExpr],
src: &NamedSource<Arc<String>>,
) -> Result<ResolvedTypeExpr, GraphcalError> {
if type_args.len() != 1 {
return Err(GraphcalError::EvalError {
message: format!(
"type `Datetime` expects 0 or 1 type argument(s), got {}",
type_args.len()
),
src: src.clone(),
span: type_ann.span.into(),
});
}
let arg = &type_args[0];
match &arg.kind {
TypeExprKind::DimExpr(dim_expr)
if dim_expr.terms.len() == 1 && dim_expr.terms[0].term.power.is_none() =>
{
let term = &dim_expr.terms[0].term;
let name = require_local_type_level_path(&term.name.value, term.name.span, src)?;
name.parse::<TimeScale>().map_or_else(
|_| {
Err(GraphcalError::EvalError {
message: format!(
"unknown time scale `{name}`; \
expected one of: UTC, TAI, TT, TDB, ET, GPST, GST, BDT"
),
src: src.clone(),
span: arg.span.into(),
})
},
|scale| Ok(ResolvedTypeExpr::Datetime(scale)),
)
}
_ => Err(GraphcalError::EvalError {
message: "expected a time scale name (e.g., UTC, TAI, TT, TDB, GPST)".to_string(),
src: src.clone(),
span: arg.span.into(),
}),
}
}
#[expect(
clippy::too_many_arguments,
reason = "passes full type resolution context from resolve_type_expr"
)]
fn resolve_type_application(
type_ann: &TypeExpr,
name: &crate::syntax::span::Spanned<NamePath>,
type_args: &[TypeExpr],
registry: &Registry,
owner: &crate::dag_id::DagId,
dim_params: &[GenericParamName],
index_params: &[GenericParamName],
nat_params: &[GenericParamName],
src: &NamedSource<Arc<String>>,
module_ctx: Option<ModuleTypeContext<'_>>,
) -> Result<ResolvedTypeExpr, GraphcalError> {
let (type_name, type_def) =
resolve_struct_type_path(&name.value, name.span, registry, owner, src, module_ctx)?
.ok_or_else(|| GraphcalError::UnknownStructType {
name: name.value.display_path(),
src: src.clone(),
span: name.span.into(),
})?;
check_type_application_arity(
type_name.as_str(),
type_def,
type_args.len(),
type_ann.span,
src,
)?;
let mut resolved_args = Vec::with_capacity(type_def.generic_params.len());
for (param, arg) in type_def.generic_params.iter().zip(type_args) {
let resolved = resolve_type_arg_for_param(
param,
arg,
registry,
owner,
dim_params,
index_params,
nat_params,
src,
module_ctx,
)?;
resolved_args.push(resolved);
}
for param in type_def.generic_params.iter().skip(type_args.len()) {
let default_expr = param
.default
.as_ref()
.ok_or_else(|| GraphcalError::EvalError {
message: format!(
"internal: generic parameter `{}` has no default",
param.name
),
src: src.clone(),
span: type_ann.span.into(),
})?;
let default_ctx = module_ctx
.map(|ctx| ModuleTypeContext::new(type_name.owner(), ctx.resolver, ctx.types));
let resolved = resolve_type_arg_for_param(
param,
default_expr,
registry,
type_name.owner(),
dim_params,
index_params,
nat_params,
src,
default_ctx,
)?;
resolved_args.push(resolved);
}
Ok(ResolvedTypeExpr::GenericStruct {
name: type_name,
type_args: resolved_args,
span: type_ann.span,
})
}
#[expect(
clippy::too_many_arguments,
reason = "passes full type resolution context from resolve_type_application"
)]
fn resolve_type_arg_for_param(
param: &crate::registry::types::TypeGenericParam,
arg: &TypeExpr,
registry: &Registry,
owner: &crate::dag_id::DagId,
dim_params: &[GenericParamName],
index_params: &[GenericParamName],
nat_params: &[GenericParamName],
src: &NamedSource<Arc<String>>,
module_ctx: Option<ModuleTypeContext<'_>>,
) -> Result<ResolvedTypeExpr, GraphcalError> {
let resolved = resolve_type_expr_inner(
arg,
registry,
owner,
dim_params,
index_params,
nat_params,
src,
module_ctx,
)?;
match (param.constraint, &resolved) {
(TypeGenericConstraint::Index, ResolvedTypeExpr::IndexArg(_)) => Ok(resolved),
(TypeGenericConstraint::Index, _) => Err(GraphcalError::EvalError {
message: format!(
"generic parameter `{}` expects an Index argument",
param.name
),
src: src.clone(),
span: arg.span.into(),
}),
(TypeGenericConstraint::Nat, _) => Err(GraphcalError::EvalError {
message: format!(
"generic parameter `{}` expects a Nat argument, got a type argument",
param.name
),
src: src.clone(),
span: arg.span.into(),
}),
(TypeGenericConstraint::Dim, ResolvedTypeExpr::IndexArg(index)) => {
Err(GraphcalError::EvalError {
message: format!(
"index `{}` cannot be used as a Dim argument",
format_resolved_index(index)
),
src: src.clone(),
span: arg.span.into(),
})
}
(TypeGenericConstraint::Unconstrained, ResolvedTypeExpr::IndexArg(index)) => {
Err(GraphcalError::EvalError {
message: format!(
"index `{}` cannot be used as a Type argument",
format_resolved_index(index)
),
src: src.clone(),
span: arg.span.into(),
})
}
(TypeGenericConstraint::Dim | TypeGenericConstraint::Unconstrained, _) => Ok(resolved),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::registry::prelude::load_prelude;
use crate::registry::types::RegistryBuilder;
use crate::syntax::dimension::BaseDimId;
use crate::syntax::parser::Parser;
fn make_registry() -> Registry {
let mut b = RegistryBuilder::new();
load_prelude(&mut b).unwrap();
b.build()
}
fn make_dim_term_name(
name: &str,
) -> crate::syntax::span::Spanned<crate::syntax::names::NamePath> {
crate::syntax::span::Spanned::new(
crate::syntax::names::NamePath::from(name),
Span::new(0, 0),
)
}
fn make_dim_type_expr(name: &str) -> crate::desugar::desugared_ast::TypeExpr {
crate::desugar::desugared_ast::TypeExpr {
kind: crate::desugar::desugared_ast::TypeExprKind::DimExpr(
crate::desugar::desugared_ast::DimExpr {
terms: vec![crate::desugar::desugared_ast::DimExprItem {
op: crate::desugar::desugared_ast::MulDivOp::Mul,
term: crate::desugar::desugared_ast::DimTerm {
name: make_dim_term_name(name),
power: None,
span: Span::new(0, 0),
},
}],
span: Span::new(0, 0),
},
),
constraints: vec![],
span: Span::new(0, 0),
}
}
fn make_registry_with_struct() -> Registry {
let mut b = RegistryBuilder::new();
load_prelude(&mut b).unwrap();
b.register_type(crate::registry::types::TypeDef {
name: StructTypeName::new("TransferResult"),
generic_params: vec![],
kind: crate::registry::types::TypeDefKind::Union {
members: vec![crate::registry::types::UnionMemberDef {
name: crate::syntax::names::ConstructorName::new("TransferResult"),
fields: vec![
crate::registry::types::StructField {
name: crate::syntax::names::FieldName::new("dv1"),
type_ann: make_dim_type_expr("Velocity"),
},
crate::registry::types::StructField {
name: crate::syntax::names::FieldName::new("dv2"),
type_ann: make_dim_type_expr("Velocity"),
},
],
}],
},
});
b.build()
}
fn make_registry_with_index() -> Registry {
let mut b = RegistryBuilder::new();
load_prelude(&mut b).unwrap();
b.register_index(crate::registry::types::IndexDef {
name: IndexName::new("Maneuver"),
kind: crate::registry::types::IndexKind::Named {
variants: vec![
crate::syntax::names::IndexVariantName::new("Departure"),
crate::syntax::names::IndexVariantName::new("Insertion"),
],
},
});
b.build()
}
fn make_src() -> NamedSource<Arc<String>> {
NamedSource::new("test", Arc::new(String::new()))
}
fn parse_type(source: &str) -> TypeExpr {
let full = format!("param x: {source} = 0.0;");
let raw_file = Parser::new(&full).parse_file().unwrap();
let desugared = crate::syntax::desugar::desugar_multi_decls_in_file(raw_file);
let file = desugared;
match &file.declarations[0].kind {
crate::desugar::desugared_ast::DeclKind::Param(p) => p.type_ann.clone(),
_ => panic!("expected param"),
}
}
#[test]
fn resolve_dimensionless() {
let r = make_registry();
let te = parse_type("Dimensionless");
let resolved = resolve_type_expr(&te, &r, &[], &[], &[], &make_src()).unwrap();
assert_eq!(resolved, ResolvedTypeExpr::Dimensionless);
}
#[test]
fn resolve_bool() {
let r = make_registry();
let te = parse_type("Bool");
let resolved = resolve_type_expr(&te, &r, &[], &[], &[], &make_src()).unwrap();
assert_eq!(resolved, ResolvedTypeExpr::Bool);
}
#[test]
fn resolve_int() {
let r = make_registry();
let te = parse_type("Int");
let resolved = resolve_type_expr(&te, &r, &[], &[], &[], &make_src()).unwrap();
assert_eq!(resolved, ResolvedTypeExpr::Int);
}
#[test]
fn resolve_concrete_dimension() {
let r = make_registry();
let te = parse_type("Length");
let resolved = resolve_type_expr(&te, &r, &[], &[], &[], &make_src()).unwrap();
assert_eq!(
resolved,
ResolvedTypeExpr::Scalar(Dimension::base(BaseDimId::Prelude("Length".to_string())))
);
}
#[test]
fn resolve_compound_dimension() {
let r = make_registry();
let te = parse_type("Length / Time^2");
let resolved = resolve_type_expr(&te, &r, &[], &[], &[], &make_src()).unwrap();
let expected = (Dimension::base(BaseDimId::Prelude("Length".to_string()))
/ Dimension::base(BaseDimId::Prelude("Time".to_string()))
.pow_int(2)
.unwrap())
.unwrap();
assert_eq!(resolved, ResolvedTypeExpr::Scalar(expected));
}
#[test]
fn resolve_struct_type() {
let r = make_registry_with_struct();
let te = parse_type("TransferResult");
let resolved = resolve_type_expr(&te, &r, &[], &[], &[], &make_src()).unwrap();
assert!(
matches!(resolved, ResolvedTypeExpr::Struct(name, _) if name.as_str() == "TransferResult")
);
}
#[test]
fn resolve_generic_dim_param() {
let r = make_registry();
let dim_params = vec![GenericParamName::new("D")];
let te = parse_type("D");
let resolved = resolve_type_expr(&te, &r, &dim_params, &[], &[], &make_src()).unwrap();
assert!(
matches!(resolved, ResolvedTypeExpr::GenericDimParam(name, _) if name.as_str() == "D")
);
}
#[test]
fn resolve_generic_dim_expr_with_power() {
let r = make_registry();
let dim_params = vec![GenericParamName::new("D")];
let te = parse_type("D^2");
let resolved = resolve_type_expr(&te, &r, &dim_params, &[], &[], &make_src()).unwrap();
match resolved {
ResolvedTypeExpr::GenericDimExpr { terms, .. } => {
assert_eq!(terms.len(), 1);
match &terms[0] {
ResolvedDimTerm::GenericParam { name, power, .. } => {
assert_eq!(name.as_str(), "D");
assert_eq!(*power, Rational::from_int(2));
}
ResolvedDimTerm::Concrete { .. } => panic!("expected GenericParam term"),
}
}
_ => panic!("expected GenericDimExpr"),
}
}
#[test]
fn resolve_mixed_generic_concrete() {
let r = make_registry();
let dim_params = vec![GenericParamName::new("D")];
let te = parse_type("D * Length");
let resolved = resolve_type_expr(&te, &r, &dim_params, &[], &[], &make_src()).unwrap();
match resolved {
ResolvedTypeExpr::GenericDimExpr { terms, .. } => {
assert_eq!(terms.len(), 2);
assert!(
matches!(&terms[0], ResolvedDimTerm::GenericParam { name, .. } if name.as_str() == "D")
);
assert!(matches!(&terms[1], ResolvedDimTerm::Concrete { .. }));
}
_ => panic!("expected GenericDimExpr, got {resolved:?}"),
}
}
#[test]
fn resolve_concrete_indexed() {
let r = make_registry_with_index();
let te = parse_type("Length[Maneuver]");
let resolved = resolve_type_expr(&te, &r, &[], &[], &[], &make_src()).unwrap();
match resolved {
ResolvedTypeExpr::Indexed { base, indexes } => {
assert_eq!(
*base,
ResolvedTypeExpr::Scalar(Dimension::base(BaseDimId::Prelude(
"Length".to_string()
)))
);
assert_eq!(indexes.len(), 1);
assert!(
matches!(&indexes[0], ResolvedIndex::Concrete(name, _) if name.as_str() == "Maneuver")
);
}
_ => panic!("expected Indexed"),
}
}
#[test]
fn resolve_generic_indexed() {
let r = make_registry();
let dim_params = vec![GenericParamName::new("D")];
let index_params = vec![GenericParamName::new("I")];
let te = parse_type("D[I]");
let resolved =
resolve_type_expr(&te, &r, &dim_params, &index_params, &[], &make_src()).unwrap();
match resolved {
ResolvedTypeExpr::Indexed { base, indexes } => {
assert!(
matches!(*base, ResolvedTypeExpr::GenericDimParam(ref name, _) if name.as_str() == "D")
);
assert_eq!(indexes.len(), 1);
assert!(
matches!(&indexes[0], ResolvedIndex::GenericParam(name, _) if name.as_str() == "I")
);
}
_ => panic!("expected Indexed"),
}
}
#[test]
fn resolve_unknown_dimension_error() {
let r = make_registry();
let te = parse_type("UnknownDim");
let err = resolve_type_expr(&te, &r, &[], &[], &[], &make_src()).unwrap_err();
assert!(matches!(err, GraphcalError::UnknownDimension { .. }));
}
#[test]
fn resolve_unknown_index_error() {
let r = make_registry();
let te = parse_type("Length[UnknownIdx]");
let err = resolve_type_expr(&te, &r, &[], &[], &[], &make_src()).unwrap_err();
assert!(matches!(err, GraphcalError::UnknownIndex { .. }));
}
#[test]
fn resolve_struct_takes_priority_over_dim_param() {
let r = make_registry_with_struct();
let dim_params = vec![GenericParamName::new("TransferResult")];
let te = parse_type("TransferResult");
let resolved = resolve_type_expr(&te, &r, &dim_params, &[], &[], &make_src()).unwrap();
assert!(matches!(resolved, ResolvedTypeExpr::Struct(..)));
}
#[test]
fn resolve_velocity_derived_dimension() {
let r = make_registry();
let te = parse_type("Velocity");
let resolved = resolve_type_expr(&te, &r, &[], &[], &[], &make_src()).unwrap();
let expected = (Dimension::base(BaseDimId::Prelude("Length".to_string()))
/ Dimension::base(BaseDimId::Prelude("Time".to_string())))
.unwrap();
assert_eq!(resolved, ResolvedTypeExpr::Scalar(expected));
}
fn parse_and_type_resolve(source: &str) -> Result<TIR, GraphcalError> {
let raw_file = Parser::new(source).parse_file().unwrap();
let desugared = crate::syntax::desugar::desugar_multi_decls_in_file(raw_file);
let file = desugared;
let src = NamedSource::new("test.gcl", Arc::new(source.to_string()));
let ir = crate::ir::lower::lower(&file, &src)?;
let parent_dag_id =
crate::dag_id::DagId::from_relative_path(std::path::Path::new("test.gcl")).unwrap();
let mut resolver = ModuleResolver::default();
resolver
.add_module(parent_dag_id.clone(), &file.declarations)
.map_err(|err| {
internal_error(
format!("test module resolver failed for root module: {err}"),
&src,
Span::new(0, 0),
)
})?;
for decl in &file.declarations {
if let crate::desugar::desugared_ast::DeclKind::Dag(dag) = &decl.kind {
resolver
.add_module(parent_dag_id.child(dag.name.value.as_str()), &dag.body)
.map_err(|err| {
internal_error(
format!(
"test module resolver failed for inline dag `{}`: {err}",
dag.name.value
),
&src,
Span::new(0, 0),
)
})?;
}
}
let mut module_types = ModuleTypeRegistry::default();
module_types.insert_graphcal_prelude().map_err(|err| {
internal_error(
format!("test module type prelude failed: {err}"),
&src,
Span::new(0, 0),
)
})?;
module_types.insert_registry(&parent_dag_id, &ir.registry);
let mut tir =
type_resolve_with_modules(ir, parent_dag_id.clone(), &src, &resolver, &module_types)?;
compile_inline_dag_bodies_test(&mut tir, &src, &parent_dag_id, &file.declarations)?;
Ok(tir)
}
fn compile_inline_dag_bodies_test(
tir: &mut TIR,
src: &NamedSource<Arc<String>>,
parent_dag_id: &crate::dag_id::DagId,
parent_declarations: &[crate::desugar::desugared_ast::Declaration],
) -> Result<(), GraphcalError> {
let dag_bodies = tir
.registry
.dags
.all_dags()
.map(|(name, dag)| (name.clone(), dag.body.clone()))
.collect::<Vec<_>>();
let mut resolver = ModuleResolver::default();
resolver
.add_module(parent_dag_id.clone(), parent_declarations)
.map_err(|err| {
internal_error(
format!("test module resolver failed for parent module: {err}"),
src,
Span::new(0, 0),
)
})?;
for (name, body) in &dag_bodies {
resolver
.add_module(parent_dag_id.child(name.as_str()), body)
.map_err(|err| {
internal_error(
format!("test module resolver failed for inline dag `{name}`: {err}"),
src,
Span::new(0, 0),
)
})?;
}
let mut module_types = ModuleTypeRegistry::default();
module_types.insert_graphcal_prelude().map_err(|err| {
internal_error(
format!("test module type prelude failed: {err}"),
src,
Span::new(0, 0),
)
})?;
module_types.insert_registry(parent_dag_id, &tir.registry);
for (name, body) in dag_bodies {
let dag_body_ir = crate::ir::lower::lower_dag_body_to_ir(
name.as_str(),
&body,
&tir.registry,
&resolver,
&crate::ir::resolve::ImportedValueNames::default(),
HashMap::new(),
HashMap::new(),
src,
parent_dag_id,
)?;
let dag_id = parent_dag_id.child(name.as_str());
let mut compiled_dag = type_resolve_single_with_modules(
dag_body_ir,
&dag_id,
src,
&resolver,
&module_types,
)?;
compiled_dag.populate_pub_nodes(&body);
tir.dags.insert(dag_id, compiled_dag);
}
Ok(())
}
#[test]
fn module_aware_type_resolve_records_semantic_deps() {
let source = "const node C: Dimensionless = 1.0;\n\
const node D: Dimensionless = C;\n\
param p: Dimensionless;\n\
node x: Dimensionless = @p + D;";
let raw_file = Parser::new(source).parse_file().unwrap();
let desugared = crate::syntax::desugar::desugar_multi_decls_in_file(raw_file);
let file = desugared;
let src = NamedSource::new("test.gcl", Arc::new(source.to_string()));
let dag_id =
crate::dag_id::DagId::from_relative_path(std::path::Path::new("test.gcl")).unwrap();
let ir = crate::ir::lower::lower(&file, &src).unwrap();
let mut resolver = ModuleResolver::default();
resolver
.add_module(dag_id.clone(), &file.declarations)
.unwrap();
let mut module_types = ModuleTypeRegistry::default();
module_types.insert_graphcal_prelude().unwrap();
module_types.insert_registry(&dag_id, &ir.registry);
let tir =
type_resolve_with_modules(ir, dag_id.clone(), &src, &resolver, &module_types).unwrap();
let deps = &tir.root().semantic.dependencies;
let c = ResolvedName::from_def(dag_id.clone(), DeclName::new("C"));
let d = ResolvedName::from_def(dag_id.clone(), DeclName::new("D"));
let p = ResolvedName::from_def(dag_id.clone(), DeclName::new("p"));
let x = ResolvedName::from_def(dag_id, DeclName::new("x"));
assert!(deps.const_deps[&d].contains(&c));
assert!(deps.const_deps[&c].is_empty());
assert!(deps.runtime_deps[&x].contains(&p));
assert!(deps.runtime_deps[&p].is_empty());
}
#[test]
fn type_resolve_rocket() {
let source = include_str!("../../../../tests/fixtures/valid/rocket.gcl");
let tir = parse_and_type_resolve(source).unwrap();
assert!(
tir.root()
.resolved_decl_types
.contains_key(&ScopedName::local("dry_mass"))
);
assert!(
tir.root()
.resolved_decl_types
.contains_key(&ScopedName::local("delta_v"))
);
assert!(
tir.root()
.resolved_decl_types
.contains_key(&ScopedName::local("g0"))
);
}
#[test]
fn type_resolve_indexed() {
let source = include_str!("../../../../tests/fixtures/valid/indexed.gcl");
let tir = parse_and_type_resolve(source).unwrap();
let dv_type = &tir.root().resolved_decl_types[&ScopedName::local("delta_v")];
assert!(matches!(dv_type, ResolvedTypeExpr::Indexed { .. }));
}
#[test]
fn type_resolve_hohmann() {
let source = include_str!("../../../../tests/fixtures/valid/hohmann.gcl");
let err = parse_and_type_resolve(source).unwrap_err();
assert!(
err.to_string().contains("transfer"),
"unexpected error: {err}"
);
}
#[test]
fn type_resolve_generics() {
let source = include_str!("../../../../tests/fixtures/valid/generics.gcl");
let tir = parse_and_type_resolve(source).unwrap();
let pos_type = &tir.root().resolved_decl_types[&ScopedName::local("pos_eci")];
match pos_type {
ResolvedTypeExpr::GenericStruct {
name, type_args, ..
} => {
assert_eq!(name.as_str(), "Vec3");
assert_eq!(type_args.len(), 2);
assert_eq!(
type_args[0],
ResolvedTypeExpr::Scalar(Dimension::base(BaseDimId::Prelude(
"Length".to_string()
)))
);
assert!(
matches!(&type_args[1], ResolvedTypeExpr::Struct(n, _) if n.as_str() == "Eci")
);
}
other => panic!("expected GenericStruct, got {other:?}"),
}
assert_eq!(
tir.root().resolved_decl_types[&ScopedName::local("x_pos")],
ResolvedTypeExpr::Scalar(Dimension::base(BaseDimId::Prelude("Length".to_string())))
);
}
#[test]
fn type_resolve_default_type_params() {
let source = include_str!("../../../../tests/fixtures/valid/generics.gcl");
let tir = parse_and_type_resolve(source).unwrap();
let pos3_eci = &tir.root().resolved_decl_types[&ScopedName::local("pos3_eci")];
match pos3_eci {
ResolvedTypeExpr::GenericStruct {
name, type_args, ..
} => {
assert_eq!(name.as_str(), "Pos3");
assert_eq!(type_args.len(), 2);
assert_eq!(
type_args[0],
ResolvedTypeExpr::Scalar(Dimension::base(BaseDimId::Prelude(
"Length".to_string()
)))
);
assert!(
matches!(&type_args[1], ResolvedTypeExpr::Struct(n, _) if n.as_str() == "Eci")
);
}
other => panic!("expected GenericStruct, got {other:?}"),
}
let pos3_default = &tir.root().resolved_decl_types[&ScopedName::local("pos3_default")];
match pos3_default {
ResolvedTypeExpr::GenericStruct {
name, type_args, ..
} => {
assert_eq!(name.as_str(), "Pos3");
assert_eq!(type_args.len(), 2);
assert_eq!(
type_args[0],
ResolvedTypeExpr::Scalar(Dimension::base(BaseDimId::Prelude(
"Length".to_string()
)))
);
assert!(
matches!(&type_args[1], ResolvedTypeExpr::Struct(n, _) if n.as_str() == "Unframed"),
"expected Struct(Unframed), got {:?}",
type_args[1]
);
}
other => panic!("expected GenericStruct, got {other:?}"),
}
}
use crate::registry::declared_type::{DeclaredType, IndexTypeRef, StructTypeRef};
#[test]
fn generic_index_substitution_preserves_resolved_owner() {
use crate::tir::dim_check::{InferredIndex, InferredType};
let src = make_src();
let registry = make_registry();
let owner = crate::dag_id::DagId::root("a");
let resolved_index = ResolvedName::from_def(owner, IndexName::new("Phase"));
let generic = GenericParamName::new("I");
let resolved_type = ResolvedTypeExpr::Indexed {
base: Box::new(ResolvedTypeExpr::Dimensionless),
indexes: vec![ResolvedIndex::GenericParam(
generic.clone(),
Span::new(0, 0),
)],
};
let actual = InferredType::Indexed {
element: Box::new(InferredType::Scalar(Dimension::dimensionless())),
index: InferredIndex::from_resolved(resolved_index.clone()),
};
let mut dim_sub = HashMap::new();
let mut index_sub = HashMap::new();
let mut nat_sub = HashMap::new();
unify_resolved_type(
&resolved_type,
&actual,
&mut dim_sub,
&mut index_sub,
&mut nat_sub,
®istry,
&src,
Span::new(0, 0),
)
.unwrap();
assert_eq!(
index_sub[&generic].declared_resolved(),
Some(&resolved_index)
);
let substituted =
substitute_resolved_type(&resolved_type, &dim_sub, &index_sub, &nat_sub, &src).unwrap();
let InferredType::Indexed { index, .. } = substituted else {
panic!("expected indexed type after substitution");
};
assert_eq!(index.declared_resolved(), Some(&resolved_index));
}
#[test]
fn convert_dimensionless() {
let dt = resolved_to_declared_type(&ResolvedTypeExpr::Dimensionless, &make_src()).unwrap();
assert_eq!(dt, DeclaredType::Scalar(Dimension::dimensionless()));
}
#[test]
fn convert_bool() {
let dt = resolved_to_declared_type(&ResolvedTypeExpr::Bool, &make_src()).unwrap();
assert_eq!(dt, DeclaredType::Bool);
}
#[test]
fn convert_int() {
let dt = resolved_to_declared_type(&ResolvedTypeExpr::Int, &make_src()).unwrap();
assert_eq!(dt, DeclaredType::Int);
}
#[test]
fn convert_scalar() {
let dim = Dimension::base(BaseDimId::Prelude("Length".to_string()));
let dt =
resolved_to_declared_type(&ResolvedTypeExpr::Scalar(dim.clone()), &make_src()).unwrap();
assert_eq!(dt, DeclaredType::Scalar(dim));
}
#[test]
fn convert_struct() {
let owner = crate::dag_id::DagId::root("test");
let resolved = ResolvedName::from_def(owner, StructTypeName::new("Foo"));
let dt = resolved_to_declared_type(
&ResolvedTypeExpr::Struct(resolved.clone(), Span::new(0, 0)),
&make_src(),
)
.unwrap();
assert_eq!(
dt,
DeclaredType::Struct(StructTypeRef::from_resolved(resolved), vec![])
);
}
#[test]
fn convert_indexed() {
let owner = crate::dag_id::DagId::root("test");
let resolved_index = ResolvedName::from_def(owner, IndexName::new("M"));
let dt = resolved_to_declared_type(
&ResolvedTypeExpr::Indexed {
base: Box::new(ResolvedTypeExpr::Scalar(Dimension::base(
BaseDimId::Prelude("Length".to_string()),
))),
indexes: vec![ResolvedIndex::Concrete(
resolved_index.clone(),
Span::new(0, 0),
)],
},
&make_src(),
)
.unwrap();
assert_eq!(
dt,
DeclaredType::Indexed {
element: Box::new(DeclaredType::Scalar(Dimension::base(BaseDimId::Prelude(
"Length".to_string()
)))),
index: IndexTypeRef::from_resolved(resolved_index),
}
);
}
#[test]
fn convert_generic_dim_param_fails() {
let err = resolved_to_declared_type(
&ResolvedTypeExpr::GenericDimParam(GenericParamName::new("D"), Span::new(0, 0)),
&make_src(),
)
.unwrap_err();
assert!(matches!(err, GraphcalError::EvalError { .. }));
}
#[test]
fn convert_generic_index_fails() {
let err = resolved_to_declared_type(
&ResolvedTypeExpr::Indexed {
base: Box::new(ResolvedTypeExpr::Dimensionless),
indexes: vec![ResolvedIndex::GenericParam(
GenericParamName::new("I"),
Span::new(0, 0),
)],
},
&make_src(),
)
.unwrap_err();
assert!(matches!(err, GraphcalError::EvalError { .. }));
}
#[test]
fn resolve_bare_datetime() {
let r = make_registry();
let te = parse_type("Datetime");
let resolved = resolve_type_expr(&te, &r, &[], &[], &[], &make_src()).unwrap();
assert_eq!(resolved, ResolvedTypeExpr::Datetime(TimeScale::UTC));
}
#[test]
fn resolve_datetime_utc() {
let r = make_registry();
let te = parse_type("Datetime<UTC>");
let resolved = resolve_type_expr(&te, &r, &[], &[], &[], &make_src()).unwrap();
assert_eq!(resolved, ResolvedTypeExpr::Datetime(TimeScale::UTC));
}
#[test]
fn resolve_datetime_tt() {
let r = make_registry();
let te = parse_type("Datetime<TT>");
let resolved = resolve_type_expr(&te, &r, &[], &[], &[], &make_src()).unwrap();
assert_eq!(resolved, ResolvedTypeExpr::Datetime(TimeScale::TT));
}
#[test]
fn resolve_datetime_tai() {
let r = make_registry();
let te = parse_type("Datetime<TAI>");
let resolved = resolve_type_expr(&te, &r, &[], &[], &[], &make_src()).unwrap();
assert_eq!(resolved, ResolvedTypeExpr::Datetime(TimeScale::TAI));
}
#[test]
fn resolve_datetime_gpst() {
let r = make_registry();
let te = parse_type("Datetime<GPST>");
let resolved = resolve_type_expr(&te, &r, &[], &[], &[], &make_src()).unwrap();
assert_eq!(resolved, ResolvedTypeExpr::Datetime(TimeScale::GPST));
}
#[test]
fn resolve_datetime_unknown_scale_error() {
let r = make_registry();
let te = parse_type("Datetime<XYZ>");
let err = resolve_type_expr(&te, &r, &[], &[], &[], &make_src()).unwrap_err();
assert!(matches!(err, GraphcalError::EvalError { .. }));
}
#[test]
fn convert_datetime_utc() {
let dt =
resolved_to_declared_type(&ResolvedTypeExpr::Datetime(TimeScale::UTC), &make_src())
.unwrap();
assert_eq!(dt, DeclaredType::Datetime(TimeScale::UTC));
}
#[test]
fn convert_datetime_tt() {
let dt = resolved_to_declared_type(&ResolvedTypeExpr::Datetime(TimeScale::TT), &make_src())
.unwrap();
assert_eq!(dt, DeclaredType::Datetime(TimeScale::TT));
}
#[test]
fn nat_leq_constant_equal() {
let a = NatPolyForm::from_constant(3);
let b = NatPolyForm::from_constant(3);
assert!(a.is_leq(&b));
}
#[test]
fn nat_leq_constant_less() {
let a = NatPolyForm::from_constant(2);
let b = NatPolyForm::from_constant(5);
assert!(a.is_leq(&b));
}
#[test]
fn nat_leq_constant_greater() {
let a = NatPolyForm::from_constant(5);
let b = NatPolyForm::from_constant(3);
assert!(!a.is_leq(&b));
}
#[test]
fn nat_leq_same_var() {
let a = NatPolyForm::from_var(GenericParamName::new("N"));
let b = NatPolyForm::from_var(GenericParamName::new("N"));
assert!(a.is_leq(&b));
}
#[test]
fn nat_leq_var_plus_constant() {
let a = NatPolyForm::from_var(GenericParamName::new("N"));
let b = NatPolyForm::from_var(GenericParamName::new("N"))
.add(&NatPolyForm::from_constant(1))
.unwrap();
assert!(a.is_leq(&b));
}
#[test]
fn nat_leq_var_plus_constant_reverse() {
let a = NatPolyForm::from_var(GenericParamName::new("N"))
.add(&NatPolyForm::from_constant(1))
.unwrap();
let b = NatPolyForm::from_var(GenericParamName::new("N"));
assert!(!a.is_leq(&b));
}
#[test]
fn nat_leq_different_vars() {
let a = NatPolyForm::from_var(GenericParamName::new("N"));
let b = NatPolyForm::from_var(GenericParamName::new("M"));
assert!(!a.is_leq(&b));
}
#[test]
fn nat_leq_zero_leq_anything() {
let a = NatPolyForm::from_constant(0);
let b = NatPolyForm::from_var(GenericParamName::new("N"));
assert!(a.is_leq(&b));
}
#[test]
fn nat_range_identity_concrete_to_index_type_ref() -> Result<(), Box<dyn std::error::Error>> {
let reference = NatPolyForm::from_constant(3)
.to_nat_range_identity()?
.to_index_type_ref()?;
assert_eq!(
reference
.nat_range()
.map(crate::registry::types::NatRangeIndex::size_u64),
Some(3)
);
assert_eq!(reference.display_name().as_str(), "range(3)");
Ok(())
}
#[test]
fn nat_range_identity_symbolic_to_display_only_index_type_ref()
-> Result<(), Box<dyn std::error::Error>> {
let reference = NatPolyForm::from_var(GenericParamName::new("N"))
.add(&NatPolyForm::from_constant(1))
.unwrap()
.to_nat_range_identity()?
.to_index_type_ref()?;
assert_eq!(reference.nat_range(), None);
assert_eq!(reference.display_name().as_str(), "range(N + 1)");
Ok(())
}
#[test]
fn nat_mul_constants() {
let a = NatPolyForm::from_constant(3);
let b = NatPolyForm::from_constant(4);
assert_eq!(a.mul(&b).unwrap(), NatPolyForm::from_constant(12));
}
#[test]
fn nat_mul_var_by_constant() {
let n = NatPolyForm::from_var(GenericParamName::new("N"));
let three = NatPolyForm::from_constant(3);
let result = n.mul(&three).unwrap();
assert_eq!(result.format(), "3 * N");
let mut bindings = HashMap::new();
bindings.insert(GenericParamName::new("N"), 5);
assert_eq!(result.evaluate(&bindings), Some(15));
}
#[test]
fn nat_mul_two_vars() {
let m = NatPolyForm::from_var(GenericParamName::new("M"));
let n = NatPolyForm::from_var(GenericParamName::new("N"));
let result = m.mul(&n).unwrap();
assert_eq!(result.format(), "M * N");
let mut bindings = HashMap::new();
bindings.insert(GenericParamName::new("M"), 3);
bindings.insert(GenericParamName::new("N"), 4);
assert_eq!(result.evaluate(&bindings), Some(12));
}
#[test]
fn nat_mul_distributive() {
let m = NatPolyForm::from_var(GenericParamName::new("M"));
let n = NatPolyForm::from_var(GenericParamName::new("N"));
let m_plus_1 = m.add(&NatPolyForm::from_constant(1)).unwrap();
let result = m_plus_1.mul(&n).unwrap();
let mut bindings = HashMap::new();
bindings.insert(GenericParamName::new("M"), 2);
bindings.insert(GenericParamName::new("N"), 3);
assert_eq!(result.evaluate(&bindings), Some(9));
}
#[test]
fn nat_mul_mixed_add() {
let m = NatPolyForm::from_var(GenericParamName::new("M"));
let n = NatPolyForm::from_var(GenericParamName::new("N"));
let result = m
.mul(&n)
.unwrap()
.add(&NatPolyForm::from_constant(1))
.unwrap();
assert_eq!(result.format(), "M * N + 1");
let mut bindings = HashMap::new();
bindings.insert(GenericParamName::new("M"), 2);
bindings.insert(GenericParamName::new("N"), 3);
assert_eq!(result.evaluate(&bindings), Some(7));
}
#[test]
fn nat_poly_is_constant() {
let c = NatPolyForm::from_constant(5);
assert!(c.is_constant());
let n = NatPolyForm::from_var(GenericParamName::new("N"));
assert!(!n.is_constant());
let mn = NatPolyForm::from_var(GenericParamName::new("M"))
.mul(&NatPolyForm::from_var(GenericParamName::new("N")))
.unwrap();
assert!(!mn.is_constant());
}
#[test]
fn nat_poly_leq_with_mul() {
let mn = NatPolyForm::from_var(GenericParamName::new("M"))
.mul(&NatPolyForm::from_var(GenericParamName::new("N")))
.unwrap();
let mn_plus_1 = mn.add(&NatPolyForm::from_constant(1)).unwrap();
assert!(mn.is_leq(&mn_plus_1));
assert!(!mn_plus_1.is_leq(&mn));
}
#[test]
fn nat_add_overflow_errors() {
let a = NatPolyForm::from_constant(u64::MAX);
let b = NatPolyForm::from_constant(1);
assert!(a.add(&b).is_err());
}
#[test]
fn nat_mul_overflow_errors() {
let a = NatPolyForm::from_constant(u64::MAX);
let b = NatPolyForm::from_constant(2);
assert!(a.mul(&b).is_err());
}
#[test]
fn nat_unify_substituted_term_overflow_errors() {
let form = NatPolyForm::from_constant(2)
.mul(&NatPolyForm::from_var(GenericParamName::new("N")))
.unwrap();
let mut nat_sub = HashMap::new();
nat_sub.insert(GenericParamName::new("N"), u64::MAX / 2 + 1);
let src = NamedSource::new("<test>", Arc::new(String::new()));
let result = unify_nat_poly_form(
&form,
4,
&mut nat_sub,
&IndexName::new("range(4)"),
&src,
Span::new(0, 0),
);
assert!(result.is_err());
}
#[test]
fn nat_poly_format_zero() {
let z = NatPolyForm::from_constant(0);
assert_eq!(z.format(), "0");
}
}