Skip to main content

graphcal_compiler/syntax/
nat.rs

1//! Typed representation of type-level Nat arithmetic.
2//!
3//! Nat forms are used by both type resolution and declared type references, so
4//! they live in the syntax layer rather than in TIR. Rendering a form to a
5//! string is a display operation only; semantic comparisons use the normalized
6//! polynomial structure directly.
7
8use std::collections::{BTreeMap, BTreeSet, HashMap};
9
10use crate::syntax::names::GenericParamName;
11
12/// Arithmetic overflow while combining type-level Nat forms.
13///
14/// Coefficients and exponents are stored as `u64`; combining forms whose
15/// values exceed that range must fail loudly instead of wrapping, since a
16/// wrapped form could spuriously unify with an unrelated type.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub struct NatOverflowError;
19
20impl std::fmt::Display for NatOverflowError {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        write!(
23            f,
24            "type-level Nat arithmetic overflow (values are stored as `u64`)"
25        )
26    }
27}
28
29impl std::error::Error for NatOverflowError {}
30
31/// A monomial: product of variables raised to natural number exponents.
32///
33/// Represented as a sorted map from variable name to exponent. The empty map
34/// represents the constant monomial (= 1).
35#[derive(Debug, Clone, PartialEq, Eq, Hash)]
36pub(crate) struct Monomial(pub(crate) BTreeMap<GenericParamName, u64>);
37
38impl Monomial {
39    /// The constant monomial (empty product = 1).
40    #[must_use]
41    pub(crate) const fn constant() -> Self {
42        Self(BTreeMap::new())
43    }
44
45    /// A single-variable monomial with exponent 1.
46    #[must_use]
47    pub(crate) fn var(name: GenericParamName) -> Self {
48        let mut m = BTreeMap::new();
49        m.insert(name, 1);
50        Self(m)
51    }
52
53    /// Returns `true` if this is the constant monomial (no variables).
54    #[must_use]
55    pub(crate) fn is_constant(&self) -> bool {
56        self.0.is_empty()
57    }
58
59    /// Multiply two monomials: add exponents of each variable.
60    ///
61    /// Returns an error if an exponent overflows.
62    pub(crate) fn mul(&self, other: &Self) -> Result<Self, NatOverflowError> {
63        let mut result = self.0.clone();
64        for (var, exp) in &other.0 {
65            let entry = result.entry(var.clone()).or_insert(0);
66            *entry = entry.checked_add(*exp).ok_or(NatOverflowError)?;
67        }
68        Ok(Self(result))
69    }
70
71    /// Evaluate the monomial given variable bindings.
72    ///
73    /// Returns `None` if any variable is unbound or arithmetic overflows.
74    #[must_use]
75    pub(crate) fn evaluate(&self, bindings: &HashMap<GenericParamName, u64>) -> Option<u64> {
76        let mut result: u64 = 1;
77        for (var, exp) in &self.0 {
78            let val = bindings.get(var)?;
79            result = result.checked_mul(val.checked_pow(u32::try_from(*exp).ok()?)?)?;
80        }
81        Some(result)
82    }
83
84    /// Substitute bound variables, returning a new monomial with only unbound
85    /// variables and the multiplicative factor contributed by bound variables.
86    ///
87    /// Returns `None` if arithmetic overflows.
88    #[must_use]
89    pub(crate) fn substitute(
90        &self,
91        bindings: &HashMap<GenericParamName, u64>,
92    ) -> Option<(Self, u64)> {
93        let mut remaining = BTreeMap::new();
94        let mut factor: u64 = 1;
95        for (var, exp) in &self.0 {
96            if let Some(val) = bindings.get(var) {
97                factor = factor.checked_mul(val.checked_pow(u32::try_from(*exp).ok()?)?)?;
98            } else {
99                remaining.insert(var.clone(), *exp);
100            }
101        }
102        Some((Self(remaining), factor))
103    }
104
105    /// Format as a human-readable string, e.g. `""`, `"N"`, `"M * N"`, `"N^2"`.
106    #[must_use]
107    pub(crate) fn format(&self) -> String {
108        let mut parts = Vec::new();
109        for (var, exp) in &self.0 {
110            if *exp == 1 {
111                parts.push(var.to_string());
112            } else {
113                parts.push(format!("{var}^{exp}"));
114            }
115        }
116        parts.join(" * ")
117    }
118}
119
120impl PartialOrd for Monomial {
121    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
122        Some(self.cmp(other))
123    }
124}
125
126impl Ord for Monomial {
127    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
128        // Compare by iterating entries in sorted order (BTreeMap guarantees this).
129        let a: Vec<_> = self.0.iter().collect();
130        let b: Vec<_> = other.0.iter().collect();
131        a.cmp(&b)
132    }
133}
134
135/// A normalized polynomial form for Nat expressions.
136///
137/// This is the canonical representation for Nat arithmetic (Level 1 addition +
138/// Level 2 multiplication). Each term is a monomial mapped to its coefficient.
139/// Two `NatPolyForm`s are equal iff their normalized terms match.
140#[derive(Debug, Clone, PartialEq, Eq, Hash)]
141pub struct NatPolyForm {
142    /// Monomial → coefficient mapping (only non-zero coefficients).
143    pub(crate) terms: BTreeMap<Monomial, u64>,
144}
145
146/// Backward-compatible alias for code that still speaks in linear Nat forms.
147pub type NatLinearForm = NatPolyForm;
148
149impl NatPolyForm {
150    /// Create a polynomial from a constant.
151    #[must_use]
152    pub fn from_constant(c: u64) -> Self {
153        let mut terms = BTreeMap::new();
154        if c != 0 {
155            terms.insert(Monomial::constant(), c);
156        }
157        Self { terms }
158    }
159
160    /// Create a polynomial from a single variable with coefficient 1.
161    #[must_use]
162    pub fn from_var(name: GenericParamName) -> Self {
163        let mut terms = BTreeMap::new();
164        terms.insert(Monomial::var(name), 1);
165        Self { terms }
166    }
167
168    /// Add two polynomials.
169    ///
170    /// Returns an error if a coefficient overflows.
171    pub fn add(&self, other: &Self) -> Result<Self, NatOverflowError> {
172        let mut terms = self.terms.clone();
173        for (mono, coeff) in &other.terms {
174            let entry = terms.entry(mono.clone()).or_insert(0);
175            *entry = entry.checked_add(*coeff).ok_or(NatOverflowError)?;
176        }
177        terms.retain(|_, c| *c != 0);
178        Ok(Self { terms })
179    }
180
181    /// Multiply two polynomials (distributive law).
182    ///
183    /// Returns an error if a coefficient or exponent overflows.
184    pub fn mul(&self, other: &Self) -> Result<Self, NatOverflowError> {
185        let mut terms: BTreeMap<Monomial, u64> = BTreeMap::new();
186        for (m1, c1) in &self.terms {
187            for (m2, c2) in &other.terms {
188                let mono = m1.mul(m2)?;
189                let term = c1.checked_mul(*c2).ok_or(NatOverflowError)?;
190                let entry = terms.entry(mono).or_insert(0);
191                *entry = entry.checked_add(term).ok_or(NatOverflowError)?;
192            }
193        }
194        terms.retain(|_, c| *c != 0);
195        Ok(Self { terms })
196    }
197
198    /// Returns the constant term (coefficient of the empty monomial).
199    #[must_use]
200    pub fn constant(&self) -> u64 {
201        self.terms.get(&Monomial::constant()).copied().unwrap_or(0)
202    }
203
204    /// Returns `true` if this form has no variables (is a constant).
205    #[must_use]
206    pub fn is_constant(&self) -> bool {
207        self.terms.iter().all(|(m, _)| m.is_constant())
208    }
209
210    /// Evaluate to a concrete value given variable bindings.
211    ///
212    /// Returns `None` if any variable is unbound or arithmetic overflows.
213    #[must_use]
214    pub fn evaluate(&self, bindings: &HashMap<GenericParamName, u64>) -> Option<u64> {
215        let mut result: u64 = 0;
216        for (mono, coeff) in &self.terms {
217            result = result.checked_add(coeff.checked_mul(mono.evaluate(bindings)?)?)?;
218        }
219        Some(result)
220    }
221
222    /// Format as a human-readable string.
223    ///
224    /// Examples: `"3"`, `"N"`, `"N + 1"`, `"M * N"`, `"2 * N^2 + N + 1"`.
225    #[must_use]
226    pub fn format(&self) -> String {
227        if self.terms.is_empty() {
228            return "0".to_string();
229        }
230        let mut parts = Vec::new();
231        // Non-constant terms first (sorted by monomial), then constant.
232        for (mono, coeff) in &self.terms {
233            if mono.is_constant() {
234                continue;
235            }
236            let mono_str = mono.format();
237            if *coeff == 1 {
238                parts.push(mono_str);
239            } else {
240                parts.push(format!("{coeff} * {mono_str}"));
241            }
242        }
243        if let Some(&c) = self.terms.get(&Monomial::constant())
244            && (c > 0 || parts.is_empty())
245        {
246            parts.push(c.to_string());
247        }
248        if parts.is_empty() {
249            "0".to_string()
250        } else {
251            parts.join(" + ")
252        }
253    }
254
255    /// Check if `self <= other` for all non-negative variable assignments.
256    ///
257    /// Returns `true` iff for every monomial, the coefficient in `self` is <=
258    /// the coefficient in `other`. This is sound because all `Nat` variables
259    /// are non-negative, so each monomial evaluates to a non-negative value.
260    #[must_use]
261    pub fn is_leq(&self, other: &Self) -> bool {
262        self.terms.iter().all(|(mono, &coeff)| {
263            let other_coeff = other.terms.get(mono).copied().unwrap_or(0);
264            coeff <= other_coeff
265        })
266    }
267
268    /// Collect all variable names that appear in any monomial of this polynomial.
269    #[must_use]
270    pub fn variables(&self) -> BTreeSet<GenericParamName> {
271        self.terms
272            .keys()
273            .flat_map(|mono| mono.0.keys().cloned())
274            .collect()
275    }
276}
277
278impl std::fmt::Display for NatPolyForm {
279    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280        write!(f, "{}", self.format())
281    }
282}