1use std::collections::HashMap;
7use std::fmt;
8use std::hash::Hash;
9use std::str::FromStr;
10use std::sync::Arc;
11
12use egg::{
13 define_language, Analysis, CostFunction, EGraph, Id, Language, RecExpr, Rewrite, Runner, Symbol,
14};
15use scirs2_core::Complex64;
16
17use crate::error::{SymEngineError, SymEngineResult};
18
19define_language! {
28 pub enum ExprLang {
30 Num(Symbol),
33
34 "+" = Add([Id; 2]),
36 "*" = Mul([Id; 2]),
37 "/" = Div([Id; 2]),
38 "^" = Pow([Id; 2]),
39
40 "neg" = Neg([Id; 1]),
42 "inv" = Inv([Id; 1]),
43 "abs" = Abs([Id; 1]),
44
45 "sin" = Sin([Id; 1]),
47 "cos" = Cos([Id; 1]),
48 "tan" = Tan([Id; 1]),
49 "exp" = Exp([Id; 1]),
50 "log" = Log([Id; 1]),
51 "sqrt" = Sqrt([Id; 1]),
52 "asin" = Asin([Id; 1]),
53 "acos" = Acos([Id; 1]),
54 "atan" = Atan([Id; 1]),
55 "sinh" = Sinh([Id; 1]),
56 "cosh" = Cosh([Id; 1]),
57 "tanh" = Tanh([Id; 1]),
58
59 "re" = Re([Id; 1]),
61 "im" = Im([Id; 1]),
62 "conj" = Conj([Id; 1]),
63
64 "comm" = Commutator([Id; 2]), "anticomm" = Anticommutator([Id; 2]), "tensor" = TensorProduct([Id; 2]), "trace" = Trace([Id; 1]),
69 "dagger" = Dagger([Id; 1]), "det" = Determinant([Id; 1]),
73 "transpose" = Transpose([Id; 1]),
74 }
75}
76
77#[derive(Clone, Debug)]
82pub struct Expression {
83 expr: RecExpr<ExprLang>,
85}
86
87impl Expression {
88 #[must_use]
100 pub fn symbol(name: &str) -> Self {
101 let mut expr = RecExpr::default();
102 expr.add(ExprLang::Num(Symbol::from(name)));
103 Self { expr }
104 }
105
106 #[must_use]
108 pub fn int(value: i64) -> Self {
109 let mut expr = RecExpr::default();
110 expr.add(ExprLang::Num(Symbol::from(value.to_string())));
111 Self { expr }
112 }
113
114 pub fn float(value: f64) -> SymEngineResult<Self> {
119 if value.is_nan() {
120 return Err(SymEngineError::Undefined(
121 "NaN is not a valid symbolic value".into(),
122 ));
123 }
124 let mut expr = RecExpr::default();
125 expr.add(ExprLang::Num(Symbol::from(value.to_string())));
126 Ok(Self { expr })
127 }
128
129 #[must_use]
131 pub fn float_unchecked(value: f64) -> Self {
132 let v = if value.is_nan() { 0.0 } else { value };
133 let mut expr = RecExpr::default();
134 expr.add(ExprLang::Num(Symbol::from(v.to_string())));
135 Self { expr }
136 }
137
138 #[must_use]
140 pub fn zero() -> Self {
141 Self::int(0)
142 }
143
144 #[must_use]
146 pub fn one() -> Self {
147 Self::int(1)
148 }
149
150 #[must_use]
152 pub fn i() -> Self {
153 Self::symbol("I")
154 }
155
156 #[must_use]
158 pub fn pi() -> Self {
159 Self::symbol("pi")
160 }
161
162 #[must_use]
164 pub fn e() -> Self {
165 Self::symbol("e")
166 }
167
168 #[must_use]
172 pub fn from_complex64(c: Complex64) -> Self {
173 const EPSILON: f64 = 1e-15;
174 if c.im.abs() < EPSILON {
175 Self::float_unchecked(c.re)
176 } else if c.re.abs() < EPSILON {
177 Self::float_unchecked(c.im) * Self::i()
179 } else {
180 Self::float_unchecked(c.re) + Self::float_unchecked(c.im) * Self::i()
182 }
183 }
184
185 pub fn parse(input: &str) -> SymEngineResult<Self> {
190 let trimmed = input.trim();
191 if trimmed.is_empty() {
192 return Err(SymEngineError::parse("empty expression"));
193 }
194
195 if let Ok(n) = trimmed.parse::<i64>() {
197 return Ok(Self::int(n));
198 }
199 if let Ok(f) = trimmed.parse::<f64>() {
200 return Self::float(f);
201 }
202
203 Ok(Self::symbol(trimmed))
205 }
206
207 #[must_use]
209 pub fn new(input: impl AsRef<str>) -> Self {
210 Self::parse(input.as_ref()).unwrap_or_else(|_| Self::symbol(input.as_ref()))
211 }
212
213 fn root(&self) -> &ExprLang {
219 &self.expr[self.root_id()]
220 }
221
222 fn root_id(&self) -> Id {
224 Id::from(self.expr.as_ref().len() - 1)
225 }
226
227 #[must_use]
229 pub fn is_symbol(&self) -> bool {
230 matches!(self.root(), ExprLang::Num(_))
231 }
232
233 #[must_use]
235 pub fn is_number(&self) -> bool {
236 if let ExprLang::Num(s) = self.root() {
237 s.as_str().parse::<f64>().is_ok()
238 } else {
239 false
240 }
241 }
242
243 #[must_use]
245 pub fn is_zero(&self) -> bool {
246 if let ExprLang::Num(s) = self.root() {
247 s.as_str() == "0" || s.as_str().parse::<f64>().is_ok_and(|v| v.abs() < 1e-15)
248 } else {
249 false
250 }
251 }
252
253 #[must_use]
255 pub fn is_one(&self) -> bool {
256 if let ExprLang::Num(s) = self.root() {
257 s.as_str() == "1"
258 || s.as_str()
259 .parse::<f64>()
260 .is_ok_and(|v| (v - 1.0).abs() < 1e-15)
261 } else {
262 false
263 }
264 }
265
266 #[must_use]
268 pub fn as_symbol(&self) -> Option<&str> {
269 if let ExprLang::Num(s) = self.root() {
270 if s.as_str().parse::<f64>().is_err() {
272 return Some(s.as_str());
273 }
274 }
275 None
276 }
277
278 #[must_use]
280 pub fn to_f64(&self) -> Option<f64> {
281 if let ExprLang::Num(s) = self.root() {
282 s.as_str().parse::<f64>().ok()
283 } else {
284 None
285 }
286 }
287
288 #[must_use]
290 pub fn to_i64(&self) -> Option<i64> {
291 if let ExprLang::Num(s) = self.root() {
292 s.as_str().parse::<i64>().ok()
293 } else {
294 None
295 }
296 }
297
298 #[must_use]
300 pub fn is_add(&self) -> bool {
301 matches!(self.root(), ExprLang::Add(_))
302 }
303
304 #[must_use]
306 pub fn is_mul(&self) -> bool {
307 matches!(self.root(), ExprLang::Mul(_))
308 }
309
310 #[must_use]
312 pub fn is_pow(&self) -> bool {
313 matches!(self.root(), ExprLang::Pow(_))
314 }
315
316 #[must_use]
318 pub fn is_neg(&self) -> bool {
319 matches!(self.root(), ExprLang::Neg(_))
320 }
321
322 #[must_use]
324 pub fn as_neg(&self) -> Option<Self> {
325 if let ExprLang::Neg([inner_id]) = self.root() {
326 Some(self.extract_subexpr(*inner_id))
327 } else {
328 None
329 }
330 }
331
332 #[must_use]
334 pub fn as_add(&self) -> Option<Vec<Self>> {
335 if let ExprLang::Add([lhs_id, rhs_id]) = self.root() {
336 Some(vec![
337 self.extract_subexpr(*lhs_id),
338 self.extract_subexpr(*rhs_id),
339 ])
340 } else {
341 None
342 }
343 }
344
345 #[must_use]
347 pub fn as_mul(&self) -> Option<Vec<Self>> {
348 if let ExprLang::Mul([lhs_id, rhs_id]) = self.root() {
349 Some(vec![
350 self.extract_subexpr(*lhs_id),
351 self.extract_subexpr(*rhs_id),
352 ])
353 } else {
354 None
355 }
356 }
357
358 #[must_use]
360 pub fn as_pow(&self) -> Option<(Self, Self)> {
361 if let ExprLang::Pow([base_id, exp_id]) = self.root() {
362 Some((
363 self.extract_subexpr(*base_id),
364 self.extract_subexpr(*exp_id),
365 ))
366 } else {
367 None
368 }
369 }
370
371 fn extract_subexpr(&self, id: Id) -> Self {
373 let target_idx = usize::from(id);
374 let mut new_expr = RecExpr::default();
375
376 let mut id_map = std::collections::HashMap::new();
378
379 for (idx, node) in self.expr.as_ref().iter().enumerate() {
381 if idx > target_idx {
382 break;
383 }
384 let new_node = node
385 .clone()
386 .map_children(|old_id| *id_map.get(&old_id).unwrap_or(&old_id));
387 let new_id = new_expr.add(new_node);
388 id_map.insert(Id::from(idx), new_id);
389 }
390
391 Self { expr: new_expr }
392 }
393
394 #[must_use]
400 pub fn add(&self, other: &Self) -> Self {
401 self.clone() + other.clone()
402 }
403
404 #[must_use]
406 pub fn sub(&self, other: &Self) -> Self {
407 self.clone() - other.clone()
408 }
409
410 #[must_use]
412 pub fn mul(&self, other: &Self) -> Self {
413 self.clone() * other.clone()
414 }
415
416 #[must_use]
418 pub fn div(&self, other: &Self) -> Self {
419 self.clone() / other.clone()
420 }
421
422 #[must_use]
424 pub fn pow(&self, exp: &Self) -> Self {
425 let mut expr = self.expr.clone();
426 let lhs_id = Id::from(expr.as_ref().len() - 1);
427
428 let rhs_id = merge_expr(&mut expr, &exp.expr);
430
431 expr.add(ExprLang::Pow([lhs_id, rhs_id]));
432 Self { expr }
433 }
434
435 #[must_use]
437 pub fn neg(&self) -> Self {
438 let mut expr = self.expr.clone();
439 let id = Id::from(expr.as_ref().len() - 1);
440 expr.add(ExprLang::Neg([id]));
441 Self { expr }
442 }
443
444 #[must_use]
446 pub fn conjugate(&self) -> Self {
447 let mut expr = self.expr.clone();
448 let id = Id::from(expr.as_ref().len() - 1);
449 expr.add(ExprLang::Conj([id]));
450 Self { expr }
451 }
452
453 #[must_use]
459 pub fn diff(&self, var: &Self) -> Self {
460 crate::diff::differentiate(self, var)
461 }
462
463 #[must_use]
465 pub fn gradient(&self, vars: &[Self]) -> Vec<Self> {
466 vars.iter().map(|v| self.diff(v)).collect()
467 }
468
469 #[must_use]
471 pub fn hessian(&self, vars: &[Self]) -> Vec<Vec<Self>> {
472 let grad = self.gradient(vars);
473 grad.iter().map(|g| g.gradient(vars)).collect()
474 }
475
476 #[must_use]
482 pub fn expand(&self) -> Self {
483 crate::simplify::expand(self)
484 }
485
486 #[must_use]
488 pub fn simplify(&self) -> Self {
489 crate::simplify::simplify(self)
490 }
491
492 pub fn eval(&self, values: &HashMap<String, f64>) -> SymEngineResult<f64> {
501 crate::eval::evaluate(self, values)
502 }
503
504 pub fn eval_complex(
518 &self,
519 values: &HashMap<String, f64>,
520 ) -> SymEngineResult<scirs2_core::Complex64> {
521 crate::eval::evaluate_complex(self, values)
522 }
523
524 #[must_use]
526 pub fn substitute(&self, var: &Self, value: &Self) -> Self {
527 crate::simplify::substitute(self, var, value)
528 }
529
530 #[must_use]
532 pub fn substitute_many(&self, values: &HashMap<Self, Self>) -> Self {
533 let mut result = self.clone();
534 for (var, value) in values {
535 result = result.substitute(var, value);
536 }
537 result
538 }
539
540 pub(crate) const fn as_rec_expr(&self) -> &RecExpr<ExprLang> {
546 &self.expr
547 }
548
549 pub(crate) const fn from_rec_expr(expr: RecExpr<ExprLang>) -> Self {
551 Self { expr }
552 }
553}
554
555fn merge_expr(target: &mut RecExpr<ExprLang>, source: &RecExpr<ExprLang>) -> Id {
557 let offset = target.as_ref().len();
558 for node in source.as_ref() {
559 let shifted = node
560 .clone()
561 .map_children(|id| Id::from(usize::from(id) + offset));
562 target.add(shifted);
563 }
564 Id::from(target.as_ref().len() - 1)
565}
566
567impl fmt::Display for Expression {
572 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
573 write!(f, "{}", self.expr.pretty(80))
574 }
575}
576
577impl PartialEq for Expression {
578 fn eq(&self, other: &Self) -> bool {
579 self.expr == other.expr
580 }
581}
582
583impl Eq for Expression {}
584
585impl Hash for Expression {
586 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
587 self.to_string().hash(state);
588 }
589}
590
591impl From<i64> for Expression {
592 fn from(n: i64) -> Self {
593 Self::int(n)
594 }
595}
596
597impl From<i32> for Expression {
598 fn from(n: i32) -> Self {
599 Self::int(i64::from(n))
600 }
601}
602
603impl From<f64> for Expression {
604 fn from(f: f64) -> Self {
605 Self::float_unchecked(f)
606 }
607}
608
609impl From<Complex64> for Expression {
610 fn from(c: Complex64) -> Self {
611 Self::from_complex64(c)
612 }
613}
614
615impl std::ops::Add for Expression {
617 type Output = Self;
618
619 #[allow(clippy::suspicious_arithmetic_impl)]
620 fn add(self, rhs: Self) -> Self::Output {
621 let mut expr = self.expr;
622 let lhs_id = Id::from(expr.as_ref().len() - 1);
623 let rhs_id = merge_expr(&mut expr, &rhs.expr);
624 expr.add(ExprLang::Add([lhs_id, rhs_id]));
625 Self { expr }
626 }
627}
628
629impl std::ops::Sub for Expression {
630 type Output = Self;
631
632 #[allow(clippy::suspicious_arithmetic_impl)]
633 fn sub(self, rhs: Self) -> Self::Output {
634 self + rhs.neg()
635 }
636}
637
638impl std::ops::Mul for Expression {
639 type Output = Self;
640
641 #[allow(clippy::suspicious_arithmetic_impl)]
642 fn mul(self, rhs: Self) -> Self::Output {
643 let mut expr = self.expr;
644 let lhs_id = Id::from(expr.as_ref().len() - 1);
645 let rhs_id = merge_expr(&mut expr, &rhs.expr);
646 expr.add(ExprLang::Mul([lhs_id, rhs_id]));
647 Self { expr }
648 }
649}
650
651impl std::ops::Div for Expression {
652 type Output = Self;
653
654 #[allow(clippy::suspicious_arithmetic_impl)]
655 fn div(self, rhs: Self) -> Self::Output {
656 let mut expr = self.expr;
657 let lhs_id = Id::from(expr.as_ref().len() - 1);
658 let rhs_id = merge_expr(&mut expr, &rhs.expr);
659 expr.add(ExprLang::Div([lhs_id, rhs_id]));
660 Self { expr }
661 }
662}
663
664impl std::ops::Neg for Expression {
665 type Output = Self;
666
667 fn neg(self) -> Self::Output {
668 Self::neg(&self)
669 }
670}
671
672#[cfg(test)]
673mod tests {
674 use super::*;
675
676 #[test]
677 fn test_symbol_creation() {
678 let x = Expression::symbol("x");
679 assert!(x.is_symbol());
680 assert_eq!(x.as_symbol(), Some("x"));
681 }
682
683 #[test]
684 fn test_integer_creation() {
685 let n = Expression::int(42);
686 assert!(n.is_number());
687 assert_eq!(n.to_i64(), Some(42));
688 }
689
690 #[test]
691 fn test_float_creation() {
692 let f = Expression::float(2.5).expect("valid float");
693 assert!(f.is_number());
694 let val = f.to_f64().expect("should be f64");
695 assert!((val - 2.5).abs() < 1e-10);
696 }
697
698 #[test]
699 fn test_zero_and_one() {
700 let zero = Expression::zero();
701 let one = Expression::one();
702
703 assert!(zero.is_zero());
704 assert!(!zero.is_one());
705 assert!(one.is_one());
706 assert!(!one.is_zero());
707 }
708
709 #[test]
710 fn test_from_complex64() {
711 let c = Complex64::new(3.0, 4.0);
712 let expr = Expression::from_complex64(c);
713 assert!(!expr.is_number());
714 }
715
716 #[test]
717 fn test_arithmetic_operators() {
718 let x = Expression::symbol("x");
719 let y = Expression::symbol("y");
720
721 let sum = x.clone() + y.clone();
722 let product = x.clone() * y.clone();
723 let diff = x.clone() - y.clone();
724 let quot = x / y;
725
726 assert!(!sum.is_symbol());
727 assert!(!product.is_symbol());
728 assert!(!diff.is_symbol());
729 assert!(!quot.is_symbol());
730 }
731}