Skip to main content

quantrs2_symengine_pure/
expr.rs

1//! Expression AST and core types for symbolic computation.
2//!
3//! This module defines the symbolic expression type using egg's e-graph
4//! for efficient representation and manipulation.
5
6use std::collections::HashMap;
7use std::fmt;
8use std::hash::Hash;
9use std::str::FromStr;
10use std::sync::Arc;
11
12use egg::{
13    define_language, Analysis, CostFunction, EGraph, Id, Language, RecExpr, Rewrite, Runner, Symbol,
14};
15use scirs2_core::Complex64;
16
17use crate::error::{SymEngineError, SymEngineResult};
18
19// The symbolic expression language using egg's macro.
20//
21// This language supports:
22// - Numeric constants (represented as strings to avoid trait issues)
23// - Symbols (variables)
24// - Arithmetic operations (add, mul, pow, neg, inv)
25// - Transcendental functions (sin, cos, exp, log, sqrt)
26// - Quantum-specific operations (commutator, anticommutator, tensor_product)
27define_language! {
28    /// The symbolic expression language
29    pub enum ExprLang {
30        // Use Symbol for both variable names and numeric literals
31        // Numbers are stored as strings like "42" or "3.14"
32        Num(Symbol),
33
34        // Binary arithmetic operations
35        "+" = Add([Id; 2]),
36        "*" = Mul([Id; 2]),
37        "/" = Div([Id; 2]),
38        "^" = Pow([Id; 2]),
39
40        // Unary operations
41        "neg" = Neg([Id; 1]),
42        "inv" = Inv([Id; 1]),
43        "abs" = Abs([Id; 1]),
44
45        // Transcendental functions
46        "sin" = Sin([Id; 1]),
47        "cos" = Cos([Id; 1]),
48        "tan" = Tan([Id; 1]),
49        "exp" = Exp([Id; 1]),
50        "log" = Log([Id; 1]),
51        "sqrt" = Sqrt([Id; 1]),
52        "asin" = Asin([Id; 1]),
53        "acos" = Acos([Id; 1]),
54        "atan" = Atan([Id; 1]),
55        "sinh" = Sinh([Id; 1]),
56        "cosh" = Cosh([Id; 1]),
57        "tanh" = Tanh([Id; 1]),
58
59        // Complex number operations
60        "re" = Re([Id; 1]),
61        "im" = Im([Id; 1]),
62        "conj" = Conj([Id; 1]),
63
64        // Quantum-specific operations
65        "comm" = Commutator([Id; 2]),      // [A, B] = AB - BA
66        "anticomm" = Anticommutator([Id; 2]), // {A, B} = AB + BA
67        "tensor" = TensorProduct([Id; 2]),  // A ⊗ B
68        "trace" = Trace([Id; 1]),
69        "dagger" = Dagger([Id; 1]),         // Hermitian conjugate
70
71        // Matrix operations
72        "det" = Determinant([Id; 1]),
73        "transpose" = Transpose([Id; 1]),
74    }
75}
76
77/// A symbolic mathematical expression.
78///
79/// This type wraps egg's `RecExpr` and provides a user-friendly API
80/// for symbolic computation.
81#[derive(Clone, Debug)]
82pub struct Expression {
83    /// The underlying recursive expression
84    expr: RecExpr<ExprLang>,
85}
86
87impl Expression {
88    // =========================================================================
89    // Construction
90    // =========================================================================
91
92    /// Create a new symbolic variable
93    ///
94    /// # Example
95    /// ```ignore
96    /// use quantrs2_symengine_pure::Expression;
97    /// let x = Expression::symbol("x");
98    /// ```
99    #[must_use]
100    pub fn symbol(name: &str) -> Self {
101        let mut expr = RecExpr::default();
102        expr.add(ExprLang::Num(Symbol::from(name)));
103        Self { expr }
104    }
105
106    /// Create an integer constant
107    #[must_use]
108    pub fn int(value: i64) -> Self {
109        let mut expr = RecExpr::default();
110        expr.add(ExprLang::Num(Symbol::from(value.to_string())));
111        Self { expr }
112    }
113
114    /// Create a floating-point constant
115    ///
116    /// # Errors
117    /// Returns an error if the value is NaN
118    pub fn float(value: f64) -> SymEngineResult<Self> {
119        if value.is_nan() {
120            return Err(SymEngineError::Undefined(
121                "NaN is not a valid symbolic value".into(),
122            ));
123        }
124        let mut expr = RecExpr::default();
125        expr.add(ExprLang::Num(Symbol::from(value.to_string())));
126        Ok(Self { expr })
127    }
128
129    /// Create a floating-point constant, using 0 for NaN
130    #[must_use]
131    pub fn float_unchecked(value: f64) -> Self {
132        let v = if value.is_nan() { 0.0 } else { value };
133        let mut expr = RecExpr::default();
134        expr.add(ExprLang::Num(Symbol::from(v.to_string())));
135        Self { expr }
136    }
137
138    /// Create the constant zero
139    #[must_use]
140    pub fn zero() -> Self {
141        Self::int(0)
142    }
143
144    /// Create the constant one
145    #[must_use]
146    pub fn one() -> Self {
147        Self::int(1)
148    }
149
150    /// Create the imaginary unit i
151    #[must_use]
152    pub fn i() -> Self {
153        Self::symbol("I")
154    }
155
156    /// Create the constant π
157    #[must_use]
158    pub fn pi() -> Self {
159        Self::symbol("pi")
160    }
161
162    /// Create the constant e (Euler's number)
163    #[must_use]
164    pub fn e() -> Self {
165        Self::symbol("e")
166    }
167
168    /// Create from a complex number
169    ///
170    /// If imaginary part is negligible, returns just the real part.
171    #[must_use]
172    pub fn from_complex64(c: Complex64) -> Self {
173        const EPSILON: f64 = 1e-15;
174        if c.im.abs() < EPSILON {
175            Self::float_unchecked(c.re)
176        } else if c.re.abs() < EPSILON {
177            // Pure imaginary
178            Self::float_unchecked(c.im) * Self::i()
179        } else {
180            // General complex
181            Self::float_unchecked(c.re) + Self::float_unchecked(c.im) * Self::i()
182        }
183    }
184
185    /// Parse an expression from a string
186    ///
187    /// # Errors
188    /// Returns an error if parsing fails
189    pub fn parse(input: &str) -> SymEngineResult<Self> {
190        let trimmed = input.trim();
191        if trimmed.is_empty() {
192            return Err(SymEngineError::parse("empty expression"));
193        }
194
195        // Try to parse as number
196        if let Ok(n) = trimmed.parse::<i64>() {
197            return Ok(Self::int(n));
198        }
199        if let Ok(f) = trimmed.parse::<f64>() {
200            return Self::float(f);
201        }
202
203        // Otherwise treat as symbol/expression
204        Ok(Self::symbol(trimmed))
205    }
206
207    /// Create an expression from a string (alias for parse)
208    #[must_use]
209    pub fn new(input: impl AsRef<str>) -> Self {
210        Self::parse(input.as_ref()).unwrap_or_else(|_| Self::symbol(input.as_ref()))
211    }
212
213    // =========================================================================
214    // Accessors
215    // =========================================================================
216
217    /// Get the root node of the expression
218    fn root(&self) -> &ExprLang {
219        &self.expr[self.root_id()]
220    }
221
222    /// Get the root ID
223    fn root_id(&self) -> Id {
224        Id::from(self.expr.as_ref().len() - 1)
225    }
226
227    /// Check if this expression is a symbol (variable or number literal)
228    #[must_use]
229    pub fn is_symbol(&self) -> bool {
230        matches!(self.root(), ExprLang::Num(_))
231    }
232
233    /// Check if this expression is a number
234    #[must_use]
235    pub fn is_number(&self) -> bool {
236        if let ExprLang::Num(s) = self.root() {
237            s.as_str().parse::<f64>().is_ok()
238        } else {
239            false
240        }
241    }
242
243    /// Check if this expression is zero
244    #[must_use]
245    pub fn is_zero(&self) -> bool {
246        if let ExprLang::Num(s) = self.root() {
247            s.as_str() == "0" || s.as_str().parse::<f64>().is_ok_and(|v| v.abs() < 1e-15)
248        } else {
249            false
250        }
251    }
252
253    /// Check if this expression is one
254    #[must_use]
255    pub fn is_one(&self) -> bool {
256        if let ExprLang::Num(s) = self.root() {
257            s.as_str() == "1"
258                || s.as_str()
259                    .parse::<f64>()
260                    .is_ok_and(|v| (v - 1.0).abs() < 1e-15)
261        } else {
262            false
263        }
264    }
265
266    /// Get the symbol name if this is a symbol
267    #[must_use]
268    pub fn as_symbol(&self) -> Option<&str> {
269        if let ExprLang::Num(s) = self.root() {
270            // Only return as symbol if it's not a number
271            if s.as_str().parse::<f64>().is_err() {
272                return Some(s.as_str());
273            }
274        }
275        None
276    }
277
278    /// Convert to f64 if this is a numeric constant
279    #[must_use]
280    pub fn to_f64(&self) -> Option<f64> {
281        if let ExprLang::Num(s) = self.root() {
282            s.as_str().parse::<f64>().ok()
283        } else {
284            None
285        }
286    }
287
288    /// Convert to i64 if this is an integer constant
289    #[must_use]
290    pub fn to_i64(&self) -> Option<i64> {
291        if let ExprLang::Num(s) = self.root() {
292            s.as_str().parse::<i64>().ok()
293        } else {
294            None
295        }
296    }
297
298    /// Check if this expression is an addition operation
299    #[must_use]
300    pub fn is_add(&self) -> bool {
301        matches!(self.root(), ExprLang::Add(_))
302    }
303
304    /// Check if this expression is a multiplication operation
305    #[must_use]
306    pub fn is_mul(&self) -> bool {
307        matches!(self.root(), ExprLang::Mul(_))
308    }
309
310    /// Check if this expression is a power operation
311    #[must_use]
312    pub fn is_pow(&self) -> bool {
313        matches!(self.root(), ExprLang::Pow(_))
314    }
315
316    /// Check if this expression is a negation operation
317    #[must_use]
318    pub fn is_neg(&self) -> bool {
319        matches!(self.root(), ExprLang::Neg(_))
320    }
321
322    /// Get the inner expression if this is a negation
323    #[must_use]
324    pub fn as_neg(&self) -> Option<Self> {
325        if let ExprLang::Neg([inner_id]) = self.root() {
326            Some(self.extract_subexpr(*inner_id))
327        } else {
328            None
329        }
330    }
331
332    /// Get the operands if this is an addition operation
333    #[must_use]
334    pub fn as_add(&self) -> Option<Vec<Self>> {
335        if let ExprLang::Add([lhs_id, rhs_id]) = self.root() {
336            Some(vec![
337                self.extract_subexpr(*lhs_id),
338                self.extract_subexpr(*rhs_id),
339            ])
340        } else {
341            None
342        }
343    }
344
345    /// Get the operands if this is a multiplication operation
346    #[must_use]
347    pub fn as_mul(&self) -> Option<Vec<Self>> {
348        if let ExprLang::Mul([lhs_id, rhs_id]) = self.root() {
349            Some(vec![
350                self.extract_subexpr(*lhs_id),
351                self.extract_subexpr(*rhs_id),
352            ])
353        } else {
354            None
355        }
356    }
357
358    /// Get the base and exponent if this is a power operation
359    #[must_use]
360    pub fn as_pow(&self) -> Option<(Self, Self)> {
361        if let ExprLang::Pow([base_id, exp_id]) = self.root() {
362            Some((
363                self.extract_subexpr(*base_id),
364                self.extract_subexpr(*exp_id),
365            ))
366        } else {
367            None
368        }
369    }
370
371    /// Extract a subexpression by its ID
372    fn extract_subexpr(&self, id: Id) -> Self {
373        let target_idx = usize::from(id);
374        let mut new_expr = RecExpr::default();
375
376        // Build a mapping from old IDs to new IDs
377        let mut id_map = std::collections::HashMap::new();
378
379        // Traverse the expression up to and including the target node
380        for (idx, node) in self.expr.as_ref().iter().enumerate() {
381            if idx > target_idx {
382                break;
383            }
384            let new_node = node
385                .clone()
386                .map_children(|old_id| *id_map.get(&old_id).unwrap_or(&old_id));
387            let new_id = new_expr.add(new_node);
388            id_map.insert(Id::from(idx), new_id);
389        }
390
391        Self { expr: new_expr }
392    }
393
394    // =========================================================================
395    // Basic Operations
396    // =========================================================================
397
398    /// Add two expressions
399    #[must_use]
400    pub fn add(&self, other: &Self) -> Self {
401        self.clone() + other.clone()
402    }
403
404    /// Subtract two expressions
405    #[must_use]
406    pub fn sub(&self, other: &Self) -> Self {
407        self.clone() - other.clone()
408    }
409
410    /// Multiply two expressions
411    #[must_use]
412    pub fn mul(&self, other: &Self) -> Self {
413        self.clone() * other.clone()
414    }
415
416    /// Divide two expressions
417    #[must_use]
418    pub fn div(&self, other: &Self) -> Self {
419        self.clone() / other.clone()
420    }
421
422    /// Raise to a power
423    #[must_use]
424    pub fn pow(&self, exp: &Self) -> Self {
425        let mut expr = self.expr.clone();
426        let lhs_id = Id::from(expr.as_ref().len() - 1);
427
428        // Merge the exponent expression
429        let rhs_id = merge_expr(&mut expr, &exp.expr);
430
431        expr.add(ExprLang::Pow([lhs_id, rhs_id]));
432        Self { expr }
433    }
434
435    /// Negate the expression
436    #[must_use]
437    pub fn neg(&self) -> Self {
438        let mut expr = self.expr.clone();
439        let id = Id::from(expr.as_ref().len() - 1);
440        expr.add(ExprLang::Neg([id]));
441        Self { expr }
442    }
443
444    /// Complex conjugate
445    #[must_use]
446    pub fn conjugate(&self) -> Self {
447        let mut expr = self.expr.clone();
448        let id = Id::from(expr.as_ref().len() - 1);
449        expr.add(ExprLang::Conj([id]));
450        Self { expr }
451    }
452
453    // =========================================================================
454    // Calculus
455    // =========================================================================
456
457    /// Compute the derivative with respect to a variable
458    #[must_use]
459    pub fn diff(&self, var: &Self) -> Self {
460        crate::diff::differentiate(self, var)
461    }
462
463    /// Compute the gradient with respect to multiple variables
464    #[must_use]
465    pub fn gradient(&self, vars: &[Self]) -> Vec<Self> {
466        vars.iter().map(|v| self.diff(v)).collect()
467    }
468
469    /// Compute the Hessian matrix (second derivatives)
470    #[must_use]
471    pub fn hessian(&self, vars: &[Self]) -> Vec<Vec<Self>> {
472        let grad = self.gradient(vars);
473        grad.iter().map(|g| g.gradient(vars)).collect()
474    }
475
476    // =========================================================================
477    // Simplification
478    // =========================================================================
479
480    /// Expand the expression (distribute products over sums)
481    #[must_use]
482    pub fn expand(&self) -> Self {
483        crate::simplify::expand(self)
484    }
485
486    /// Simplify the expression
487    #[must_use]
488    pub fn simplify(&self) -> Self {
489        crate::simplify::simplify(self)
490    }
491
492    // =========================================================================
493    // Evaluation
494    // =========================================================================
495
496    /// Evaluate the expression with given variable values
497    ///
498    /// # Errors
499    /// Returns an error if a variable is not found in the values map
500    pub fn eval(&self, values: &HashMap<String, f64>) -> SymEngineResult<f64> {
501        crate::eval::evaluate(self, values)
502    }
503
504    /// Evaluate the expression to a complex number.
505    ///
506    /// This can handle expressions containing the imaginary unit `I`,
507    /// which is essential for quantum computing applications.
508    ///
509    /// # Arguments
510    /// * `values` - Map of variable names to real values
511    ///
512    /// # Returns
513    /// The complex result of the evaluation.
514    ///
515    /// # Errors
516    /// Returns an error if a variable is not found in the values map
517    pub fn eval_complex(
518        &self,
519        values: &HashMap<String, f64>,
520    ) -> SymEngineResult<scirs2_core::Complex64> {
521        crate::eval::evaluate_complex(self, values)
522    }
523
524    /// Substitute a variable with an expression
525    #[must_use]
526    pub fn substitute(&self, var: &Self, value: &Self) -> Self {
527        crate::simplify::substitute(self, var, value)
528    }
529
530    /// Substitute multiple variables
531    #[must_use]
532    pub fn substitute_many(&self, values: &HashMap<Self, Self>) -> Self {
533        let mut result = self.clone();
534        for (var, value) in values {
535            result = result.substitute(var, value);
536        }
537        result
538    }
539
540    // =========================================================================
541    // Symbol collection
542    // =========================================================================
543
544    /// Collect all free symbols (variable names) in this expression.
545    ///
546    /// Returns a `HashSet` of variable names that appear in the expression,
547    /// excluding numeric literals and the special constants `pi`, `e`, and `I`.
548    #[must_use]
549    pub fn free_symbols(&self) -> std::collections::HashSet<String> {
550        let mut symbols = std::collections::HashSet::new();
551        collect_free_symbols(
552            self.expr.as_ref(),
553            self.expr.as_ref().len() - 1,
554            &mut symbols,
555        );
556        symbols
557    }
558
559    // =========================================================================
560    // Internal helpers
561    // =========================================================================
562
563    /// Get the underlying RecExpr (for advanced usage)
564    pub(crate) const fn as_rec_expr(&self) -> &RecExpr<ExprLang> {
565        &self.expr
566    }
567
568    /// Create from RecExpr (for internal use)
569    pub(crate) const fn from_rec_expr(expr: RecExpr<ExprLang>) -> Self {
570        Self { expr }
571    }
572}
573
574/// Collect all free symbol names (variables) from a RecExpr node, recursively.
575///
576/// Numeric literals and special constants (`pi`, `e`, `I`) are excluded.
577fn collect_free_symbols(
578    nodes: &[ExprLang],
579    idx: usize,
580    symbols: &mut std::collections::HashSet<String>,
581) {
582    match &nodes[idx] {
583        ExprLang::Num(s) => {
584            let name = s.as_str();
585            // Exclude numeric literals and known constants
586            if name.parse::<f64>().is_err() && !matches!(name, "pi" | "e" | "I") {
587                symbols.insert(name.to_string());
588            }
589        }
590        node => {
591            node.for_each(|child_id| {
592                collect_free_symbols(nodes, usize::from(child_id), symbols);
593            });
594        }
595    }
596}
597
598/// Merge another expression into a RecExpr and return the new root ID
599fn merge_expr(target: &mut RecExpr<ExprLang>, source: &RecExpr<ExprLang>) -> Id {
600    let offset = target.as_ref().len();
601    for node in source.as_ref() {
602        let shifted = node
603            .clone()
604            .map_children(|id| Id::from(usize::from(id) + offset));
605        target.add(shifted);
606    }
607    Id::from(target.as_ref().len() - 1)
608}
609
610// =========================================================================
611// Trait Implementations
612// =========================================================================
613
614impl fmt::Display for Expression {
615    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
616        write!(f, "{}", self.expr.pretty(80))
617    }
618}
619
620impl PartialEq for Expression {
621    fn eq(&self, other: &Self) -> bool {
622        self.expr == other.expr
623    }
624}
625
626impl Eq for Expression {}
627
628impl Hash for Expression {
629    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
630        self.to_string().hash(state);
631    }
632}
633
634impl From<i64> for Expression {
635    fn from(n: i64) -> Self {
636        Self::int(n)
637    }
638}
639
640impl From<i32> for Expression {
641    fn from(n: i32) -> Self {
642        Self::int(i64::from(n))
643    }
644}
645
646impl From<f64> for Expression {
647    fn from(f: f64) -> Self {
648        Self::float_unchecked(f)
649    }
650}
651
652impl From<Complex64> for Expression {
653    fn from(c: Complex64) -> Self {
654        Self::from_complex64(c)
655    }
656}
657
658// Implement arithmetic operators
659impl std::ops::Add for Expression {
660    type Output = Self;
661
662    #[allow(clippy::suspicious_arithmetic_impl)]
663    fn add(self, rhs: Self) -> Self::Output {
664        let mut expr = self.expr;
665        let lhs_id = Id::from(expr.as_ref().len() - 1);
666        let rhs_id = merge_expr(&mut expr, &rhs.expr);
667        expr.add(ExprLang::Add([lhs_id, rhs_id]));
668        Self { expr }
669    }
670}
671
672impl std::ops::Sub for Expression {
673    type Output = Self;
674
675    #[allow(clippy::suspicious_arithmetic_impl)]
676    fn sub(self, rhs: Self) -> Self::Output {
677        self + rhs.neg()
678    }
679}
680
681impl std::ops::Mul for Expression {
682    type Output = Self;
683
684    #[allow(clippy::suspicious_arithmetic_impl)]
685    fn mul(self, rhs: Self) -> Self::Output {
686        let mut expr = self.expr;
687        let lhs_id = Id::from(expr.as_ref().len() - 1);
688        let rhs_id = merge_expr(&mut expr, &rhs.expr);
689        expr.add(ExprLang::Mul([lhs_id, rhs_id]));
690        Self { expr }
691    }
692}
693
694impl std::ops::Div for Expression {
695    type Output = Self;
696
697    #[allow(clippy::suspicious_arithmetic_impl)]
698    fn div(self, rhs: Self) -> Self::Output {
699        let mut expr = self.expr;
700        let lhs_id = Id::from(expr.as_ref().len() - 1);
701        let rhs_id = merge_expr(&mut expr, &rhs.expr);
702        expr.add(ExprLang::Div([lhs_id, rhs_id]));
703        Self { expr }
704    }
705}
706
707impl std::ops::Neg for Expression {
708    type Output = Self;
709
710    fn neg(self) -> Self::Output {
711        Self::neg(&self)
712    }
713}
714
715#[cfg(test)]
716mod tests {
717    use super::*;
718
719    #[test]
720    fn test_symbol_creation() {
721        let x = Expression::symbol("x");
722        assert!(x.is_symbol());
723        assert_eq!(x.as_symbol(), Some("x"));
724    }
725
726    #[test]
727    fn test_integer_creation() {
728        let n = Expression::int(42);
729        assert!(n.is_number());
730        assert_eq!(n.to_i64(), Some(42));
731    }
732
733    #[test]
734    fn test_float_creation() {
735        let f = Expression::float(2.5).expect("valid float");
736        assert!(f.is_number());
737        let val = f.to_f64().expect("should be f64");
738        assert!((val - 2.5).abs() < 1e-10);
739    }
740
741    #[test]
742    fn test_zero_and_one() {
743        let zero = Expression::zero();
744        let one = Expression::one();
745
746        assert!(zero.is_zero());
747        assert!(!zero.is_one());
748        assert!(one.is_one());
749        assert!(!one.is_zero());
750    }
751
752    #[test]
753    fn test_from_complex64() {
754        let c = Complex64::new(3.0, 4.0);
755        let expr = Expression::from_complex64(c);
756        assert!(!expr.is_number());
757    }
758
759    #[test]
760    fn test_arithmetic_operators() {
761        let x = Expression::symbol("x");
762        let y = Expression::symbol("y");
763
764        let sum = x.clone() + y.clone();
765        let product = x.clone() * y.clone();
766        let diff = x.clone() - y.clone();
767        let quot = x / y;
768
769        assert!(!sum.is_symbol());
770        assert!(!product.is_symbol());
771        assert!(!diff.is_symbol());
772        assert!(!quot.is_symbol());
773    }
774}