Skip to main content

constraint_solver/
compiler.rs

1/*
2MIT License
3
4Copyright (c) 2026 Raja Lehtihet & Wael El Oraiby
5
6Permission is hereby granted, free of charge, to any person obtaining a copy
7of this software and associated documentation files (the "Software"), to deal
8in the Software without restriction, including without limitation the rights
9to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10copies of the Software, and to permit persons to whom the Software is
11furnished to do so, subject to the following conditions:
12
13The above copyright notice and this permission notice shall be included in all
14copies or substantial portions of the Software.
15
16THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22SOFTWARE.
23*/
24
25use std::collections::HashMap;
26use std::fmt;
27
28use crate::exp::{Exp, MissingVarError};
29
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub enum CompileError {
32    EmptyVariableName,
33}
34
35impl fmt::Display for CompileError {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        match self {
38            CompileError::EmptyVariableName => write!(f, "Variable name cannot be empty"),
39        }
40    }
41}
42
43impl std::error::Error for CompileError {}
44
45#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
46pub(crate) struct VarId(usize);
47
48impl VarId {
49    pub(crate) const fn new(id: usize) -> Self {
50        Self(id)
51    }
52}
53
54#[derive(Debug, Clone)]
55pub(crate) struct VarTable {
56    names: Vec<String>,
57    name_to_id: HashMap<String, VarId>,
58}
59
60impl VarTable {
61    fn new() -> Self {
62        Self {
63            names: Vec::new(),
64            name_to_id: HashMap::new(),
65        }
66    }
67
68    fn register(&mut self, name: &str) -> Result<VarId, CompileError> {
69        if name.is_empty() {
70            return Err(CompileError::EmptyVariableName);
71        }
72        if let Some(id) = self.name_to_id.get(name) {
73            return Ok(*id);
74        }
75        let id = VarId::new(self.names.len());
76        self.names.push(name.to_string());
77        self.name_to_id.insert(name.to_string(), id);
78        Ok(id)
79    }
80
81    pub(crate) fn get_id(&self, name: &str) -> Option<VarId> {
82        self.name_to_id.get(name).copied()
83    }
84
85    pub(crate) fn len(&self) -> usize {
86        self.names.len()
87    }
88
89    pub(crate) fn all_var_ids(&self) -> Vec<VarId> {
90        (0..self.names.len()).map(VarId::new).collect()
91    }
92
93    pub(crate) fn names(&self) -> &[String] {
94        &self.names
95    }
96}
97
98#[derive(Debug, Clone, PartialEq)]
99pub(crate) enum CompiledExp {
100    Val(f64),
101    Var(String, VarId),
102    Add(Box<CompiledExp>, Box<CompiledExp>),
103    Sub(Box<CompiledExp>, Box<CompiledExp>),
104    Mul(Box<CompiledExp>, Box<CompiledExp>),
105    Div(Box<CompiledExp>, Box<CompiledExp>),
106    Power(Box<CompiledExp>, f64),
107    Neg(Box<CompiledExp>),
108    Sin(Box<CompiledExp>),
109    Cos(Box<CompiledExp>),
110    Ln(Box<CompiledExp>),
111    Exp(Box<CompiledExp>),
112}
113
114impl CompiledExp {
115    pub(crate) fn evaluate_checked(
116        &self,
117        vars: &HashMap<VarId, f64>,
118    ) -> Result<f64, MissingVarError> {
119        match self {
120            CompiledExp::Val(v) => Ok(*v),
121            CompiledExp::Var(name, id) => vars.get(id).copied().ok_or_else(|| MissingVarError {
122                var_name: name.clone(),
123            }),
124            CompiledExp::Add(l, r) => Ok(l.evaluate_checked(vars)? + r.evaluate_checked(vars)?),
125            CompiledExp::Sub(l, r) => Ok(l.evaluate_checked(vars)? - r.evaluate_checked(vars)?),
126            CompiledExp::Mul(l, r) => Ok(l.evaluate_checked(vars)? * r.evaluate_checked(vars)?),
127            CompiledExp::Div(l, r) => Ok(l.evaluate_checked(vars)? / r.evaluate_checked(vars)?),
128            CompiledExp::Power(base, exp) => Ok(base.evaluate_checked(vars)?.powf(*exp)),
129            CompiledExp::Neg(e) => Ok(-e.evaluate_checked(vars)?),
130            CompiledExp::Sin(e) => Ok(e.evaluate_checked(vars)?.sin()),
131            CompiledExp::Cos(e) => Ok(e.evaluate_checked(vars)?.cos()),
132            CompiledExp::Ln(e) => Ok(e.evaluate_checked(vars)?.ln()),
133            CompiledExp::Exp(e) => Ok(e.evaluate_checked(vars)?.exp()),
134        }
135    }
136
137    pub(crate) fn differentiate(&self, var_id: VarId) -> CompiledExp {
138        match self {
139            CompiledExp::Val(_) => CompiledExp::Val(0.0),
140            CompiledExp::Var(_, id) => {
141                if *id == var_id {
142                    CompiledExp::Val(1.0)
143                } else {
144                    CompiledExp::Val(0.0)
145                }
146            }
147            CompiledExp::Add(l, r) => CompiledExp::Add(
148                Box::new(l.differentiate(var_id)),
149                Box::new(r.differentiate(var_id)),
150            ),
151            CompiledExp::Sub(l, r) => CompiledExp::Sub(
152                Box::new(l.differentiate(var_id)),
153                Box::new(r.differentiate(var_id)),
154            ),
155            CompiledExp::Mul(l, r) => {
156                let dl = l.differentiate(var_id);
157                let dr = r.differentiate(var_id);
158                CompiledExp::Add(
159                    Box::new(CompiledExp::Mul(Box::new(dl), r.clone())),
160                    Box::new(CompiledExp::Mul(l.clone(), Box::new(dr))),
161                )
162            }
163            CompiledExp::Div(l, r) => {
164                let dl = l.differentiate(var_id);
165                let dr = r.differentiate(var_id);
166                CompiledExp::Div(
167                    Box::new(CompiledExp::Sub(
168                        Box::new(CompiledExp::Mul(Box::new(dl), r.clone())),
169                        Box::new(CompiledExp::Mul(l.clone(), Box::new(dr))),
170                    )),
171                    Box::new(CompiledExp::Power(r.clone(), 2.0)),
172                )
173            }
174            CompiledExp::Power(base, exp) => {
175                let db = base.differentiate(var_id);
176                CompiledExp::Mul(
177                    Box::new(CompiledExp::Mul(
178                        Box::new(CompiledExp::Val(*exp)),
179                        Box::new(CompiledExp::Power(base.clone(), exp - 1.0)),
180                    )),
181                    Box::new(db),
182                )
183            }
184            CompiledExp::Neg(e) => CompiledExp::Neg(Box::new(e.differentiate(var_id))),
185            CompiledExp::Sin(e) => {
186                let de = e.differentiate(var_id);
187                CompiledExp::Mul(Box::new(CompiledExp::Cos(e.clone())), Box::new(de))
188            }
189            CompiledExp::Cos(e) => {
190                let de = e.differentiate(var_id);
191                CompiledExp::Neg(Box::new(CompiledExp::Mul(
192                    Box::new(CompiledExp::Sin(e.clone())),
193                    Box::new(de),
194                )))
195            }
196            CompiledExp::Ln(e) => {
197                let de = e.differentiate(var_id);
198                CompiledExp::Div(Box::new(de), e.clone())
199            }
200            CompiledExp::Exp(e) => {
201                let de = e.differentiate(var_id);
202                CompiledExp::Mul(Box::new(CompiledExp::Exp(e.clone())), Box::new(de))
203            }
204        }
205    }
206
207    pub(crate) fn simplify(&self) -> CompiledExp {
208        match self {
209            CompiledExp::Add(l, r) => {
210                let ls = l.simplify();
211                let rs = r.simplify();
212                match (&ls, &rs) {
213                    (CompiledExp::Val(lv), CompiledExp::Val(rv)) => CompiledExp::Val(lv + rv),
214                    (CompiledExp::Val(0.0), _) => rs,
215                    (_, CompiledExp::Val(0.0)) => ls,
216                    _ => CompiledExp::Add(Box::new(ls), Box::new(rs)),
217                }
218            }
219            CompiledExp::Sub(l, r) => {
220                let ls = l.simplify();
221                let rs = r.simplify();
222                match (&ls, &rs) {
223                    (CompiledExp::Val(lv), CompiledExp::Val(rv)) => CompiledExp::Val(lv - rv),
224                    (_, CompiledExp::Val(0.0)) => ls,
225                    _ => CompiledExp::Sub(Box::new(ls), Box::new(rs)),
226                }
227            }
228            CompiledExp::Mul(l, r) => {
229                let ls = l.simplify();
230                let rs = r.simplify();
231                match (&ls, &rs) {
232                    (CompiledExp::Val(lv), CompiledExp::Val(rv)) => CompiledExp::Val(lv * rv),
233                    (CompiledExp::Val(0.0), _) | (_, CompiledExp::Val(0.0)) => {
234                        CompiledExp::Val(0.0)
235                    }
236                    (CompiledExp::Val(1.0), _) => rs,
237                    (_, CompiledExp::Val(1.0)) => ls,
238                    _ => CompiledExp::Mul(Box::new(ls), Box::new(rs)),
239                }
240            }
241            CompiledExp::Div(l, r) => {
242                let ls = l.simplify();
243                let rs = r.simplify();
244                match (&ls, &rs) {
245                    (CompiledExp::Val(lv), CompiledExp::Val(rv)) if *rv != 0.0 => {
246                        CompiledExp::Val(lv / rv)
247                    }
248                    (CompiledExp::Val(0.0), _) => CompiledExp::Val(0.0),
249                    (_, CompiledExp::Val(1.0)) => ls,
250                    _ => CompiledExp::Div(Box::new(ls), Box::new(rs)),
251                }
252            }
253            CompiledExp::Power(base, exp) => {
254                let bs = base.simplify();
255                match &bs {
256                    CompiledExp::Val(v) => CompiledExp::Val(v.powf(*exp)),
257                    _ if *exp == 0.0 => CompiledExp::Val(1.0),
258                    _ if *exp == 1.0 => bs,
259                    _ => CompiledExp::Power(Box::new(bs), *exp),
260                }
261            }
262            CompiledExp::Neg(e) => {
263                let es = e.simplify();
264                match &es {
265                    CompiledExp::Val(v) => CompiledExp::Val(-v),
266                    _ => CompiledExp::Neg(Box::new(es)),
267                }
268            }
269            _ => self.clone(),
270        }
271    }
272}
273
274#[derive(Debug, Clone)]
275pub struct CompiledSystem {
276    pub(crate) equations: Vec<CompiledExp>,
277    pub(crate) var_table: VarTable,
278}
279
280impl CompiledSystem {
281}
282
283pub struct Compiler;
284
285impl Compiler {
286    pub fn compile(equations: &[Exp]) -> Result<CompiledSystem, CompileError> {
287        let mut var_table = VarTable::new();
288        let mut compiled = Vec::with_capacity(equations.len());
289
290        for eq in equations {
291            compiled.push(Self::compile_exp(eq, &mut var_table)?);
292        }
293
294        Ok(CompiledSystem {
295            equations: compiled,
296            var_table,
297        })
298    }
299
300    fn compile_exp(exp: &Exp, var_table: &mut VarTable) -> Result<CompiledExp, CompileError> {
301        match exp {
302            Exp::Val(v) => Ok(CompiledExp::Val(*v)),
303            Exp::Var(name) => {
304                let id = var_table.register(name)?;
305                Ok(CompiledExp::Var(name.clone(), id))
306            }
307            Exp::Add(l, r) => Ok(CompiledExp::Add(
308                Box::new(Self::compile_exp(l, var_table)?),
309                Box::new(Self::compile_exp(r, var_table)?),
310            )),
311            Exp::Sub(l, r) => Ok(CompiledExp::Sub(
312                Box::new(Self::compile_exp(l, var_table)?),
313                Box::new(Self::compile_exp(r, var_table)?),
314            )),
315            Exp::Mul(l, r) => Ok(CompiledExp::Mul(
316                Box::new(Self::compile_exp(l, var_table)?),
317                Box::new(Self::compile_exp(r, var_table)?),
318            )),
319            Exp::Div(l, r) => Ok(CompiledExp::Div(
320                Box::new(Self::compile_exp(l, var_table)?),
321                Box::new(Self::compile_exp(r, var_table)?),
322            )),
323            Exp::Power(base, exp) => Ok(CompiledExp::Power(
324                Box::new(Self::compile_exp(base, var_table)?),
325                *exp,
326            )),
327            Exp::Neg(e) => Ok(CompiledExp::Neg(Box::new(Self::compile_exp(e, var_table)?))),
328            Exp::Sin(e) => Ok(CompiledExp::Sin(Box::new(Self::compile_exp(e, var_table)?))),
329            Exp::Cos(e) => Ok(CompiledExp::Cos(Box::new(Self::compile_exp(e, var_table)?))),
330            Exp::Ln(e) => Ok(CompiledExp::Ln(Box::new(Self::compile_exp(e, var_table)?))),
331            Exp::Exp(e) => Ok(CompiledExp::Exp(Box::new(Self::compile_exp(e, var_table)?))),
332        }
333    }
334}