1use crate::types::ProblemSize;
4use std::collections::{HashMap, HashSet};
5use std::fmt;
6
7#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
9pub enum Expr {
10 Const(f64),
12 Var(&'static str),
14 Add(Box<Expr>, Box<Expr>),
16 Mul(Box<Expr>, Box<Expr>),
18 Pow(Box<Expr>, Box<Expr>),
20 Exp(Box<Expr>),
22 Log(Box<Expr>),
24 Sqrt(Box<Expr>),
26 Factorial(Box<Expr>),
28}
29
30impl Expr {
31 pub fn pow(base: Expr, exp: Expr) -> Self {
33 Expr::Pow(Box::new(base), Box::new(exp))
34 }
35
36 pub fn scale(self, c: f64) -> Self {
38 Expr::Const(c) * self
39 }
40
41 pub fn eval(&self, vars: &ProblemSize) -> f64 {
43 match self {
44 Expr::Const(c) => *c,
45 Expr::Var(name) => vars.get(name).unwrap_or(0) as f64,
46 Expr::Add(a, b) => a.eval(vars) + b.eval(vars),
47 Expr::Mul(a, b) => a.eval(vars) * b.eval(vars),
48 Expr::Pow(base, exp) => base.eval(vars).powf(exp.eval(vars)),
49 Expr::Exp(a) => a.eval(vars).exp(),
50 Expr::Log(a) => a.eval(vars).ln(),
51 Expr::Sqrt(a) => a.eval(vars).sqrt(),
52 Expr::Factorial(a) => gamma_factorial(a.eval(vars)),
53 }
54 }
55
56 pub fn variables(&self) -> HashSet<&'static str> {
58 let mut vars = HashSet::new();
59 self.collect_variables(&mut vars);
60 vars
61 }
62
63 fn collect_variables(&self, vars: &mut HashSet<&'static str>) {
64 match self {
65 Expr::Const(_) => {}
66 Expr::Var(name) => {
67 vars.insert(name);
68 }
69 Expr::Add(a, b) | Expr::Mul(a, b) | Expr::Pow(a, b) => {
70 a.collect_variables(vars);
71 b.collect_variables(vars);
72 }
73 Expr::Exp(a) | Expr::Log(a) | Expr::Sqrt(a) | Expr::Factorial(a) => {
74 a.collect_variables(vars);
75 }
76 }
77 }
78
79 pub fn substitute(&self, mapping: &HashMap<&str, &Expr>) -> Expr {
81 match self {
82 Expr::Const(c) => Expr::Const(*c),
83 Expr::Var(name) => {
84 if let Some(replacement) = mapping.get(name) {
85 (*replacement).clone()
86 } else {
87 Expr::Var(name)
88 }
89 }
90 Expr::Add(a, b) => a.substitute(mapping) + b.substitute(mapping),
91 Expr::Mul(a, b) => a.substitute(mapping) * b.substitute(mapping),
92 Expr::Pow(a, b) => Expr::pow(a.substitute(mapping), b.substitute(mapping)),
93 Expr::Exp(a) => Expr::Exp(Box::new(a.substitute(mapping))),
94 Expr::Log(a) => Expr::Log(Box::new(a.substitute(mapping))),
95 Expr::Sqrt(a) => Expr::Sqrt(Box::new(a.substitute(mapping))),
96 Expr::Factorial(a) => Expr::Factorial(Box::new(a.substitute(mapping))),
97 }
98 }
99
100 pub fn parse(input: &str) -> Expr {
111 Self::try_parse(input)
112 .unwrap_or_else(|e| panic!("failed to parse expression \"{input}\": {e}"))
113 }
114
115 pub fn try_parse(input: &str) -> Result<Expr, String> {
117 parse_to_expr(input)
118 }
119
120 pub fn is_polynomial(&self) -> bool {
122 match self {
123 Expr::Const(_) | Expr::Var(_) => true,
124 Expr::Add(a, b) | Expr::Mul(a, b) => a.is_polynomial() && b.is_polynomial(),
125 Expr::Pow(base, exp) => {
126 base.is_polynomial()
127 && matches!(exp.as_ref(), Expr::Const(c) if *c >= 0.0 && (*c - c.round()).abs() < 1e-10)
128 }
129 Expr::Exp(_) | Expr::Log(_) | Expr::Sqrt(_) | Expr::Factorial(_) => false,
130 }
131 }
132
133 pub fn is_valid_complexity_notation(&self) -> bool {
144 self.is_valid_complexity_notation_inner()
145 }
146
147 fn is_valid_complexity_notation_inner(&self) -> bool {
148 match self {
149 Expr::Const(c) => (*c - 1.0).abs() < 1e-10,
150 Expr::Var(_) => true,
151 Expr::Add(a, b) => {
152 a.constant_value().is_none()
153 && b.constant_value().is_none()
154 && a.is_valid_complexity_notation_inner()
155 && b.is_valid_complexity_notation_inner()
156 }
157 Expr::Mul(a, b) => {
158 a.constant_value().is_none()
159 && b.constant_value().is_none()
160 && a.is_valid_complexity_notation_inner()
161 && b.is_valid_complexity_notation_inner()
162 }
163 Expr::Pow(base, exp) => {
164 let base_is_constant = base.constant_value().is_some();
165 let exp_is_constant = exp.constant_value().is_some();
166
167 let base_ok = if base_is_constant {
168 base.is_valid_exponential_base()
169 } else {
170 base.is_valid_complexity_notation_inner()
171 };
172
173 let exp_ok = if exp_is_constant {
174 true
175 } else {
176 exp.is_valid_complexity_notation_inner()
177 };
178
179 base_ok && exp_ok
180 }
181 Expr::Exp(a) | Expr::Log(a) | Expr::Sqrt(a) | Expr::Factorial(a) => {
182 a.is_valid_complexity_notation_inner()
183 }
184 }
185 }
186
187 fn is_valid_exponential_base(&self) -> bool {
188 self.constant_value().is_some_and(|c| c > 0.0)
189 }
190
191 pub(crate) fn constant_value(&self) -> Option<f64> {
192 match self {
193 Expr::Const(c) => Some(*c),
194 Expr::Var(_) => None,
195 Expr::Add(a, b) => Some(a.constant_value()? + b.constant_value()?),
196 Expr::Mul(a, b) => Some(a.constant_value()? * b.constant_value()?),
197 Expr::Pow(base, exp) => Some(base.constant_value()?.powf(exp.constant_value()?)),
198 Expr::Exp(a) => Some(a.constant_value()?.exp()),
199 Expr::Log(a) => Some(a.constant_value()?.ln()),
200 Expr::Sqrt(a) => Some(a.constant_value()?.sqrt()),
201 Expr::Factorial(a) => Some(gamma_factorial(a.constant_value()?)),
202 }
203 }
204}
205
206impl fmt::Display for Expr {
207 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208 match self {
209 Expr::Const(c) => {
210 let ci = c.round() as i64;
211 if (*c - ci as f64).abs() < 1e-10 {
212 write!(f, "{ci}")
213 } else {
214 write!(f, "{c}")
215 }
216 }
217 Expr::Var(name) => write!(f, "{name}"),
218 Expr::Add(a, b) => write!(f, "{a} + {b}"),
219 Expr::Mul(a, b) => {
220 let left = if matches!(a.as_ref(), Expr::Add(_, _)) {
221 format!("({a})")
222 } else {
223 format!("{a}")
224 };
225 let right = if matches!(b.as_ref(), Expr::Add(_, _)) {
226 format!("({b})")
227 } else {
228 format!("{b}")
229 };
230 write!(f, "{left} * {right}")
231 }
232 Expr::Pow(base, exp) => {
233 if let Expr::Const(e) = exp.as_ref() {
235 if (*e - 0.5).abs() < 1e-15 {
236 return write!(f, "sqrt({base})");
237 }
238 }
239 let base_str = if matches!(base.as_ref(), Expr::Add(_, _) | Expr::Mul(_, _)) {
240 format!("({base})")
241 } else {
242 format!("{base}")
243 };
244 let exp_str = if matches!(exp.as_ref(), Expr::Add(_, _) | Expr::Mul(_, _)) {
245 format!("({exp})")
246 } else {
247 format!("{exp}")
248 };
249 write!(f, "{base_str}^{exp_str}")
250 }
251 Expr::Exp(a) => write!(f, "exp({a})"),
252 Expr::Log(a) => write!(f, "log({a})"),
253 Expr::Sqrt(a) => write!(f, "sqrt({a})"),
254 Expr::Factorial(a) => write!(f, "factorial({a})"),
255 }
256 }
257}
258
259impl std::ops::Add for Expr {
260 type Output = Self;
261
262 fn add(self, other: Self) -> Self {
263 Expr::Add(Box::new(self), Box::new(other))
264 }
265}
266
267impl std::ops::Mul for Expr {
268 type Output = Self;
269
270 fn mul(self, other: Self) -> Self {
271 Expr::Mul(Box::new(self), Box::new(other))
272 }
273}
274
275impl std::ops::Sub for Expr {
276 type Output = Self;
277
278 fn sub(self, other: Self) -> Self {
279 self + Expr::Const(-1.0) * other
280 }
281}
282
283impl std::ops::Div for Expr {
284 type Output = Self;
285
286 fn div(self, other: Self) -> Self {
287 self * Expr::pow(other, Expr::Const(-1.0))
288 }
289}
290
291impl std::ops::Neg for Expr {
292 type Output = Self;
293
294 fn neg(self) -> Self {
295 Expr::Const(-1.0) * self
296 }
297}
298
299#[derive(Clone, Debug, PartialEq, Eq)]
301pub enum AsymptoticAnalysisError {
302 Unsupported(String),
303}
304
305impl fmt::Display for AsymptoticAnalysisError {
306 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
307 match self {
308 Self::Unsupported(expr) => write!(f, "unsupported asymptotic expression: {expr}"),
309 }
310 }
311}
312
313impl std::error::Error for AsymptoticAnalysisError {}
314
315#[derive(Clone, Debug, PartialEq, Eq)]
317pub enum CanonicalizationError {
318 Unsupported(String),
320}
321
322impl fmt::Display for CanonicalizationError {
323 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
324 match self {
325 Self::Unsupported(expr) => {
326 write!(f, "unsupported expression for canonicalization: {expr}")
327 }
328 }
329 }
330}
331
332impl std::error::Error for CanonicalizationError {}
333
334pub fn asymptotic_normal_form(expr: &Expr) -> Result<Expr, AsymptoticAnalysisError> {
338 crate::big_o::big_o_normal_form(expr)
339}
340
341fn gamma_factorial(n: f64) -> f64 {
347 if n < 0.0 {
348 return f64::NAN;
349 }
350 let rounded = n.round();
351 if (n - rounded).abs() < 1e-10 && rounded >= 0.0 {
352 let k = rounded as u64;
353 let mut result = 1u64;
354 for i in 2..=k {
355 result = result.saturating_mul(i);
356 }
357 result as f64
358 } else {
359 (2.0 * std::f64::consts::PI * n).sqrt() * (n / std::f64::consts::E).powf(n)
361 }
362}
363
364fn parse_to_expr(input: &str) -> Result<Expr, String> {
371 let tokens = tokenize_expr(input)?;
372 let mut parser = ExprParser::new(tokens);
373 let expr = parser.parse_additive()?;
374 if parser.pos != parser.tokens.len() {
375 return Err(format!("trailing tokens at position {}", parser.pos));
376 }
377 Ok(expr)
378}
379
380#[derive(Debug, Clone, PartialEq)]
381enum ExprToken {
382 Number(f64),
383 Ident(String),
384 Plus,
385 Minus,
386 Star,
387 Slash,
388 Caret,
389 LParen,
390 RParen,
391}
392
393fn tokenize_expr(input: &str) -> Result<Vec<ExprToken>, String> {
394 let mut tokens = Vec::new();
395 let mut chars = input.chars().peekable();
396 while let Some(&ch) = chars.peek() {
397 match ch {
398 ' ' | '\t' | '\n' => {
399 chars.next();
400 }
401 '+' => {
402 chars.next();
403 tokens.push(ExprToken::Plus);
404 }
405 '-' => {
406 chars.next();
407 tokens.push(ExprToken::Minus);
408 }
409 '*' => {
410 chars.next();
411 tokens.push(ExprToken::Star);
412 }
413 '/' => {
414 chars.next();
415 tokens.push(ExprToken::Slash);
416 }
417 '^' => {
418 chars.next();
419 tokens.push(ExprToken::Caret);
420 }
421 '(' => {
422 chars.next();
423 tokens.push(ExprToken::LParen);
424 }
425 ')' => {
426 chars.next();
427 tokens.push(ExprToken::RParen);
428 }
429 c if c.is_ascii_digit() || c == '.' => {
430 let mut num = String::new();
431 while let Some(&c) = chars.peek() {
432 if c.is_ascii_digit() || c == '.' {
433 num.push(c);
434 chars.next();
435 } else {
436 break;
437 }
438 }
439 tokens.push(ExprToken::Number(
440 num.parse().map_err(|_| format!("invalid number: {num}"))?,
441 ));
442 }
443 c if c.is_ascii_alphabetic() || c == '_' => {
444 let mut ident = String::new();
445 while let Some(&c) = chars.peek() {
446 if c.is_ascii_alphanumeric() || c == '_' {
447 ident.push(c);
448 chars.next();
449 } else {
450 break;
451 }
452 }
453 tokens.push(ExprToken::Ident(ident));
454 }
455 _ => return Err(format!("unexpected character: '{ch}'")),
456 }
457 }
458 Ok(tokens)
459}
460
461struct ExprParser {
462 tokens: Vec<ExprToken>,
463 pos: usize,
464}
465
466impl ExprParser {
467 fn new(tokens: Vec<ExprToken>) -> Self {
468 Self { tokens, pos: 0 }
469 }
470
471 fn peek(&self) -> Option<&ExprToken> {
472 self.tokens.get(self.pos)
473 }
474
475 fn advance(&mut self) -> Option<ExprToken> {
476 let tok = self.tokens.get(self.pos).cloned();
477 self.pos += 1;
478 tok
479 }
480
481 fn expect(&mut self, expected: &ExprToken) -> Result<(), String> {
482 match self.advance() {
483 Some(ref tok) if tok == expected => Ok(()),
484 Some(tok) => Err(format!("expected {expected:?}, got {tok:?}")),
485 None => Err(format!("expected {expected:?}, got end of input")),
486 }
487 }
488
489 fn parse_additive(&mut self) -> Result<Expr, String> {
490 let mut left = self.parse_multiplicative()?;
491 while matches!(self.peek(), Some(ExprToken::Plus) | Some(ExprToken::Minus)) {
492 let op = self.advance().unwrap();
493 let right = self.parse_multiplicative()?;
494 left = match op {
495 ExprToken::Plus => left + right,
496 ExprToken::Minus => left - right,
497 _ => unreachable!(),
498 };
499 }
500 Ok(left)
501 }
502
503 fn parse_multiplicative(&mut self) -> Result<Expr, String> {
504 let mut left = self.parse_unary()?;
505 while matches!(self.peek(), Some(ExprToken::Star) | Some(ExprToken::Slash)) {
506 let op = self.advance().unwrap();
507 let right = self.parse_unary()?;
508 left = match op {
509 ExprToken::Star => left * right,
510 ExprToken::Slash => left / right,
511 _ => unreachable!(),
512 };
513 }
514 Ok(left)
515 }
516
517 fn parse_power(&mut self) -> Result<Expr, String> {
518 let base = self.parse_primary()?;
519 if matches!(self.peek(), Some(ExprToken::Caret)) {
520 self.advance();
521 let exp = self.parse_unary()?; Ok(Expr::pow(base, exp))
523 } else {
524 Ok(base)
525 }
526 }
527
528 fn parse_unary(&mut self) -> Result<Expr, String> {
529 if matches!(self.peek(), Some(ExprToken::Minus)) {
530 self.advance();
531 let expr = self.parse_unary()?;
532 Ok(-expr)
533 } else {
534 self.parse_power()
535 }
536 }
537
538 fn parse_primary(&mut self) -> Result<Expr, String> {
539 match self.advance() {
540 Some(ExprToken::Number(n)) => Ok(Expr::Const(n)),
541 Some(ExprToken::Ident(name)) => {
542 if matches!(self.peek(), Some(ExprToken::LParen)) {
543 self.advance();
544 let arg = self.parse_additive()?;
545 self.expect(&ExprToken::RParen)?;
546 match name.as_str() {
547 "exp" => Ok(Expr::Exp(Box::new(arg))),
548 "log" => Ok(Expr::Log(Box::new(arg))),
549 "sqrt" => Ok(Expr::Sqrt(Box::new(arg))),
550 "factorial" => Ok(Expr::Factorial(Box::new(arg))),
551 _ => Err(format!("unknown function: {name}")),
552 }
553 } else {
554 let leaked: &'static str = Box::leak(name.into_boxed_str());
556 Ok(Expr::Var(leaked))
557 }
558 }
559 Some(ExprToken::LParen) => {
560 let expr = self.parse_additive()?;
561 self.expect(&ExprToken::RParen)?;
562 Ok(expr)
563 }
564 Some(tok) => Err(format!("unexpected token: {tok:?}")),
565 None => Err("unexpected end of input".to_string()),
566 }
567 }
568}
569
570#[cfg(test)]
571#[path = "unit_tests/expr.rs"]
572mod tests;