1use std::collections::HashMap;
26use std::fmt;
27
28use crate::exp::{Exp, MissingVarError};
29
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub enum CompileError {
32 EmptyVariableName,
33}
34
35impl fmt::Display for CompileError {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 match self {
38 CompileError::EmptyVariableName => write!(f, "Variable name cannot be empty"),
39 }
40 }
41}
42
43impl std::error::Error for CompileError {}
44
45#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
46pub(crate) struct VarId(usize);
47
48impl VarId {
49 pub(crate) const fn new(id: usize) -> Self {
50 Self(id)
51 }
52}
53
54#[derive(Debug, Clone)]
55pub(crate) struct VarTable {
56 names: Vec<String>,
57 name_to_id: HashMap<String, VarId>,
58}
59
60impl VarTable {
61 fn new() -> Self {
62 Self {
63 names: Vec::new(),
64 name_to_id: HashMap::new(),
65 }
66 }
67
68 fn register(&mut self, name: &str) -> Result<VarId, CompileError> {
69 if name.is_empty() {
70 return Err(CompileError::EmptyVariableName);
71 }
72 if let Some(id) = self.name_to_id.get(name) {
73 return Ok(*id);
74 }
75 let id = VarId::new(self.names.len());
76 self.names.push(name.to_string());
77 self.name_to_id.insert(name.to_string(), id);
78 Ok(id)
79 }
80
81 pub(crate) fn get_id(&self, name: &str) -> Option<VarId> {
82 self.name_to_id.get(name).copied()
83 }
84
85 pub(crate) fn len(&self) -> usize {
86 self.names.len()
87 }
88
89 pub(crate) fn all_var_ids(&self) -> Vec<VarId> {
90 (0..self.names.len()).map(VarId::new).collect()
91 }
92
93 pub(crate) fn names(&self) -> &[String] {
94 &self.names
95 }
96}
97
98#[derive(Debug, Clone, PartialEq)]
99pub(crate) enum CompiledExp {
100 Val(f64),
101 Var(String, VarId),
102 Add(Box<CompiledExp>, Box<CompiledExp>),
103 Sub(Box<CompiledExp>, Box<CompiledExp>),
104 Mul(Box<CompiledExp>, Box<CompiledExp>),
105 Div(Box<CompiledExp>, Box<CompiledExp>),
106 Power(Box<CompiledExp>, f64),
107 Neg(Box<CompiledExp>),
108 Sin(Box<CompiledExp>),
109 Cos(Box<CompiledExp>),
110 Ln(Box<CompiledExp>),
111 Exp(Box<CompiledExp>),
112}
113
114impl CompiledExp {
115 pub(crate) fn evaluate_checked(
116 &self,
117 vars: &HashMap<VarId, f64>,
118 ) -> Result<f64, MissingVarError> {
119 match self {
120 CompiledExp::Val(v) => Ok(*v),
121 CompiledExp::Var(name, id) => vars.get(id).copied().ok_or_else(|| MissingVarError {
122 var_name: name.clone(),
123 }),
124 CompiledExp::Add(l, r) => Ok(l.evaluate_checked(vars)? + r.evaluate_checked(vars)?),
125 CompiledExp::Sub(l, r) => Ok(l.evaluate_checked(vars)? - r.evaluate_checked(vars)?),
126 CompiledExp::Mul(l, r) => Ok(l.evaluate_checked(vars)? * r.evaluate_checked(vars)?),
127 CompiledExp::Div(l, r) => Ok(l.evaluate_checked(vars)? / r.evaluate_checked(vars)?),
128 CompiledExp::Power(base, exp) => Ok(base.evaluate_checked(vars)?.powf(*exp)),
129 CompiledExp::Neg(e) => Ok(-e.evaluate_checked(vars)?),
130 CompiledExp::Sin(e) => Ok(e.evaluate_checked(vars)?.sin()),
131 CompiledExp::Cos(e) => Ok(e.evaluate_checked(vars)?.cos()),
132 CompiledExp::Ln(e) => Ok(e.evaluate_checked(vars)?.ln()),
133 CompiledExp::Exp(e) => Ok(e.evaluate_checked(vars)?.exp()),
134 }
135 }
136
137 pub(crate) fn differentiate(&self, var_id: VarId) -> CompiledExp {
138 match self {
139 CompiledExp::Val(_) => CompiledExp::Val(0.0),
140 CompiledExp::Var(_, id) => {
141 if *id == var_id {
142 CompiledExp::Val(1.0)
143 } else {
144 CompiledExp::Val(0.0)
145 }
146 }
147 CompiledExp::Add(l, r) => CompiledExp::Add(
148 Box::new(l.differentiate(var_id)),
149 Box::new(r.differentiate(var_id)),
150 ),
151 CompiledExp::Sub(l, r) => CompiledExp::Sub(
152 Box::new(l.differentiate(var_id)),
153 Box::new(r.differentiate(var_id)),
154 ),
155 CompiledExp::Mul(l, r) => {
156 let dl = l.differentiate(var_id);
157 let dr = r.differentiate(var_id);
158 CompiledExp::Add(
159 Box::new(CompiledExp::Mul(Box::new(dl), r.clone())),
160 Box::new(CompiledExp::Mul(l.clone(), Box::new(dr))),
161 )
162 }
163 CompiledExp::Div(l, r) => {
164 let dl = l.differentiate(var_id);
165 let dr = r.differentiate(var_id);
166 CompiledExp::Div(
167 Box::new(CompiledExp::Sub(
168 Box::new(CompiledExp::Mul(Box::new(dl), r.clone())),
169 Box::new(CompiledExp::Mul(l.clone(), Box::new(dr))),
170 )),
171 Box::new(CompiledExp::Power(r.clone(), 2.0)),
172 )
173 }
174 CompiledExp::Power(base, exp) => {
175 let db = base.differentiate(var_id);
176 CompiledExp::Mul(
177 Box::new(CompiledExp::Mul(
178 Box::new(CompiledExp::Val(*exp)),
179 Box::new(CompiledExp::Power(base.clone(), exp - 1.0)),
180 )),
181 Box::new(db),
182 )
183 }
184 CompiledExp::Neg(e) => CompiledExp::Neg(Box::new(e.differentiate(var_id))),
185 CompiledExp::Sin(e) => {
186 let de = e.differentiate(var_id);
187 CompiledExp::Mul(Box::new(CompiledExp::Cos(e.clone())), Box::new(de))
188 }
189 CompiledExp::Cos(e) => {
190 let de = e.differentiate(var_id);
191 CompiledExp::Neg(Box::new(CompiledExp::Mul(
192 Box::new(CompiledExp::Sin(e.clone())),
193 Box::new(de),
194 )))
195 }
196 CompiledExp::Ln(e) => {
197 let de = e.differentiate(var_id);
198 CompiledExp::Div(Box::new(de), e.clone())
199 }
200 CompiledExp::Exp(e) => {
201 let de = e.differentiate(var_id);
202 CompiledExp::Mul(Box::new(CompiledExp::Exp(e.clone())), Box::new(de))
203 }
204 }
205 }
206
207 pub(crate) fn simplify(&self) -> CompiledExp {
208 match self {
209 CompiledExp::Add(l, r) => {
210 let ls = l.simplify();
211 let rs = r.simplify();
212 match (&ls, &rs) {
213 (CompiledExp::Val(lv), CompiledExp::Val(rv)) => CompiledExp::Val(lv + rv),
214 (CompiledExp::Val(0.0), _) => rs,
215 (_, CompiledExp::Val(0.0)) => ls,
216 _ => CompiledExp::Add(Box::new(ls), Box::new(rs)),
217 }
218 }
219 CompiledExp::Sub(l, r) => {
220 let ls = l.simplify();
221 let rs = r.simplify();
222 match (&ls, &rs) {
223 (CompiledExp::Val(lv), CompiledExp::Val(rv)) => CompiledExp::Val(lv - rv),
224 (_, CompiledExp::Val(0.0)) => ls,
225 _ => CompiledExp::Sub(Box::new(ls), Box::new(rs)),
226 }
227 }
228 CompiledExp::Mul(l, r) => {
229 let ls = l.simplify();
230 let rs = r.simplify();
231 match (&ls, &rs) {
232 (CompiledExp::Val(lv), CompiledExp::Val(rv)) => CompiledExp::Val(lv * rv),
233 (CompiledExp::Val(0.0), _) | (_, CompiledExp::Val(0.0)) => {
234 CompiledExp::Val(0.0)
235 }
236 (CompiledExp::Val(1.0), _) => rs,
237 (_, CompiledExp::Val(1.0)) => ls,
238 _ => CompiledExp::Mul(Box::new(ls), Box::new(rs)),
239 }
240 }
241 CompiledExp::Div(l, r) => {
242 let ls = l.simplify();
243 let rs = r.simplify();
244 match (&ls, &rs) {
245 (CompiledExp::Val(lv), CompiledExp::Val(rv)) if *rv != 0.0 => {
246 CompiledExp::Val(lv / rv)
247 }
248 (CompiledExp::Val(0.0), _) => CompiledExp::Val(0.0),
249 (_, CompiledExp::Val(1.0)) => ls,
250 _ => CompiledExp::Div(Box::new(ls), Box::new(rs)),
251 }
252 }
253 CompiledExp::Power(base, exp) => {
254 let bs = base.simplify();
255 match &bs {
256 CompiledExp::Val(v) => CompiledExp::Val(v.powf(*exp)),
257 _ if *exp == 0.0 => CompiledExp::Val(1.0),
258 _ if *exp == 1.0 => bs,
259 _ => CompiledExp::Power(Box::new(bs), *exp),
260 }
261 }
262 CompiledExp::Neg(e) => {
263 let es = e.simplify();
264 match &es {
265 CompiledExp::Val(v) => CompiledExp::Val(-v),
266 _ => CompiledExp::Neg(Box::new(es)),
267 }
268 }
269 _ => self.clone(),
270 }
271 }
272}
273
274#[derive(Debug, Clone)]
275pub struct CompiledSystem {
276 pub(crate) equations: Vec<CompiledExp>,
277 pub(crate) var_table: VarTable,
278}
279
280impl CompiledSystem {
281}
282
283pub struct Compiler;
284
285impl Compiler {
286 pub fn compile(equations: &[Exp]) -> Result<CompiledSystem, CompileError> {
287 let mut var_table = VarTable::new();
288 let mut compiled = Vec::with_capacity(equations.len());
289
290 for eq in equations {
291 compiled.push(Self::compile_exp(eq, &mut var_table)?);
292 }
293
294 Ok(CompiledSystem {
295 equations: compiled,
296 var_table,
297 })
298 }
299
300 fn compile_exp(exp: &Exp, var_table: &mut VarTable) -> Result<CompiledExp, CompileError> {
301 match exp {
302 Exp::Val(v) => Ok(CompiledExp::Val(*v)),
303 Exp::Var(name) => {
304 let id = var_table.register(name)?;
305 Ok(CompiledExp::Var(name.clone(), id))
306 }
307 Exp::Add(l, r) => Ok(CompiledExp::Add(
308 Box::new(Self::compile_exp(l, var_table)?),
309 Box::new(Self::compile_exp(r, var_table)?),
310 )),
311 Exp::Sub(l, r) => Ok(CompiledExp::Sub(
312 Box::new(Self::compile_exp(l, var_table)?),
313 Box::new(Self::compile_exp(r, var_table)?),
314 )),
315 Exp::Mul(l, r) => Ok(CompiledExp::Mul(
316 Box::new(Self::compile_exp(l, var_table)?),
317 Box::new(Self::compile_exp(r, var_table)?),
318 )),
319 Exp::Div(l, r) => Ok(CompiledExp::Div(
320 Box::new(Self::compile_exp(l, var_table)?),
321 Box::new(Self::compile_exp(r, var_table)?),
322 )),
323 Exp::Power(base, exp) => Ok(CompiledExp::Power(
324 Box::new(Self::compile_exp(base, var_table)?),
325 *exp,
326 )),
327 Exp::Neg(e) => Ok(CompiledExp::Neg(Box::new(Self::compile_exp(e, var_table)?))),
328 Exp::Sin(e) => Ok(CompiledExp::Sin(Box::new(Self::compile_exp(e, var_table)?))),
329 Exp::Cos(e) => Ok(CompiledExp::Cos(Box::new(Self::compile_exp(e, var_table)?))),
330 Exp::Ln(e) => Ok(CompiledExp::Ln(Box::new(Self::compile_exp(e, var_table)?))),
331 Exp::Exp(e) => Ok(CompiledExp::Exp(Box::new(Self::compile_exp(e, var_table)?))),
332 }
333 }
334}