1use std::collections::HashMap;
7use std::fmt;
8use std::hash::Hash;
9use std::str::FromStr;
10use std::sync::Arc;
11
12use egg::{
13 define_language, Analysis, CostFunction, EGraph, Id, Language, RecExpr, Rewrite, Runner, Symbol,
14};
15use scirs2_core::Complex64;
16
17use crate::error::{SymEngineError, SymEngineResult};
18
19define_language! {
28 pub enum ExprLang {
30 Num(Symbol),
33
34 "+" = Add([Id; 2]),
36 "*" = Mul([Id; 2]),
37 "/" = Div([Id; 2]),
38 "^" = Pow([Id; 2]),
39
40 "neg" = Neg([Id; 1]),
42 "inv" = Inv([Id; 1]),
43 "abs" = Abs([Id; 1]),
44
45 "sin" = Sin([Id; 1]),
47 "cos" = Cos([Id; 1]),
48 "tan" = Tan([Id; 1]),
49 "exp" = Exp([Id; 1]),
50 "log" = Log([Id; 1]),
51 "sqrt" = Sqrt([Id; 1]),
52 "asin" = Asin([Id; 1]),
53 "acos" = Acos([Id; 1]),
54 "atan" = Atan([Id; 1]),
55 "sinh" = Sinh([Id; 1]),
56 "cosh" = Cosh([Id; 1]),
57 "tanh" = Tanh([Id; 1]),
58
59 "re" = Re([Id; 1]),
61 "im" = Im([Id; 1]),
62 "conj" = Conj([Id; 1]),
63
64 "comm" = Commutator([Id; 2]), "anticomm" = Anticommutator([Id; 2]), "tensor" = TensorProduct([Id; 2]), "trace" = Trace([Id; 1]),
69 "dagger" = Dagger([Id; 1]), "det" = Determinant([Id; 1]),
73 "transpose" = Transpose([Id; 1]),
74 }
75}
76
77#[derive(Clone, Debug)]
82pub struct Expression {
83 expr: RecExpr<ExprLang>,
85}
86
87impl Expression {
88 #[must_use]
100 pub fn symbol(name: &str) -> Self {
101 let mut expr = RecExpr::default();
102 expr.add(ExprLang::Num(Symbol::from(name)));
103 Self { expr }
104 }
105
106 #[must_use]
108 pub fn int(value: i64) -> Self {
109 let mut expr = RecExpr::default();
110 expr.add(ExprLang::Num(Symbol::from(value.to_string())));
111 Self { expr }
112 }
113
114 pub fn float(value: f64) -> SymEngineResult<Self> {
119 if value.is_nan() {
120 return Err(SymEngineError::Undefined(
121 "NaN is not a valid symbolic value".into(),
122 ));
123 }
124 let mut expr = RecExpr::default();
125 expr.add(ExprLang::Num(Symbol::from(value.to_string())));
126 Ok(Self { expr })
127 }
128
129 #[must_use]
131 pub fn float_unchecked(value: f64) -> Self {
132 let v = if value.is_nan() { 0.0 } else { value };
133 let mut expr = RecExpr::default();
134 expr.add(ExprLang::Num(Symbol::from(v.to_string())));
135 Self { expr }
136 }
137
138 #[must_use]
140 pub fn zero() -> Self {
141 Self::int(0)
142 }
143
144 #[must_use]
146 pub fn one() -> Self {
147 Self::int(1)
148 }
149
150 #[must_use]
152 pub fn i() -> Self {
153 Self::symbol("I")
154 }
155
156 #[must_use]
158 pub fn pi() -> Self {
159 Self::symbol("pi")
160 }
161
162 #[must_use]
164 pub fn e() -> Self {
165 Self::symbol("e")
166 }
167
168 #[must_use]
172 pub fn from_complex64(c: Complex64) -> Self {
173 const EPSILON: f64 = 1e-15;
174 if c.im.abs() < EPSILON {
175 Self::float_unchecked(c.re)
176 } else if c.re.abs() < EPSILON {
177 Self::float_unchecked(c.im) * Self::i()
179 } else {
180 Self::float_unchecked(c.re) + Self::float_unchecked(c.im) * Self::i()
182 }
183 }
184
185 pub fn parse(input: &str) -> SymEngineResult<Self> {
190 let trimmed = input.trim();
191 if trimmed.is_empty() {
192 return Err(SymEngineError::parse("empty expression"));
193 }
194
195 if let Ok(n) = trimmed.parse::<i64>() {
197 return Ok(Self::int(n));
198 }
199 if let Ok(f) = trimmed.parse::<f64>() {
200 return Self::float(f);
201 }
202
203 Ok(Self::symbol(trimmed))
205 }
206
207 #[must_use]
209 pub fn new(input: impl AsRef<str>) -> Self {
210 Self::parse(input.as_ref()).unwrap_or_else(|_| Self::symbol(input.as_ref()))
211 }
212
213 fn root(&self) -> &ExprLang {
219 &self.expr[self.root_id()]
220 }
221
222 fn root_id(&self) -> Id {
224 Id::from(self.expr.as_ref().len() - 1)
225 }
226
227 #[must_use]
229 pub fn is_symbol(&self) -> bool {
230 matches!(self.root(), ExprLang::Num(_))
231 }
232
233 #[must_use]
235 pub fn is_number(&self) -> bool {
236 if let ExprLang::Num(s) = self.root() {
237 s.as_str().parse::<f64>().is_ok()
238 } else {
239 false
240 }
241 }
242
243 #[must_use]
245 pub fn is_zero(&self) -> bool {
246 if let ExprLang::Num(s) = self.root() {
247 s.as_str() == "0" || s.as_str().parse::<f64>().is_ok_and(|v| v.abs() < 1e-15)
248 } else {
249 false
250 }
251 }
252
253 #[must_use]
255 pub fn is_one(&self) -> bool {
256 if let ExprLang::Num(s) = self.root() {
257 s.as_str() == "1"
258 || s.as_str()
259 .parse::<f64>()
260 .is_ok_and(|v| (v - 1.0).abs() < 1e-15)
261 } else {
262 false
263 }
264 }
265
266 #[must_use]
268 pub fn as_symbol(&self) -> Option<&str> {
269 if let ExprLang::Num(s) = self.root() {
270 if s.as_str().parse::<f64>().is_err() {
272 return Some(s.as_str());
273 }
274 }
275 None
276 }
277
278 #[must_use]
280 pub fn to_f64(&self) -> Option<f64> {
281 if let ExprLang::Num(s) = self.root() {
282 s.as_str().parse::<f64>().ok()
283 } else {
284 None
285 }
286 }
287
288 #[must_use]
290 pub fn to_i64(&self) -> Option<i64> {
291 if let ExprLang::Num(s) = self.root() {
292 s.as_str().parse::<i64>().ok()
293 } else {
294 None
295 }
296 }
297
298 #[must_use]
300 pub fn is_add(&self) -> bool {
301 matches!(self.root(), ExprLang::Add(_))
302 }
303
304 #[must_use]
306 pub fn is_mul(&self) -> bool {
307 matches!(self.root(), ExprLang::Mul(_))
308 }
309
310 #[must_use]
312 pub fn is_pow(&self) -> bool {
313 matches!(self.root(), ExprLang::Pow(_))
314 }
315
316 #[must_use]
318 pub fn is_neg(&self) -> bool {
319 matches!(self.root(), ExprLang::Neg(_))
320 }
321
322 #[must_use]
324 pub fn as_neg(&self) -> Option<Self> {
325 if let ExprLang::Neg([inner_id]) = self.root() {
326 Some(self.extract_subexpr(*inner_id))
327 } else {
328 None
329 }
330 }
331
332 #[must_use]
334 pub fn as_add(&self) -> Option<Vec<Self>> {
335 if let ExprLang::Add([lhs_id, rhs_id]) = self.root() {
336 Some(vec![
337 self.extract_subexpr(*lhs_id),
338 self.extract_subexpr(*rhs_id),
339 ])
340 } else {
341 None
342 }
343 }
344
345 #[must_use]
347 pub fn as_mul(&self) -> Option<Vec<Self>> {
348 if let ExprLang::Mul([lhs_id, rhs_id]) = self.root() {
349 Some(vec![
350 self.extract_subexpr(*lhs_id),
351 self.extract_subexpr(*rhs_id),
352 ])
353 } else {
354 None
355 }
356 }
357
358 #[must_use]
360 pub fn as_pow(&self) -> Option<(Self, Self)> {
361 if let ExprLang::Pow([base_id, exp_id]) = self.root() {
362 Some((
363 self.extract_subexpr(*base_id),
364 self.extract_subexpr(*exp_id),
365 ))
366 } else {
367 None
368 }
369 }
370
371 fn extract_subexpr(&self, id: Id) -> Self {
373 let target_idx = usize::from(id);
374 let mut new_expr = RecExpr::default();
375
376 let mut id_map = std::collections::HashMap::new();
378
379 for (idx, node) in self.expr.as_ref().iter().enumerate() {
381 if idx > target_idx {
382 break;
383 }
384 let new_node = node
385 .clone()
386 .map_children(|old_id| *id_map.get(&old_id).unwrap_or(&old_id));
387 let new_id = new_expr.add(new_node);
388 id_map.insert(Id::from(idx), new_id);
389 }
390
391 Self { expr: new_expr }
392 }
393
394 #[must_use]
400 pub fn add(&self, other: &Self) -> Self {
401 self.clone() + other.clone()
402 }
403
404 #[must_use]
406 pub fn sub(&self, other: &Self) -> Self {
407 self.clone() - other.clone()
408 }
409
410 #[must_use]
412 pub fn mul(&self, other: &Self) -> Self {
413 self.clone() * other.clone()
414 }
415
416 #[must_use]
418 pub fn div(&self, other: &Self) -> Self {
419 self.clone() / other.clone()
420 }
421
422 #[must_use]
424 pub fn pow(&self, exp: &Self) -> Self {
425 let mut expr = self.expr.clone();
426 let lhs_id = Id::from(expr.as_ref().len() - 1);
427
428 let rhs_id = merge_expr(&mut expr, &exp.expr);
430
431 expr.add(ExprLang::Pow([lhs_id, rhs_id]));
432 Self { expr }
433 }
434
435 #[must_use]
437 pub fn neg(&self) -> Self {
438 let mut expr = self.expr.clone();
439 let id = Id::from(expr.as_ref().len() - 1);
440 expr.add(ExprLang::Neg([id]));
441 Self { expr }
442 }
443
444 #[must_use]
446 pub fn conjugate(&self) -> Self {
447 let mut expr = self.expr.clone();
448 let id = Id::from(expr.as_ref().len() - 1);
449 expr.add(ExprLang::Conj([id]));
450 Self { expr }
451 }
452
453 #[must_use]
459 pub fn diff(&self, var: &Self) -> Self {
460 crate::diff::differentiate(self, var)
461 }
462
463 #[must_use]
465 pub fn gradient(&self, vars: &[Self]) -> Vec<Self> {
466 vars.iter().map(|v| self.diff(v)).collect()
467 }
468
469 #[must_use]
471 pub fn hessian(&self, vars: &[Self]) -> Vec<Vec<Self>> {
472 let grad = self.gradient(vars);
473 grad.iter().map(|g| g.gradient(vars)).collect()
474 }
475
476 #[must_use]
482 pub fn expand(&self) -> Self {
483 crate::simplify::expand(self)
484 }
485
486 #[must_use]
488 pub fn simplify(&self) -> Self {
489 crate::simplify::simplify(self)
490 }
491
492 pub fn eval(&self, values: &HashMap<String, f64>) -> SymEngineResult<f64> {
501 crate::eval::evaluate(self, values)
502 }
503
504 pub fn eval_complex(
518 &self,
519 values: &HashMap<String, f64>,
520 ) -> SymEngineResult<scirs2_core::Complex64> {
521 crate::eval::evaluate_complex(self, values)
522 }
523
524 #[must_use]
526 pub fn substitute(&self, var: &Self, value: &Self) -> Self {
527 crate::simplify::substitute(self, var, value)
528 }
529
530 #[must_use]
532 pub fn substitute_many(&self, values: &HashMap<Self, Self>) -> Self {
533 let mut result = self.clone();
534 for (var, value) in values {
535 result = result.substitute(var, value);
536 }
537 result
538 }
539
540 #[must_use]
549 pub fn free_symbols(&self) -> std::collections::HashSet<String> {
550 let mut symbols = std::collections::HashSet::new();
551 collect_free_symbols(
552 self.expr.as_ref(),
553 self.expr.as_ref().len() - 1,
554 &mut symbols,
555 );
556 symbols
557 }
558
559 pub(crate) const fn as_rec_expr(&self) -> &RecExpr<ExprLang> {
565 &self.expr
566 }
567
568 pub(crate) const fn from_rec_expr(expr: RecExpr<ExprLang>) -> Self {
570 Self { expr }
571 }
572}
573
574fn collect_free_symbols(
578 nodes: &[ExprLang],
579 idx: usize,
580 symbols: &mut std::collections::HashSet<String>,
581) {
582 match &nodes[idx] {
583 ExprLang::Num(s) => {
584 let name = s.as_str();
585 if name.parse::<f64>().is_err() && !matches!(name, "pi" | "e" | "I") {
587 symbols.insert(name.to_string());
588 }
589 }
590 node => {
591 node.for_each(|child_id| {
592 collect_free_symbols(nodes, usize::from(child_id), symbols);
593 });
594 }
595 }
596}
597
598fn merge_expr(target: &mut RecExpr<ExprLang>, source: &RecExpr<ExprLang>) -> Id {
600 let offset = target.as_ref().len();
601 for node in source.as_ref() {
602 let shifted = node
603 .clone()
604 .map_children(|id| Id::from(usize::from(id) + offset));
605 target.add(shifted);
606 }
607 Id::from(target.as_ref().len() - 1)
608}
609
610impl fmt::Display for Expression {
615 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
616 write!(f, "{}", self.expr.pretty(80))
617 }
618}
619
620impl PartialEq for Expression {
621 fn eq(&self, other: &Self) -> bool {
622 self.expr == other.expr
623 }
624}
625
626impl Eq for Expression {}
627
628impl Hash for Expression {
629 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
630 self.to_string().hash(state);
631 }
632}
633
634impl From<i64> for Expression {
635 fn from(n: i64) -> Self {
636 Self::int(n)
637 }
638}
639
640impl From<i32> for Expression {
641 fn from(n: i32) -> Self {
642 Self::int(i64::from(n))
643 }
644}
645
646impl From<f64> for Expression {
647 fn from(f: f64) -> Self {
648 Self::float_unchecked(f)
649 }
650}
651
652impl From<Complex64> for Expression {
653 fn from(c: Complex64) -> Self {
654 Self::from_complex64(c)
655 }
656}
657
658impl std::ops::Add for Expression {
660 type Output = Self;
661
662 #[allow(clippy::suspicious_arithmetic_impl)]
663 fn add(self, rhs: Self) -> Self::Output {
664 let mut expr = self.expr;
665 let lhs_id = Id::from(expr.as_ref().len() - 1);
666 let rhs_id = merge_expr(&mut expr, &rhs.expr);
667 expr.add(ExprLang::Add([lhs_id, rhs_id]));
668 Self { expr }
669 }
670}
671
672impl std::ops::Sub for Expression {
673 type Output = Self;
674
675 #[allow(clippy::suspicious_arithmetic_impl)]
676 fn sub(self, rhs: Self) -> Self::Output {
677 self + rhs.neg()
678 }
679}
680
681impl std::ops::Mul for Expression {
682 type Output = Self;
683
684 #[allow(clippy::suspicious_arithmetic_impl)]
685 fn mul(self, rhs: Self) -> Self::Output {
686 let mut expr = self.expr;
687 let lhs_id = Id::from(expr.as_ref().len() - 1);
688 let rhs_id = merge_expr(&mut expr, &rhs.expr);
689 expr.add(ExprLang::Mul([lhs_id, rhs_id]));
690 Self { expr }
691 }
692}
693
694impl std::ops::Div for Expression {
695 type Output = Self;
696
697 #[allow(clippy::suspicious_arithmetic_impl)]
698 fn div(self, rhs: Self) -> Self::Output {
699 let mut expr = self.expr;
700 let lhs_id = Id::from(expr.as_ref().len() - 1);
701 let rhs_id = merge_expr(&mut expr, &rhs.expr);
702 expr.add(ExprLang::Div([lhs_id, rhs_id]));
703 Self { expr }
704 }
705}
706
707impl std::ops::Neg for Expression {
708 type Output = Self;
709
710 fn neg(self) -> Self::Output {
711 Self::neg(&self)
712 }
713}
714
715#[cfg(test)]
716mod tests {
717 use super::*;
718
719 #[test]
720 fn test_symbol_creation() {
721 let x = Expression::symbol("x");
722 assert!(x.is_symbol());
723 assert_eq!(x.as_symbol(), Some("x"));
724 }
725
726 #[test]
727 fn test_integer_creation() {
728 let n = Expression::int(42);
729 assert!(n.is_number());
730 assert_eq!(n.to_i64(), Some(42));
731 }
732
733 #[test]
734 fn test_float_creation() {
735 let f = Expression::float(2.5).expect("valid float");
736 assert!(f.is_number());
737 let val = f.to_f64().expect("should be f64");
738 assert!((val - 2.5).abs() < 1e-10);
739 }
740
741 #[test]
742 fn test_zero_and_one() {
743 let zero = Expression::zero();
744 let one = Expression::one();
745
746 assert!(zero.is_zero());
747 assert!(!zero.is_one());
748 assert!(one.is_one());
749 assert!(!one.is_zero());
750 }
751
752 #[test]
753 fn test_from_complex64() {
754 let c = Complex64::new(3.0, 4.0);
755 let expr = Expression::from_complex64(c);
756 assert!(!expr.is_number());
757 }
758
759 #[test]
760 fn test_arithmetic_operators() {
761 let x = Expression::symbol("x");
762 let y = Expression::symbol("y");
763
764 let sum = x.clone() + y.clone();
765 let product = x.clone() * y.clone();
766 let diff = x.clone() - y.clone();
767 let quot = x / y;
768
769 assert!(!sum.is_symbol());
770 assert!(!product.is_symbol());
771 assert!(!diff.is_symbol());
772 assert!(!quot.is_symbol());
773 }
774}