quantrs2_symengine_pure/
expr.rs

1//! Expression AST and core types for symbolic computation.
2//!
3//! This module defines the symbolic expression type using egg's e-graph
4//! for efficient representation and manipulation.
5
6use std::collections::HashMap;
7use std::fmt;
8use std::hash::Hash;
9use std::str::FromStr;
10use std::sync::Arc;
11
12use egg::{
13    define_language, Analysis, CostFunction, EGraph, Id, Language, RecExpr, Rewrite, Runner, Symbol,
14};
15use scirs2_core::Complex64;
16
17use crate::error::{SymEngineError, SymEngineResult};
18
19// The symbolic expression language using egg's macro.
20//
21// This language supports:
22// - Numeric constants (represented as strings to avoid trait issues)
23// - Symbols (variables)
24// - Arithmetic operations (add, mul, pow, neg, inv)
25// - Transcendental functions (sin, cos, exp, log, sqrt)
26// - Quantum-specific operations (commutator, anticommutator, tensor_product)
27define_language! {
28    /// The symbolic expression language
29    pub enum ExprLang {
30        // Use Symbol for both variable names and numeric literals
31        // Numbers are stored as strings like "42" or "3.14"
32        Num(Symbol),
33
34        // Binary arithmetic operations
35        "+" = Add([Id; 2]),
36        "*" = Mul([Id; 2]),
37        "/" = Div([Id; 2]),
38        "^" = Pow([Id; 2]),
39
40        // Unary operations
41        "neg" = Neg([Id; 1]),
42        "inv" = Inv([Id; 1]),
43        "abs" = Abs([Id; 1]),
44
45        // Transcendental functions
46        "sin" = Sin([Id; 1]),
47        "cos" = Cos([Id; 1]),
48        "tan" = Tan([Id; 1]),
49        "exp" = Exp([Id; 1]),
50        "log" = Log([Id; 1]),
51        "sqrt" = Sqrt([Id; 1]),
52        "asin" = Asin([Id; 1]),
53        "acos" = Acos([Id; 1]),
54        "atan" = Atan([Id; 1]),
55        "sinh" = Sinh([Id; 1]),
56        "cosh" = Cosh([Id; 1]),
57        "tanh" = Tanh([Id; 1]),
58
59        // Complex number operations
60        "re" = Re([Id; 1]),
61        "im" = Im([Id; 1]),
62        "conj" = Conj([Id; 1]),
63
64        // Quantum-specific operations
65        "comm" = Commutator([Id; 2]),      // [A, B] = AB - BA
66        "anticomm" = Anticommutator([Id; 2]), // {A, B} = AB + BA
67        "tensor" = TensorProduct([Id; 2]),  // A ⊗ B
68        "trace" = Trace([Id; 1]),
69        "dagger" = Dagger([Id; 1]),         // Hermitian conjugate
70
71        // Matrix operations
72        "det" = Determinant([Id; 1]),
73        "transpose" = Transpose([Id; 1]),
74    }
75}
76
77/// A symbolic mathematical expression.
78///
79/// This type wraps egg's `RecExpr` and provides a user-friendly API
80/// for symbolic computation.
81#[derive(Clone, Debug)]
82pub struct Expression {
83    /// The underlying recursive expression
84    expr: RecExpr<ExprLang>,
85}
86
87impl Expression {
88    // =========================================================================
89    // Construction
90    // =========================================================================
91
92    /// Create a new symbolic variable
93    ///
94    /// # Example
95    /// ```ignore
96    /// use quantrs2_symengine_pure::Expression;
97    /// let x = Expression::symbol("x");
98    /// ```
99    #[must_use]
100    pub fn symbol(name: &str) -> Self {
101        let mut expr = RecExpr::default();
102        expr.add(ExprLang::Num(Symbol::from(name)));
103        Self { expr }
104    }
105
106    /// Create an integer constant
107    #[must_use]
108    pub fn int(value: i64) -> Self {
109        let mut expr = RecExpr::default();
110        expr.add(ExprLang::Num(Symbol::from(value.to_string())));
111        Self { expr }
112    }
113
114    /// Create a floating-point constant
115    ///
116    /// # Errors
117    /// Returns an error if the value is NaN
118    pub fn float(value: f64) -> SymEngineResult<Self> {
119        if value.is_nan() {
120            return Err(SymEngineError::Undefined(
121                "NaN is not a valid symbolic value".into(),
122            ));
123        }
124        let mut expr = RecExpr::default();
125        expr.add(ExprLang::Num(Symbol::from(value.to_string())));
126        Ok(Self { expr })
127    }
128
129    /// Create a floating-point constant, using 0 for NaN
130    #[must_use]
131    pub fn float_unchecked(value: f64) -> Self {
132        let v = if value.is_nan() { 0.0 } else { value };
133        let mut expr = RecExpr::default();
134        expr.add(ExprLang::Num(Symbol::from(v.to_string())));
135        Self { expr }
136    }
137
138    /// Create the constant zero
139    #[must_use]
140    pub fn zero() -> Self {
141        Self::int(0)
142    }
143
144    /// Create the constant one
145    #[must_use]
146    pub fn one() -> Self {
147        Self::int(1)
148    }
149
150    /// Create the imaginary unit i
151    #[must_use]
152    pub fn i() -> Self {
153        Self::symbol("I")
154    }
155
156    /// Create the constant π
157    #[must_use]
158    pub fn pi() -> Self {
159        Self::symbol("pi")
160    }
161
162    /// Create the constant e (Euler's number)
163    #[must_use]
164    pub fn e() -> Self {
165        Self::symbol("e")
166    }
167
168    /// Create from a complex number
169    ///
170    /// If imaginary part is negligible, returns just the real part.
171    #[must_use]
172    pub fn from_complex64(c: Complex64) -> Self {
173        const EPSILON: f64 = 1e-15;
174        if c.im.abs() < EPSILON {
175            Self::float_unchecked(c.re)
176        } else if c.re.abs() < EPSILON {
177            // Pure imaginary
178            Self::float_unchecked(c.im) * Self::i()
179        } else {
180            // General complex
181            Self::float_unchecked(c.re) + Self::float_unchecked(c.im) * Self::i()
182        }
183    }
184
185    /// Parse an expression from a string
186    ///
187    /// # Errors
188    /// Returns an error if parsing fails
189    pub fn parse(input: &str) -> SymEngineResult<Self> {
190        let trimmed = input.trim();
191        if trimmed.is_empty() {
192            return Err(SymEngineError::parse("empty expression"));
193        }
194
195        // Try to parse as number
196        if let Ok(n) = trimmed.parse::<i64>() {
197            return Ok(Self::int(n));
198        }
199        if let Ok(f) = trimmed.parse::<f64>() {
200            return Self::float(f);
201        }
202
203        // Otherwise treat as symbol/expression
204        Ok(Self::symbol(trimmed))
205    }
206
207    /// Create an expression from a string (alias for parse)
208    #[must_use]
209    pub fn new(input: impl AsRef<str>) -> Self {
210        Self::parse(input.as_ref()).unwrap_or_else(|_| Self::symbol(input.as_ref()))
211    }
212
213    // =========================================================================
214    // Accessors
215    // =========================================================================
216
217    /// Get the root node of the expression
218    fn root(&self) -> &ExprLang {
219        &self.expr[self.root_id()]
220    }
221
222    /// Get the root ID
223    fn root_id(&self) -> Id {
224        Id::from(self.expr.as_ref().len() - 1)
225    }
226
227    /// Check if this expression is a symbol (variable or number literal)
228    #[must_use]
229    pub fn is_symbol(&self) -> bool {
230        matches!(self.root(), ExprLang::Num(_))
231    }
232
233    /// Check if this expression is a number
234    #[must_use]
235    pub fn is_number(&self) -> bool {
236        if let ExprLang::Num(s) = self.root() {
237            s.as_str().parse::<f64>().is_ok()
238        } else {
239            false
240        }
241    }
242
243    /// Check if this expression is zero
244    #[must_use]
245    pub fn is_zero(&self) -> bool {
246        if let ExprLang::Num(s) = self.root() {
247            s.as_str() == "0" || s.as_str().parse::<f64>().is_ok_and(|v| v.abs() < 1e-15)
248        } else {
249            false
250        }
251    }
252
253    /// Check if this expression is one
254    #[must_use]
255    pub fn is_one(&self) -> bool {
256        if let ExprLang::Num(s) = self.root() {
257            s.as_str() == "1"
258                || s.as_str()
259                    .parse::<f64>()
260                    .is_ok_and(|v| (v - 1.0).abs() < 1e-15)
261        } else {
262            false
263        }
264    }
265
266    /// Get the symbol name if this is a symbol
267    #[must_use]
268    pub fn as_symbol(&self) -> Option<&str> {
269        if let ExprLang::Num(s) = self.root() {
270            // Only return as symbol if it's not a number
271            if s.as_str().parse::<f64>().is_err() {
272                return Some(s.as_str());
273            }
274        }
275        None
276    }
277
278    /// Convert to f64 if this is a numeric constant
279    #[must_use]
280    pub fn to_f64(&self) -> Option<f64> {
281        if let ExprLang::Num(s) = self.root() {
282            s.as_str().parse::<f64>().ok()
283        } else {
284            None
285        }
286    }
287
288    /// Convert to i64 if this is an integer constant
289    #[must_use]
290    pub fn to_i64(&self) -> Option<i64> {
291        if let ExprLang::Num(s) = self.root() {
292            s.as_str().parse::<i64>().ok()
293        } else {
294            None
295        }
296    }
297
298    /// Check if this expression is an addition operation
299    #[must_use]
300    pub fn is_add(&self) -> bool {
301        matches!(self.root(), ExprLang::Add(_))
302    }
303
304    /// Check if this expression is a multiplication operation
305    #[must_use]
306    pub fn is_mul(&self) -> bool {
307        matches!(self.root(), ExprLang::Mul(_))
308    }
309
310    /// Check if this expression is a power operation
311    #[must_use]
312    pub fn is_pow(&self) -> bool {
313        matches!(self.root(), ExprLang::Pow(_))
314    }
315
316    /// Check if this expression is a negation operation
317    #[must_use]
318    pub fn is_neg(&self) -> bool {
319        matches!(self.root(), ExprLang::Neg(_))
320    }
321
322    /// Get the inner expression if this is a negation
323    #[must_use]
324    pub fn as_neg(&self) -> Option<Self> {
325        if let ExprLang::Neg([inner_id]) = self.root() {
326            Some(self.extract_subexpr(*inner_id))
327        } else {
328            None
329        }
330    }
331
332    /// Get the operands if this is an addition operation
333    #[must_use]
334    pub fn as_add(&self) -> Option<Vec<Self>> {
335        if let ExprLang::Add([lhs_id, rhs_id]) = self.root() {
336            Some(vec![
337                self.extract_subexpr(*lhs_id),
338                self.extract_subexpr(*rhs_id),
339            ])
340        } else {
341            None
342        }
343    }
344
345    /// Get the operands if this is a multiplication operation
346    #[must_use]
347    pub fn as_mul(&self) -> Option<Vec<Self>> {
348        if let ExprLang::Mul([lhs_id, rhs_id]) = self.root() {
349            Some(vec![
350                self.extract_subexpr(*lhs_id),
351                self.extract_subexpr(*rhs_id),
352            ])
353        } else {
354            None
355        }
356    }
357
358    /// Get the base and exponent if this is a power operation
359    #[must_use]
360    pub fn as_pow(&self) -> Option<(Self, Self)> {
361        if let ExprLang::Pow([base_id, exp_id]) = self.root() {
362            Some((
363                self.extract_subexpr(*base_id),
364                self.extract_subexpr(*exp_id),
365            ))
366        } else {
367            None
368        }
369    }
370
371    /// Extract a subexpression by its ID
372    fn extract_subexpr(&self, id: Id) -> Self {
373        let target_idx = usize::from(id);
374        let mut new_expr = RecExpr::default();
375
376        // Build a mapping from old IDs to new IDs
377        let mut id_map = std::collections::HashMap::new();
378
379        // Traverse the expression up to and including the target node
380        for (idx, node) in self.expr.as_ref().iter().enumerate() {
381            if idx > target_idx {
382                break;
383            }
384            let new_node = node
385                .clone()
386                .map_children(|old_id| *id_map.get(&old_id).unwrap_or(&old_id));
387            let new_id = new_expr.add(new_node);
388            id_map.insert(Id::from(idx), new_id);
389        }
390
391        Self { expr: new_expr }
392    }
393
394    // =========================================================================
395    // Basic Operations
396    // =========================================================================
397
398    /// Add two expressions
399    #[must_use]
400    pub fn add(&self, other: &Self) -> Self {
401        self.clone() + other.clone()
402    }
403
404    /// Subtract two expressions
405    #[must_use]
406    pub fn sub(&self, other: &Self) -> Self {
407        self.clone() - other.clone()
408    }
409
410    /// Multiply two expressions
411    #[must_use]
412    pub fn mul(&self, other: &Self) -> Self {
413        self.clone() * other.clone()
414    }
415
416    /// Divide two expressions
417    #[must_use]
418    pub fn div(&self, other: &Self) -> Self {
419        self.clone() / other.clone()
420    }
421
422    /// Raise to a power
423    #[must_use]
424    pub fn pow(&self, exp: &Self) -> Self {
425        let mut expr = self.expr.clone();
426        let lhs_id = Id::from(expr.as_ref().len() - 1);
427
428        // Merge the exponent expression
429        let rhs_id = merge_expr(&mut expr, &exp.expr);
430
431        expr.add(ExprLang::Pow([lhs_id, rhs_id]));
432        Self { expr }
433    }
434
435    /// Negate the expression
436    #[must_use]
437    pub fn neg(&self) -> Self {
438        let mut expr = self.expr.clone();
439        let id = Id::from(expr.as_ref().len() - 1);
440        expr.add(ExprLang::Neg([id]));
441        Self { expr }
442    }
443
444    /// Complex conjugate
445    #[must_use]
446    pub fn conjugate(&self) -> Self {
447        let mut expr = self.expr.clone();
448        let id = Id::from(expr.as_ref().len() - 1);
449        expr.add(ExprLang::Conj([id]));
450        Self { expr }
451    }
452
453    // =========================================================================
454    // Calculus
455    // =========================================================================
456
457    /// Compute the derivative with respect to a variable
458    #[must_use]
459    pub fn diff(&self, var: &Self) -> Self {
460        crate::diff::differentiate(self, var)
461    }
462
463    /// Compute the gradient with respect to multiple variables
464    #[must_use]
465    pub fn gradient(&self, vars: &[Self]) -> Vec<Self> {
466        vars.iter().map(|v| self.diff(v)).collect()
467    }
468
469    /// Compute the Hessian matrix (second derivatives)
470    #[must_use]
471    pub fn hessian(&self, vars: &[Self]) -> Vec<Vec<Self>> {
472        let grad = self.gradient(vars);
473        grad.iter().map(|g| g.gradient(vars)).collect()
474    }
475
476    // =========================================================================
477    // Simplification
478    // =========================================================================
479
480    /// Expand the expression (distribute products over sums)
481    #[must_use]
482    pub fn expand(&self) -> Self {
483        crate::simplify::expand(self)
484    }
485
486    /// Simplify the expression
487    #[must_use]
488    pub fn simplify(&self) -> Self {
489        crate::simplify::simplify(self)
490    }
491
492    // =========================================================================
493    // Evaluation
494    // =========================================================================
495
496    /// Evaluate the expression with given variable values
497    ///
498    /// # Errors
499    /// Returns an error if a variable is not found in the values map
500    pub fn eval(&self, values: &HashMap<String, f64>) -> SymEngineResult<f64> {
501        crate::eval::evaluate(self, values)
502    }
503
504    /// Evaluate the expression to a complex number.
505    ///
506    /// This can handle expressions containing the imaginary unit `I`,
507    /// which is essential for quantum computing applications.
508    ///
509    /// # Arguments
510    /// * `values` - Map of variable names to real values
511    ///
512    /// # Returns
513    /// The complex result of the evaluation.
514    ///
515    /// # Errors
516    /// Returns an error if a variable is not found in the values map
517    pub fn eval_complex(
518        &self,
519        values: &HashMap<String, f64>,
520    ) -> SymEngineResult<scirs2_core::Complex64> {
521        crate::eval::evaluate_complex(self, values)
522    }
523
524    /// Substitute a variable with an expression
525    #[must_use]
526    pub fn substitute(&self, var: &Self, value: &Self) -> Self {
527        crate::simplify::substitute(self, var, value)
528    }
529
530    /// Substitute multiple variables
531    #[must_use]
532    pub fn substitute_many(&self, values: &HashMap<Self, Self>) -> Self {
533        let mut result = self.clone();
534        for (var, value) in values {
535            result = result.substitute(var, value);
536        }
537        result
538    }
539
540    // =========================================================================
541    // Internal helpers
542    // =========================================================================
543
544    /// Get the underlying RecExpr (for advanced usage)
545    pub(crate) const fn as_rec_expr(&self) -> &RecExpr<ExprLang> {
546        &self.expr
547    }
548
549    /// Create from RecExpr (for internal use)
550    pub(crate) const fn from_rec_expr(expr: RecExpr<ExprLang>) -> Self {
551        Self { expr }
552    }
553}
554
555/// Merge another expression into a RecExpr and return the new root ID
556fn merge_expr(target: &mut RecExpr<ExprLang>, source: &RecExpr<ExprLang>) -> Id {
557    let offset = target.as_ref().len();
558    for node in source.as_ref() {
559        let shifted = node
560            .clone()
561            .map_children(|id| Id::from(usize::from(id) + offset));
562        target.add(shifted);
563    }
564    Id::from(target.as_ref().len() - 1)
565}
566
567// =========================================================================
568// Trait Implementations
569// =========================================================================
570
571impl fmt::Display for Expression {
572    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
573        write!(f, "{}", self.expr.pretty(80))
574    }
575}
576
577impl PartialEq for Expression {
578    fn eq(&self, other: &Self) -> bool {
579        self.expr == other.expr
580    }
581}
582
583impl Eq for Expression {}
584
585impl Hash for Expression {
586    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
587        self.to_string().hash(state);
588    }
589}
590
591impl From<i64> for Expression {
592    fn from(n: i64) -> Self {
593        Self::int(n)
594    }
595}
596
597impl From<i32> for Expression {
598    fn from(n: i32) -> Self {
599        Self::int(i64::from(n))
600    }
601}
602
603impl From<f64> for Expression {
604    fn from(f: f64) -> Self {
605        Self::float_unchecked(f)
606    }
607}
608
609impl From<Complex64> for Expression {
610    fn from(c: Complex64) -> Self {
611        Self::from_complex64(c)
612    }
613}
614
615// Implement arithmetic operators
616impl std::ops::Add for Expression {
617    type Output = Self;
618
619    #[allow(clippy::suspicious_arithmetic_impl)]
620    fn add(self, rhs: Self) -> Self::Output {
621        let mut expr = self.expr;
622        let lhs_id = Id::from(expr.as_ref().len() - 1);
623        let rhs_id = merge_expr(&mut expr, &rhs.expr);
624        expr.add(ExprLang::Add([lhs_id, rhs_id]));
625        Self { expr }
626    }
627}
628
629impl std::ops::Sub for Expression {
630    type Output = Self;
631
632    #[allow(clippy::suspicious_arithmetic_impl)]
633    fn sub(self, rhs: Self) -> Self::Output {
634        self + rhs.neg()
635    }
636}
637
638impl std::ops::Mul for Expression {
639    type Output = Self;
640
641    #[allow(clippy::suspicious_arithmetic_impl)]
642    fn mul(self, rhs: Self) -> Self::Output {
643        let mut expr = self.expr;
644        let lhs_id = Id::from(expr.as_ref().len() - 1);
645        let rhs_id = merge_expr(&mut expr, &rhs.expr);
646        expr.add(ExprLang::Mul([lhs_id, rhs_id]));
647        Self { expr }
648    }
649}
650
651impl std::ops::Div for Expression {
652    type Output = Self;
653
654    #[allow(clippy::suspicious_arithmetic_impl)]
655    fn div(self, rhs: Self) -> Self::Output {
656        let mut expr = self.expr;
657        let lhs_id = Id::from(expr.as_ref().len() - 1);
658        let rhs_id = merge_expr(&mut expr, &rhs.expr);
659        expr.add(ExprLang::Div([lhs_id, rhs_id]));
660        Self { expr }
661    }
662}
663
664impl std::ops::Neg for Expression {
665    type Output = Self;
666
667    fn neg(self) -> Self::Output {
668        Self::neg(&self)
669    }
670}
671
672#[cfg(test)]
673mod tests {
674    use super::*;
675
676    #[test]
677    fn test_symbol_creation() {
678        let x = Expression::symbol("x");
679        assert!(x.is_symbol());
680        assert_eq!(x.as_symbol(), Some("x"));
681    }
682
683    #[test]
684    fn test_integer_creation() {
685        let n = Expression::int(42);
686        assert!(n.is_number());
687        assert_eq!(n.to_i64(), Some(42));
688    }
689
690    #[test]
691    fn test_float_creation() {
692        let f = Expression::float(2.5).expect("valid float");
693        assert!(f.is_number());
694        let val = f.to_f64().expect("should be f64");
695        assert!((val - 2.5).abs() < 1e-10);
696    }
697
698    #[test]
699    fn test_zero_and_one() {
700        let zero = Expression::zero();
701        let one = Expression::one();
702
703        assert!(zero.is_zero());
704        assert!(!zero.is_one());
705        assert!(one.is_one());
706        assert!(!one.is_zero());
707    }
708
709    #[test]
710    fn test_from_complex64() {
711        let c = Complex64::new(3.0, 4.0);
712        let expr = Expression::from_complex64(c);
713        assert!(!expr.is_number());
714    }
715
716    #[test]
717    fn test_arithmetic_operators() {
718        let x = Expression::symbol("x");
719        let y = Expression::symbol("y");
720
721        let sum = x.clone() + y.clone();
722        let product = x.clone() * y.clone();
723        let diff = x.clone() - y.clone();
724        let quot = x / y;
725
726        assert!(!sum.is_symbol());
727        assert!(!product.is_symbol());
728        assert!(!diff.is_symbol());
729        assert!(!quot.is_symbol());
730    }
731}