Skip to main content

problemreductions/
expr.rs

1//! General symbolic expression AST for reduction overhead.
2
3use crate::types::ProblemSize;
4use std::collections::{HashMap, HashSet};
5use std::fmt;
6
7/// A symbolic math expression over problem size variables.
8#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
9pub enum Expr {
10    /// Numeric constant.
11    Const(f64),
12    /// Named variable (e.g., "num_vertices").
13    Var(&'static str),
14    /// Addition: a + b.
15    Add(Box<Expr>, Box<Expr>),
16    /// Multiplication: a * b.
17    Mul(Box<Expr>, Box<Expr>),
18    /// Exponentiation: base ^ exponent.
19    Pow(Box<Expr>, Box<Expr>),
20    /// Exponential function: exp(a).
21    Exp(Box<Expr>),
22    /// Natural logarithm: log(a).
23    Log(Box<Expr>),
24    /// Square root: sqrt(a).
25    Sqrt(Box<Expr>),
26    /// Factorial: factorial(a).
27    Factorial(Box<Expr>),
28}
29
30impl Expr {
31    /// Convenience constructor for exponentiation.
32    pub fn pow(base: Expr, exp: Expr) -> Self {
33        Expr::Pow(Box::new(base), Box::new(exp))
34    }
35
36    /// Multiply expression by a scalar constant.
37    pub fn scale(self, c: f64) -> Self {
38        Expr::Const(c) * self
39    }
40
41    /// Evaluate the expression given concrete variable values.
42    pub fn eval(&self, vars: &ProblemSize) -> f64 {
43        match self {
44            Expr::Const(c) => *c,
45            Expr::Var(name) => vars.get(name).unwrap_or(0) as f64,
46            Expr::Add(a, b) => a.eval(vars) + b.eval(vars),
47            Expr::Mul(a, b) => a.eval(vars) * b.eval(vars),
48            Expr::Pow(base, exp) => base.eval(vars).powf(exp.eval(vars)),
49            Expr::Exp(a) => a.eval(vars).exp(),
50            Expr::Log(a) => a.eval(vars).ln(),
51            Expr::Sqrt(a) => a.eval(vars).sqrt(),
52            Expr::Factorial(a) => gamma_factorial(a.eval(vars)),
53        }
54    }
55
56    /// Collect all variable names referenced in this expression.
57    pub fn variables(&self) -> HashSet<&'static str> {
58        let mut vars = HashSet::new();
59        self.collect_variables(&mut vars);
60        vars
61    }
62
63    fn collect_variables(&self, vars: &mut HashSet<&'static str>) {
64        match self {
65            Expr::Const(_) => {}
66            Expr::Var(name) => {
67                vars.insert(name);
68            }
69            Expr::Add(a, b) | Expr::Mul(a, b) | Expr::Pow(a, b) => {
70                a.collect_variables(vars);
71                b.collect_variables(vars);
72            }
73            Expr::Exp(a) | Expr::Log(a) | Expr::Sqrt(a) | Expr::Factorial(a) => {
74                a.collect_variables(vars);
75            }
76        }
77    }
78
79    /// Substitute variables with other expressions.
80    pub fn substitute(&self, mapping: &HashMap<&str, &Expr>) -> Expr {
81        match self {
82            Expr::Const(c) => Expr::Const(*c),
83            Expr::Var(name) => {
84                if let Some(replacement) = mapping.get(name) {
85                    (*replacement).clone()
86                } else {
87                    Expr::Var(name)
88                }
89            }
90            Expr::Add(a, b) => a.substitute(mapping) + b.substitute(mapping),
91            Expr::Mul(a, b) => a.substitute(mapping) * b.substitute(mapping),
92            Expr::Pow(a, b) => Expr::pow(a.substitute(mapping), b.substitute(mapping)),
93            Expr::Exp(a) => Expr::Exp(Box::new(a.substitute(mapping))),
94            Expr::Log(a) => Expr::Log(Box::new(a.substitute(mapping))),
95            Expr::Sqrt(a) => Expr::Sqrt(Box::new(a.substitute(mapping))),
96            Expr::Factorial(a) => Expr::Factorial(Box::new(a.substitute(mapping))),
97        }
98    }
99
100    /// Parse an expression string into an `Expr` at runtime.
101    ///
102    /// **Memory note:** Variable names are leaked to `&'static str` via `Box::leak`
103    /// since `Expr::Var` requires static lifetimes. Each unique variable name leaks
104    /// a small allocation that is never freed. This is acceptable for testing and
105    /// one-time cross-check evaluation, but should not be used in hot loops with
106    /// dynamic input.
107    ///
108    /// # Panics
109    /// Panics if the expression string has invalid syntax.
110    pub fn parse(input: &str) -> Expr {
111        Self::try_parse(input)
112            .unwrap_or_else(|e| panic!("failed to parse expression \"{input}\": {e}"))
113    }
114
115    /// Parse an expression string into an `Expr`, returning a normal error on failure.
116    pub fn try_parse(input: &str) -> Result<Expr, String> {
117        parse_to_expr(input)
118    }
119
120    /// Check if this expression is a polynomial (no exp/log/sqrt, integer exponents only).
121    pub fn is_polynomial(&self) -> bool {
122        match self {
123            Expr::Const(_) | Expr::Var(_) => true,
124            Expr::Add(a, b) | Expr::Mul(a, b) => a.is_polynomial() && b.is_polynomial(),
125            Expr::Pow(base, exp) => {
126                base.is_polynomial()
127                    && matches!(exp.as_ref(), Expr::Const(c) if *c >= 0.0 && (*c - c.round()).abs() < 1e-10)
128            }
129            Expr::Exp(_) | Expr::Log(_) | Expr::Sqrt(_) | Expr::Factorial(_) => false,
130        }
131    }
132
133    /// Check whether this expression is suitable for asymptotic complexity notation.
134    ///
135    /// This is intentionally conservative for symbolic size formulas:
136    /// - rejects explicit multiplicative constant factors like `3 * n`
137    /// - rejects additive constant terms like `n + 1`
138    /// - allows constants used as exponents (e.g. `n^(1/3)`)
139    /// - allows constants used as exponential bases (e.g. `2^n`)
140    ///
141    /// The goal is to accept expressions that already look like reduced
142    /// asymptotic notation, rather than exact-count formulas.
143    pub fn is_valid_complexity_notation(&self) -> bool {
144        self.is_valid_complexity_notation_inner()
145    }
146
147    fn is_valid_complexity_notation_inner(&self) -> bool {
148        match self {
149            Expr::Const(c) => (*c - 1.0).abs() < 1e-10,
150            Expr::Var(_) => true,
151            Expr::Add(a, b) => {
152                a.constant_value().is_none()
153                    && b.constant_value().is_none()
154                    && a.is_valid_complexity_notation_inner()
155                    && b.is_valid_complexity_notation_inner()
156            }
157            Expr::Mul(a, b) => {
158                a.constant_value().is_none()
159                    && b.constant_value().is_none()
160                    && a.is_valid_complexity_notation_inner()
161                    && b.is_valid_complexity_notation_inner()
162            }
163            Expr::Pow(base, exp) => {
164                let base_is_constant = base.constant_value().is_some();
165                let exp_is_constant = exp.constant_value().is_some();
166
167                let base_ok = if base_is_constant {
168                    base.is_valid_exponential_base()
169                } else {
170                    base.is_valid_complexity_notation_inner()
171                };
172
173                let exp_ok = if exp_is_constant {
174                    true
175                } else {
176                    exp.is_valid_complexity_notation_inner()
177                };
178
179                base_ok && exp_ok
180            }
181            Expr::Exp(a) | Expr::Log(a) | Expr::Sqrt(a) | Expr::Factorial(a) => {
182                a.is_valid_complexity_notation_inner()
183            }
184        }
185    }
186
187    fn is_valid_exponential_base(&self) -> bool {
188        self.constant_value().is_some_and(|c| c > 0.0)
189    }
190
191    pub(crate) fn constant_value(&self) -> Option<f64> {
192        match self {
193            Expr::Const(c) => Some(*c),
194            Expr::Var(_) => None,
195            Expr::Add(a, b) => Some(a.constant_value()? + b.constant_value()?),
196            Expr::Mul(a, b) => Some(a.constant_value()? * b.constant_value()?),
197            Expr::Pow(base, exp) => Some(base.constant_value()?.powf(exp.constant_value()?)),
198            Expr::Exp(a) => Some(a.constant_value()?.exp()),
199            Expr::Log(a) => Some(a.constant_value()?.ln()),
200            Expr::Sqrt(a) => Some(a.constant_value()?.sqrt()),
201            Expr::Factorial(a) => Some(gamma_factorial(a.constant_value()?)),
202        }
203    }
204}
205
206impl fmt::Display for Expr {
207    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208        match self {
209            Expr::Const(c) => {
210                let ci = c.round() as i64;
211                if (*c - ci as f64).abs() < 1e-10 {
212                    write!(f, "{ci}")
213                } else {
214                    write!(f, "{c}")
215                }
216            }
217            Expr::Var(name) => write!(f, "{name}"),
218            Expr::Add(a, b) => write!(f, "{a} + {b}"),
219            Expr::Mul(a, b) => {
220                let left = if matches!(a.as_ref(), Expr::Add(_, _)) {
221                    format!("({a})")
222                } else {
223                    format!("{a}")
224                };
225                let right = if matches!(b.as_ref(), Expr::Add(_, _)) {
226                    format!("({b})")
227                } else {
228                    format!("{b}")
229                };
230                write!(f, "{left} * {right}")
231            }
232            Expr::Pow(base, exp) => {
233                // Special case: x^0.5 → sqrt(x)
234                if let Expr::Const(e) = exp.as_ref() {
235                    if (*e - 0.5).abs() < 1e-15 {
236                        return write!(f, "sqrt({base})");
237                    }
238                }
239                let base_str = if matches!(base.as_ref(), Expr::Add(_, _) | Expr::Mul(_, _)) {
240                    format!("({base})")
241                } else {
242                    format!("{base}")
243                };
244                let exp_str = if matches!(exp.as_ref(), Expr::Add(_, _) | Expr::Mul(_, _)) {
245                    format!("({exp})")
246                } else {
247                    format!("{exp}")
248                };
249                write!(f, "{base_str}^{exp_str}")
250            }
251            Expr::Exp(a) => write!(f, "exp({a})"),
252            Expr::Log(a) => write!(f, "log({a})"),
253            Expr::Sqrt(a) => write!(f, "sqrt({a})"),
254            Expr::Factorial(a) => write!(f, "factorial({a})"),
255        }
256    }
257}
258
259impl std::ops::Add for Expr {
260    type Output = Self;
261
262    fn add(self, other: Self) -> Self {
263        Expr::Add(Box::new(self), Box::new(other))
264    }
265}
266
267impl std::ops::Mul for Expr {
268    type Output = Self;
269
270    fn mul(self, other: Self) -> Self {
271        Expr::Mul(Box::new(self), Box::new(other))
272    }
273}
274
275impl std::ops::Sub for Expr {
276    type Output = Self;
277
278    fn sub(self, other: Self) -> Self {
279        self + Expr::Const(-1.0) * other
280    }
281}
282
283impl std::ops::Div for Expr {
284    type Output = Self;
285
286    fn div(self, other: Self) -> Self {
287        self * Expr::pow(other, Expr::Const(-1.0))
288    }
289}
290
291impl std::ops::Neg for Expr {
292    type Output = Self;
293
294    fn neg(self) -> Self {
295        Expr::Const(-1.0) * self
296    }
297}
298
299/// Error returned when analyzing asymptotic behavior.
300#[derive(Clone, Debug, PartialEq, Eq)]
301pub enum AsymptoticAnalysisError {
302    Unsupported(String),
303}
304
305impl fmt::Display for AsymptoticAnalysisError {
306    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
307        match self {
308            Self::Unsupported(expr) => write!(f, "unsupported asymptotic expression: {expr}"),
309        }
310    }
311}
312
313impl std::error::Error for AsymptoticAnalysisError {}
314
315/// Error returned when exact canonicalization fails.
316#[derive(Clone, Debug, PartialEq, Eq)]
317pub enum CanonicalizationError {
318    /// Expression cannot be canonicalized (e.g., variable in both base and exponent).
319    Unsupported(String),
320}
321
322impl fmt::Display for CanonicalizationError {
323    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
324        match self {
325            Self::Unsupported(expr) => {
326                write!(f, "unsupported expression for canonicalization: {expr}")
327            }
328        }
329    }
330}
331
332impl std::error::Error for CanonicalizationError {}
333
334/// Return a normalized `Expr` representing the asymptotic behavior of `expr`.
335///
336/// This is now a compatibility wrapper for `big_o_normal_form()`.
337pub fn asymptotic_normal_form(expr: &Expr) -> Result<Expr, AsymptoticAnalysisError> {
338    crate::big_o::big_o_normal_form(expr)
339}
340
341/// Compute factorial for non-negative values.
342///
343/// For non-negative integers, returns the exact integer factorial.
344/// For non-integer values, uses Stirling's approximation of the gamma function:
345/// n! = Γ(n+1) ≈ √(2πn) · (n/e)^n.
346fn gamma_factorial(n: f64) -> f64 {
347    if n < 0.0 {
348        return f64::NAN;
349    }
350    let rounded = n.round();
351    if (n - rounded).abs() < 1e-10 && rounded >= 0.0 {
352        let k = rounded as u64;
353        let mut result = 1u64;
354        for i in 2..=k {
355            result = result.saturating_mul(i);
356        }
357        result as f64
358    } else {
359        // Stirling's approximation: Γ(n+1) ≈ √(2πn) · (n/e)^n
360        (2.0 * std::f64::consts::PI * n).sqrt() * (n / std::f64::consts::E).powf(n)
361    }
362}
363
364// --- Runtime expression parser ---
365
366/// Parse an expression string into an `Expr`.
367///
368/// Uses the same grammar as the proc macro parser. Variable names are leaked
369/// to `&'static str` for compatibility with `Expr::Var`.
370fn parse_to_expr(input: &str) -> Result<Expr, String> {
371    let tokens = tokenize_expr(input)?;
372    let mut parser = ExprParser::new(tokens);
373    let expr = parser.parse_additive()?;
374    if parser.pos != parser.tokens.len() {
375        return Err(format!("trailing tokens at position {}", parser.pos));
376    }
377    Ok(expr)
378}
379
380#[derive(Debug, Clone, PartialEq)]
381enum ExprToken {
382    Number(f64),
383    Ident(String),
384    Plus,
385    Minus,
386    Star,
387    Slash,
388    Caret,
389    LParen,
390    RParen,
391}
392
393fn tokenize_expr(input: &str) -> Result<Vec<ExprToken>, String> {
394    let mut tokens = Vec::new();
395    let mut chars = input.chars().peekable();
396    while let Some(&ch) = chars.peek() {
397        match ch {
398            ' ' | '\t' | '\n' => {
399                chars.next();
400            }
401            '+' => {
402                chars.next();
403                tokens.push(ExprToken::Plus);
404            }
405            '-' => {
406                chars.next();
407                tokens.push(ExprToken::Minus);
408            }
409            '*' => {
410                chars.next();
411                tokens.push(ExprToken::Star);
412            }
413            '/' => {
414                chars.next();
415                tokens.push(ExprToken::Slash);
416            }
417            '^' => {
418                chars.next();
419                tokens.push(ExprToken::Caret);
420            }
421            '(' => {
422                chars.next();
423                tokens.push(ExprToken::LParen);
424            }
425            ')' => {
426                chars.next();
427                tokens.push(ExprToken::RParen);
428            }
429            c if c.is_ascii_digit() || c == '.' => {
430                let mut num = String::new();
431                while let Some(&c) = chars.peek() {
432                    if c.is_ascii_digit() || c == '.' {
433                        num.push(c);
434                        chars.next();
435                    } else {
436                        break;
437                    }
438                }
439                tokens.push(ExprToken::Number(
440                    num.parse().map_err(|_| format!("invalid number: {num}"))?,
441                ));
442            }
443            c if c.is_ascii_alphabetic() || c == '_' => {
444                let mut ident = String::new();
445                while let Some(&c) = chars.peek() {
446                    if c.is_ascii_alphanumeric() || c == '_' {
447                        ident.push(c);
448                        chars.next();
449                    } else {
450                        break;
451                    }
452                }
453                tokens.push(ExprToken::Ident(ident));
454            }
455            _ => return Err(format!("unexpected character: '{ch}'")),
456        }
457    }
458    Ok(tokens)
459}
460
461struct ExprParser {
462    tokens: Vec<ExprToken>,
463    pos: usize,
464}
465
466impl ExprParser {
467    fn new(tokens: Vec<ExprToken>) -> Self {
468        Self { tokens, pos: 0 }
469    }
470
471    fn peek(&self) -> Option<&ExprToken> {
472        self.tokens.get(self.pos)
473    }
474
475    fn advance(&mut self) -> Option<ExprToken> {
476        let tok = self.tokens.get(self.pos).cloned();
477        self.pos += 1;
478        tok
479    }
480
481    fn expect(&mut self, expected: &ExprToken) -> Result<(), String> {
482        match self.advance() {
483            Some(ref tok) if tok == expected => Ok(()),
484            Some(tok) => Err(format!("expected {expected:?}, got {tok:?}")),
485            None => Err(format!("expected {expected:?}, got end of input")),
486        }
487    }
488
489    fn parse_additive(&mut self) -> Result<Expr, String> {
490        let mut left = self.parse_multiplicative()?;
491        while matches!(self.peek(), Some(ExprToken::Plus) | Some(ExprToken::Minus)) {
492            let op = self.advance().unwrap();
493            let right = self.parse_multiplicative()?;
494            left = match op {
495                ExprToken::Plus => left + right,
496                ExprToken::Minus => left - right,
497                _ => unreachable!(),
498            };
499        }
500        Ok(left)
501    }
502
503    fn parse_multiplicative(&mut self) -> Result<Expr, String> {
504        let mut left = self.parse_unary()?;
505        while matches!(self.peek(), Some(ExprToken::Star) | Some(ExprToken::Slash)) {
506            let op = self.advance().unwrap();
507            let right = self.parse_unary()?;
508            left = match op {
509                ExprToken::Star => left * right,
510                ExprToken::Slash => left / right,
511                _ => unreachable!(),
512            };
513        }
514        Ok(left)
515    }
516
517    fn parse_power(&mut self) -> Result<Expr, String> {
518        let base = self.parse_primary()?;
519        if matches!(self.peek(), Some(ExprToken::Caret)) {
520            self.advance();
521            let exp = self.parse_unary()?; // right-associative, allows unary minus in exponent
522            Ok(Expr::pow(base, exp))
523        } else {
524            Ok(base)
525        }
526    }
527
528    fn parse_unary(&mut self) -> Result<Expr, String> {
529        if matches!(self.peek(), Some(ExprToken::Minus)) {
530            self.advance();
531            let expr = self.parse_unary()?;
532            Ok(-expr)
533        } else {
534            self.parse_power()
535        }
536    }
537
538    fn parse_primary(&mut self) -> Result<Expr, String> {
539        match self.advance() {
540            Some(ExprToken::Number(n)) => Ok(Expr::Const(n)),
541            Some(ExprToken::Ident(name)) => {
542                if matches!(self.peek(), Some(ExprToken::LParen)) {
543                    self.advance();
544                    let arg = self.parse_additive()?;
545                    self.expect(&ExprToken::RParen)?;
546                    match name.as_str() {
547                        "exp" => Ok(Expr::Exp(Box::new(arg))),
548                        "log" => Ok(Expr::Log(Box::new(arg))),
549                        "sqrt" => Ok(Expr::Sqrt(Box::new(arg))),
550                        "factorial" => Ok(Expr::Factorial(Box::new(arg))),
551                        _ => Err(format!("unknown function: {name}")),
552                    }
553                } else {
554                    // Leak the string to get &'static str for Expr::Var
555                    let leaked: &'static str = Box::leak(name.into_boxed_str());
556                    Ok(Expr::Var(leaked))
557                }
558            }
559            Some(ExprToken::LParen) => {
560                let expr = self.parse_additive()?;
561                self.expect(&ExprToken::RParen)?;
562                Ok(expr)
563            }
564            Some(tok) => Err(format!("unexpected token: {tok:?}")),
565            None => Err("unexpected end of input".to_string()),
566        }
567    }
568}
569
570#[cfg(test)]
571#[path = "unit_tests/expr.rs"]
572mod tests;