graphcal_compiler/syntax/
nat.rs1use std::collections::{BTreeMap, BTreeSet, HashMap};
9
10use crate::syntax::names::GenericParamName;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub struct NatOverflowError;
19
20impl std::fmt::Display for NatOverflowError {
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 write!(
23 f,
24 "type-level Nat arithmetic overflow (values are stored as `u64`)"
25 )
26 }
27}
28
29impl std::error::Error for NatOverflowError {}
30
31#[derive(Debug, Clone, PartialEq, Eq, Hash)]
36pub(crate) struct Monomial(pub(crate) BTreeMap<GenericParamName, u64>);
37
38impl Monomial {
39 #[must_use]
41 pub(crate) const fn constant() -> Self {
42 Self(BTreeMap::new())
43 }
44
45 #[must_use]
47 pub(crate) fn var(name: GenericParamName) -> Self {
48 let mut m = BTreeMap::new();
49 m.insert(name, 1);
50 Self(m)
51 }
52
53 #[must_use]
55 pub(crate) fn is_constant(&self) -> bool {
56 self.0.is_empty()
57 }
58
59 pub(crate) fn mul(&self, other: &Self) -> Result<Self, NatOverflowError> {
63 let mut result = self.0.clone();
64 for (var, exp) in &other.0 {
65 let entry = result.entry(var.clone()).or_insert(0);
66 *entry = entry.checked_add(*exp).ok_or(NatOverflowError)?;
67 }
68 Ok(Self(result))
69 }
70
71 #[must_use]
75 pub(crate) fn evaluate(&self, bindings: &HashMap<GenericParamName, u64>) -> Option<u64> {
76 let mut result: u64 = 1;
77 for (var, exp) in &self.0 {
78 let val = bindings.get(var)?;
79 result = result.checked_mul(val.checked_pow(u32::try_from(*exp).ok()?)?)?;
80 }
81 Some(result)
82 }
83
84 #[must_use]
89 pub(crate) fn substitute(
90 &self,
91 bindings: &HashMap<GenericParamName, u64>,
92 ) -> Option<(Self, u64)> {
93 let mut remaining = BTreeMap::new();
94 let mut factor: u64 = 1;
95 for (var, exp) in &self.0 {
96 if let Some(val) = bindings.get(var) {
97 factor = factor.checked_mul(val.checked_pow(u32::try_from(*exp).ok()?)?)?;
98 } else {
99 remaining.insert(var.clone(), *exp);
100 }
101 }
102 Some((Self(remaining), factor))
103 }
104
105 #[must_use]
107 pub(crate) fn format(&self) -> String {
108 let mut parts = Vec::new();
109 for (var, exp) in &self.0 {
110 if *exp == 1 {
111 parts.push(var.to_string());
112 } else {
113 parts.push(format!("{var}^{exp}"));
114 }
115 }
116 parts.join(" * ")
117 }
118}
119
120impl PartialOrd for Monomial {
121 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
122 Some(self.cmp(other))
123 }
124}
125
126impl Ord for Monomial {
127 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
128 let a: Vec<_> = self.0.iter().collect();
130 let b: Vec<_> = other.0.iter().collect();
131 a.cmp(&b)
132 }
133}
134
135#[derive(Debug, Clone, PartialEq, Eq, Hash)]
141pub struct NatPolyForm {
142 pub(crate) terms: BTreeMap<Monomial, u64>,
144}
145
146pub type NatLinearForm = NatPolyForm;
148
149impl NatPolyForm {
150 #[must_use]
152 pub fn from_constant(c: u64) -> Self {
153 let mut terms = BTreeMap::new();
154 if c != 0 {
155 terms.insert(Monomial::constant(), c);
156 }
157 Self { terms }
158 }
159
160 #[must_use]
162 pub fn from_var(name: GenericParamName) -> Self {
163 let mut terms = BTreeMap::new();
164 terms.insert(Monomial::var(name), 1);
165 Self { terms }
166 }
167
168 pub fn add(&self, other: &Self) -> Result<Self, NatOverflowError> {
172 let mut terms = self.terms.clone();
173 for (mono, coeff) in &other.terms {
174 let entry = terms.entry(mono.clone()).or_insert(0);
175 *entry = entry.checked_add(*coeff).ok_or(NatOverflowError)?;
176 }
177 terms.retain(|_, c| *c != 0);
178 Ok(Self { terms })
179 }
180
181 pub fn mul(&self, other: &Self) -> Result<Self, NatOverflowError> {
185 let mut terms: BTreeMap<Monomial, u64> = BTreeMap::new();
186 for (m1, c1) in &self.terms {
187 for (m2, c2) in &other.terms {
188 let mono = m1.mul(m2)?;
189 let term = c1.checked_mul(*c2).ok_or(NatOverflowError)?;
190 let entry = terms.entry(mono).or_insert(0);
191 *entry = entry.checked_add(term).ok_or(NatOverflowError)?;
192 }
193 }
194 terms.retain(|_, c| *c != 0);
195 Ok(Self { terms })
196 }
197
198 #[must_use]
200 pub fn constant(&self) -> u64 {
201 self.terms.get(&Monomial::constant()).copied().unwrap_or(0)
202 }
203
204 #[must_use]
206 pub fn is_constant(&self) -> bool {
207 self.terms.iter().all(|(m, _)| m.is_constant())
208 }
209
210 #[must_use]
214 pub fn evaluate(&self, bindings: &HashMap<GenericParamName, u64>) -> Option<u64> {
215 let mut result: u64 = 0;
216 for (mono, coeff) in &self.terms {
217 result = result.checked_add(coeff.checked_mul(mono.evaluate(bindings)?)?)?;
218 }
219 Some(result)
220 }
221
222 #[must_use]
226 pub fn format(&self) -> String {
227 if self.terms.is_empty() {
228 return "0".to_string();
229 }
230 let mut parts = Vec::new();
231 for (mono, coeff) in &self.terms {
233 if mono.is_constant() {
234 continue;
235 }
236 let mono_str = mono.format();
237 if *coeff == 1 {
238 parts.push(mono_str);
239 } else {
240 parts.push(format!("{coeff} * {mono_str}"));
241 }
242 }
243 if let Some(&c) = self.terms.get(&Monomial::constant())
244 && (c > 0 || parts.is_empty())
245 {
246 parts.push(c.to_string());
247 }
248 if parts.is_empty() {
249 "0".to_string()
250 } else {
251 parts.join(" + ")
252 }
253 }
254
255 #[must_use]
261 pub fn is_leq(&self, other: &Self) -> bool {
262 self.terms.iter().all(|(mono, &coeff)| {
263 let other_coeff = other.terms.get(mono).copied().unwrap_or(0);
264 coeff <= other_coeff
265 })
266 }
267
268 #[must_use]
270 pub fn variables(&self) -> BTreeSet<GenericParamName> {
271 self.terms
272 .keys()
273 .flat_map(|mono| mono.0.keys().cloned())
274 .collect()
275 }
276}
277
278impl std::fmt::Display for NatPolyForm {
279 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280 write!(f, "{}", self.format())
281 }
282}