1use std::collections::HashMap;
24
25use scirs2_core::Complex64;
26
27use crate::error::{SymEngineError, SymEngineResult};
28use crate::expr::{ExprLang, Expression};
29
30pub fn evaluate(expr: &Expression, values: &HashMap<String, f64>) -> SymEngineResult<f64> {
32 let rec_expr = expr.as_rec_expr();
33 let root_idx = rec_expr.as_ref().len() - 1;
34 evaluate_rec(rec_expr.as_ref(), root_idx, values)
35}
36
37fn evaluate_rec(
39 nodes: &[ExprLang],
40 idx: usize,
41 values: &HashMap<String, f64>,
42) -> SymEngineResult<f64> {
43 let node = &nodes[idx];
44
45 match node {
46 ExprLang::Num(s) => {
47 let name = s.as_str();
48 if let Ok(n) = name.parse::<f64>() {
50 return Ok(n);
51 }
52 match name {
54 "pi" => Ok(std::f64::consts::PI),
55 "e" => Ok(std::f64::consts::E),
56 "I" => Err(SymEngineError::eval(
57 "Cannot evaluate complex unit i as real",
58 )),
59 _ => values
60 .get(name)
61 .copied()
62 .ok_or_else(|| SymEngineError::eval(format!("Undefined variable: {name}"))),
63 }
64 }
65
66 ExprLang::Add([a, b]) => {
67 let va = evaluate_rec(nodes, usize::from(*a), values)?;
68 let vb = evaluate_rec(nodes, usize::from(*b), values)?;
69 Ok(va + vb)
70 }
71
72 ExprLang::Mul([a, b]) => {
73 let va = evaluate_rec(nodes, usize::from(*a), values)?;
74 let vb = evaluate_rec(nodes, usize::from(*b), values)?;
75 Ok(va * vb)
76 }
77
78 ExprLang::Div([a, b]) => {
79 let va = evaluate_rec(nodes, usize::from(*a), values)?;
80 let vb = evaluate_rec(nodes, usize::from(*b), values)?;
81 if vb.abs() < 1e-15 {
82 return Err(SymEngineError::DivisionByZero);
83 }
84 Ok(va / vb)
85 }
86
87 ExprLang::Pow([a, b]) => {
88 let va = evaluate_rec(nodes, usize::from(*a), values)?;
89 let vb = evaluate_rec(nodes, usize::from(*b), values)?;
90 Ok(va.powf(vb))
91 }
92
93 ExprLang::Neg([a]) => {
94 let va = evaluate_rec(nodes, usize::from(*a), values)?;
95 Ok(-va)
96 }
97
98 ExprLang::Inv([a]) => {
99 let va = evaluate_rec(nodes, usize::from(*a), values)?;
100 if va.abs() < 1e-15 {
101 return Err(SymEngineError::DivisionByZero);
102 }
103 Ok(1.0 / va)
104 }
105
106 ExprLang::Abs([a]) => {
107 let va = evaluate_rec(nodes, usize::from(*a), values)?;
108 Ok(va.abs())
109 }
110
111 ExprLang::Sin([a]) => {
112 let va = evaluate_rec(nodes, usize::from(*a), values)?;
113 Ok(va.sin())
114 }
115
116 ExprLang::Cos([a]) => {
117 let va = evaluate_rec(nodes, usize::from(*a), values)?;
118 Ok(va.cos())
119 }
120
121 ExprLang::Tan([a]) => {
122 let va = evaluate_rec(nodes, usize::from(*a), values)?;
123 Ok(va.tan())
124 }
125
126 ExprLang::Exp([a]) => {
127 let va = evaluate_rec(nodes, usize::from(*a), values)?;
128 Ok(va.exp())
129 }
130
131 ExprLang::Log([a]) => {
132 let va = evaluate_rec(nodes, usize::from(*a), values)?;
133 if va <= 0.0 {
134 return Err(SymEngineError::Undefined(
135 "log of non-positive number".into(),
136 ));
137 }
138 Ok(va.ln())
139 }
140
141 ExprLang::Sqrt([a]) => {
142 let va = evaluate_rec(nodes, usize::from(*a), values)?;
143 if va < 0.0 {
144 return Err(SymEngineError::Undefined("sqrt of negative number".into()));
145 }
146 Ok(va.sqrt())
147 }
148
149 ExprLang::Asin([a]) => {
150 let va = evaluate_rec(nodes, usize::from(*a), values)?;
151 Ok(va.asin())
152 }
153
154 ExprLang::Acos([a]) => {
155 let va = evaluate_rec(nodes, usize::from(*a), values)?;
156 Ok(va.acos())
157 }
158
159 ExprLang::Atan([a]) => {
160 let va = evaluate_rec(nodes, usize::from(*a), values)?;
161 Ok(va.atan())
162 }
163
164 ExprLang::Sinh([a]) => {
165 let va = evaluate_rec(nodes, usize::from(*a), values)?;
166 Ok(va.sinh())
167 }
168
169 ExprLang::Cosh([a]) => {
170 let va = evaluate_rec(nodes, usize::from(*a), values)?;
171 Ok(va.cosh())
172 }
173
174 ExprLang::Tanh([a]) => {
175 let va = evaluate_rec(nodes, usize::from(*a), values)?;
176 Ok(va.tanh())
177 }
178
179 ExprLang::Re([a]) | ExprLang::Im([a]) | ExprLang::Conj([a]) => {
181 evaluate_rec(nodes, usize::from(*a), values)
183 }
184
185 ExprLang::Commutator([_, _])
187 | ExprLang::Anticommutator([_, _])
188 | ExprLang::TensorProduct([_, _])
189 | ExprLang::Trace([_])
190 | ExprLang::Dagger([_])
191 | ExprLang::Determinant([_])
192 | ExprLang::Transpose([_]) => Err(SymEngineError::eval(
193 "Cannot evaluate symbolic quantum operation numerically",
194 )),
195 }
196}
197
198pub fn evaluate_batch(
200 expr: &Expression,
201 values_list: &[HashMap<String, f64>],
202) -> Vec<SymEngineResult<f64>> {
203 values_list.iter().map(|v| evaluate(expr, v)).collect()
204}
205
206pub fn evaluate_complex(
239 expr: &Expression,
240 values: &HashMap<String, f64>,
241) -> SymEngineResult<Complex64> {
242 let rec_expr = expr.as_rec_expr();
243 let root_idx = rec_expr.as_ref().len() - 1;
244 evaluate_complex_rec(rec_expr.as_ref(), root_idx, values)
245}
246
247fn evaluate_complex_rec(
249 nodes: &[ExprLang],
250 idx: usize,
251 values: &HashMap<String, f64>,
252) -> SymEngineResult<Complex64> {
253 let node = &nodes[idx];
254
255 match node {
256 ExprLang::Num(s) => {
257 let name = s.as_str();
258 if let Ok(n) = name.parse::<f64>() {
260 return Ok(Complex64::new(n, 0.0));
261 }
262 match name {
264 "pi" => Ok(Complex64::new(std::f64::consts::PI, 0.0)),
265 "e" => Ok(Complex64::new(std::f64::consts::E, 0.0)),
266 "I" => Ok(Complex64::new(0.0, 1.0)), _ => values
268 .get(name)
269 .copied()
270 .map(|v| Complex64::new(v, 0.0))
271 .ok_or_else(|| SymEngineError::eval(format!("Undefined variable: {name}"))),
272 }
273 }
274
275 ExprLang::Add([a, b]) => {
276 let va = evaluate_complex_rec(nodes, usize::from(*a), values)?;
277 let vb = evaluate_complex_rec(nodes, usize::from(*b), values)?;
278 Ok(va + vb)
279 }
280
281 ExprLang::Mul([a, b]) => {
282 let va = evaluate_complex_rec(nodes, usize::from(*a), values)?;
283 let vb = evaluate_complex_rec(nodes, usize::from(*b), values)?;
284 Ok(va * vb)
285 }
286
287 ExprLang::Div([a, b]) => {
288 let va = evaluate_complex_rec(nodes, usize::from(*a), values)?;
289 let vb = evaluate_complex_rec(nodes, usize::from(*b), values)?;
290 if vb.norm() < 1e-15 {
291 return Err(SymEngineError::DivisionByZero);
292 }
293 Ok(va / vb)
294 }
295
296 ExprLang::Pow([a, b]) => {
297 let va = evaluate_complex_rec(nodes, usize::from(*a), values)?;
298 let vb = evaluate_complex_rec(nodes, usize::from(*b), values)?;
299 Ok(va.powc(vb))
300 }
301
302 ExprLang::Neg([a]) => {
303 let va = evaluate_complex_rec(nodes, usize::from(*a), values)?;
304 Ok(-va)
305 }
306
307 ExprLang::Inv([a]) => {
308 let va = evaluate_complex_rec(nodes, usize::from(*a), values)?;
309 if va.norm() < 1e-15 {
310 return Err(SymEngineError::DivisionByZero);
311 }
312 Ok(Complex64::new(1.0, 0.0) / va)
313 }
314
315 ExprLang::Abs([a]) => {
316 let va = evaluate_complex_rec(nodes, usize::from(*a), values)?;
317 Ok(Complex64::new(va.norm(), 0.0))
318 }
319
320 ExprLang::Sin([a]) => {
321 let va = evaluate_complex_rec(nodes, usize::from(*a), values)?;
322 Ok(va.sin())
323 }
324
325 ExprLang::Cos([a]) => {
326 let va = evaluate_complex_rec(nodes, usize::from(*a), values)?;
327 Ok(va.cos())
328 }
329
330 ExprLang::Tan([a]) => {
331 let va = evaluate_complex_rec(nodes, usize::from(*a), values)?;
332 Ok(va.tan())
333 }
334
335 ExprLang::Exp([a]) => {
336 let va = evaluate_complex_rec(nodes, usize::from(*a), values)?;
337 Ok(va.exp())
338 }
339
340 ExprLang::Log([a]) => {
341 let va = evaluate_complex_rec(nodes, usize::from(*a), values)?;
342 if va.norm() < 1e-15 {
343 return Err(SymEngineError::Undefined("log of zero".into()));
344 }
345 Ok(va.ln())
346 }
347
348 ExprLang::Sqrt([a]) => {
349 let va = evaluate_complex_rec(nodes, usize::from(*a), values)?;
350 Ok(va.sqrt())
351 }
352
353 ExprLang::Asin([a]) => {
354 let va = evaluate_complex_rec(nodes, usize::from(*a), values)?;
355 Ok(va.asin())
356 }
357
358 ExprLang::Acos([a]) => {
359 let va = evaluate_complex_rec(nodes, usize::from(*a), values)?;
360 Ok(va.acos())
361 }
362
363 ExprLang::Atan([a]) => {
364 let va = evaluate_complex_rec(nodes, usize::from(*a), values)?;
365 Ok(va.atan())
366 }
367
368 ExprLang::Sinh([a]) => {
369 let va = evaluate_complex_rec(nodes, usize::from(*a), values)?;
370 Ok(va.sinh())
371 }
372
373 ExprLang::Cosh([a]) => {
374 let va = evaluate_complex_rec(nodes, usize::from(*a), values)?;
375 Ok(va.cosh())
376 }
377
378 ExprLang::Tanh([a]) => {
379 let va = evaluate_complex_rec(nodes, usize::from(*a), values)?;
380 Ok(va.tanh())
381 }
382
383 ExprLang::Re([a]) => {
385 let va = evaluate_complex_rec(nodes, usize::from(*a), values)?;
386 Ok(Complex64::new(va.re, 0.0))
387 }
388
389 ExprLang::Im([a]) => {
390 let va = evaluate_complex_rec(nodes, usize::from(*a), values)?;
391 Ok(Complex64::new(va.im, 0.0))
392 }
393
394 ExprLang::Conj([a]) => {
395 let va = evaluate_complex_rec(nodes, usize::from(*a), values)?;
396 Ok(va.conj())
397 }
398
399 ExprLang::Commutator([_, _])
401 | ExprLang::Anticommutator([_, _])
402 | ExprLang::TensorProduct([_, _])
403 | ExprLang::Trace([_])
404 | ExprLang::Dagger([_])
405 | ExprLang::Determinant([_])
406 | ExprLang::Transpose([_]) => Err(SymEngineError::eval(
407 "Cannot evaluate symbolic quantum operation numerically",
408 )),
409 }
410}
411
412pub fn evaluate_complex_batch(
414 expr: &Expression,
415 values_list: &[HashMap<String, f64>],
416) -> Vec<SymEngineResult<Complex64>> {
417 values_list
418 .iter()
419 .map(|v| evaluate_complex(expr, v))
420 .collect()
421}
422
423pub fn evaluate_complex_with_complex_values(
428 expr: &Expression,
429 values: &HashMap<String, Complex64>,
430) -> SymEngineResult<Complex64> {
431 let rec_expr = expr.as_rec_expr();
432 let root_idx = rec_expr.as_ref().len() - 1;
433 evaluate_complex_full_rec(rec_expr.as_ref(), root_idx, values)
434}
435
436fn evaluate_complex_full_rec(
438 nodes: &[ExprLang],
439 idx: usize,
440 values: &HashMap<String, Complex64>,
441) -> SymEngineResult<Complex64> {
442 let node = &nodes[idx];
443
444 match node {
445 ExprLang::Num(s) => {
446 let name = s.as_str();
447 if let Ok(n) = name.parse::<f64>() {
449 return Ok(Complex64::new(n, 0.0));
450 }
451 match name {
453 "pi" => Ok(Complex64::new(std::f64::consts::PI, 0.0)),
454 "e" => Ok(Complex64::new(std::f64::consts::E, 0.0)),
455 "I" => Ok(Complex64::new(0.0, 1.0)),
456 _ => values
457 .get(name)
458 .copied()
459 .ok_or_else(|| SymEngineError::eval(format!("Undefined variable: {name}"))),
460 }
461 }
462
463 ExprLang::Add([a, b]) => {
464 let va = evaluate_complex_full_rec(nodes, usize::from(*a), values)?;
465 let vb = evaluate_complex_full_rec(nodes, usize::from(*b), values)?;
466 Ok(va + vb)
467 }
468
469 ExprLang::Mul([a, b]) => {
470 let va = evaluate_complex_full_rec(nodes, usize::from(*a), values)?;
471 let vb = evaluate_complex_full_rec(nodes, usize::from(*b), values)?;
472 Ok(va * vb)
473 }
474
475 ExprLang::Div([a, b]) => {
476 let va = evaluate_complex_full_rec(nodes, usize::from(*a), values)?;
477 let vb = evaluate_complex_full_rec(nodes, usize::from(*b), values)?;
478 if vb.norm() < 1e-15 {
479 return Err(SymEngineError::DivisionByZero);
480 }
481 Ok(va / vb)
482 }
483
484 ExprLang::Pow([a, b]) => {
485 let va = evaluate_complex_full_rec(nodes, usize::from(*a), values)?;
486 let vb = evaluate_complex_full_rec(nodes, usize::from(*b), values)?;
487 Ok(va.powc(vb))
488 }
489
490 ExprLang::Neg([a]) => {
491 let va = evaluate_complex_full_rec(nodes, usize::from(*a), values)?;
492 Ok(-va)
493 }
494
495 ExprLang::Inv([a]) => {
496 let va = evaluate_complex_full_rec(nodes, usize::from(*a), values)?;
497 if va.norm() < 1e-15 {
498 return Err(SymEngineError::DivisionByZero);
499 }
500 Ok(Complex64::new(1.0, 0.0) / va)
501 }
502
503 ExprLang::Abs([a]) => {
504 let va = evaluate_complex_full_rec(nodes, usize::from(*a), values)?;
505 Ok(Complex64::new(va.norm(), 0.0))
506 }
507
508 ExprLang::Sin([a]) => {
509 let va = evaluate_complex_full_rec(nodes, usize::from(*a), values)?;
510 Ok(va.sin())
511 }
512
513 ExprLang::Cos([a]) => {
514 let va = evaluate_complex_full_rec(nodes, usize::from(*a), values)?;
515 Ok(va.cos())
516 }
517
518 ExprLang::Tan([a]) => {
519 let va = evaluate_complex_full_rec(nodes, usize::from(*a), values)?;
520 Ok(va.tan())
521 }
522
523 ExprLang::Exp([a]) => {
524 let va = evaluate_complex_full_rec(nodes, usize::from(*a), values)?;
525 Ok(va.exp())
526 }
527
528 ExprLang::Log([a]) => {
529 let va = evaluate_complex_full_rec(nodes, usize::from(*a), values)?;
530 if va.norm() < 1e-15 {
531 return Err(SymEngineError::Undefined("log of zero".into()));
532 }
533 Ok(va.ln())
534 }
535
536 ExprLang::Sqrt([a]) => {
537 let va = evaluate_complex_full_rec(nodes, usize::from(*a), values)?;
538 Ok(va.sqrt())
539 }
540
541 ExprLang::Asin([a]) => {
542 let va = evaluate_complex_full_rec(nodes, usize::from(*a), values)?;
543 Ok(va.asin())
544 }
545
546 ExprLang::Acos([a]) => {
547 let va = evaluate_complex_full_rec(nodes, usize::from(*a), values)?;
548 Ok(va.acos())
549 }
550
551 ExprLang::Atan([a]) => {
552 let va = evaluate_complex_full_rec(nodes, usize::from(*a), values)?;
553 Ok(va.atan())
554 }
555
556 ExprLang::Sinh([a]) => {
557 let va = evaluate_complex_full_rec(nodes, usize::from(*a), values)?;
558 Ok(va.sinh())
559 }
560
561 ExprLang::Cosh([a]) => {
562 let va = evaluate_complex_full_rec(nodes, usize::from(*a), values)?;
563 Ok(va.cosh())
564 }
565
566 ExprLang::Tanh([a]) => {
567 let va = evaluate_complex_full_rec(nodes, usize::from(*a), values)?;
568 Ok(va.tanh())
569 }
570
571 ExprLang::Re([a]) => {
572 let va = evaluate_complex_full_rec(nodes, usize::from(*a), values)?;
573 Ok(Complex64::new(va.re, 0.0))
574 }
575
576 ExprLang::Im([a]) => {
577 let va = evaluate_complex_full_rec(nodes, usize::from(*a), values)?;
578 Ok(Complex64::new(va.im, 0.0))
579 }
580
581 ExprLang::Conj([a]) => {
582 let va = evaluate_complex_full_rec(nodes, usize::from(*a), values)?;
583 Ok(va.conj())
584 }
585
586 ExprLang::Commutator([_, _])
587 | ExprLang::Anticommutator([_, _])
588 | ExprLang::TensorProduct([_, _])
589 | ExprLang::Trace([_])
590 | ExprLang::Dagger([_])
591 | ExprLang::Determinant([_])
592 | ExprLang::Transpose([_]) => Err(SymEngineError::eval(
593 "Cannot evaluate symbolic quantum operation numerically",
594 )),
595 }
596}
597
598#[cfg(test)]
599mod tests {
600 use super::*;
601
602 #[test]
603 fn test_eval_constant() {
604 let c = Expression::int(42);
605 let result = evaluate(&c, &HashMap::new()).expect("should evaluate");
606 assert!((result - 42.0).abs() < 1e-10);
607 }
608
609 #[test]
610 fn test_eval_variable() {
611 let x = Expression::symbol("x");
612 let mut values = HashMap::new();
613 values.insert("x".to_string(), 2.5);
614
615 let result = evaluate(&x, &values).expect("should evaluate");
616 assert!((result - 2.5).abs() < 1e-10);
617 }
618
619 #[test]
620 fn test_eval_expression() {
621 let x = Expression::symbol("x");
622 let expr = x.clone() * x; let mut values = HashMap::new();
624 values.insert("x".to_string(), 3.0);
625
626 let result = evaluate(&expr, &values).expect("should evaluate");
627 assert!((result - 9.0).abs() < 1e-10);
628 }
629
630 #[test]
631 fn test_eval_trig() {
632 let x = Expression::symbol("x");
633 let sin_x = crate::ops::trig::sin(&x);
634 let mut values = HashMap::new();
635 values.insert("x".to_string(), std::f64::consts::PI / 2.0);
636
637 let result = evaluate(&sin_x, &values).expect("should evaluate");
638 assert!((result - 1.0).abs() < 1e-10);
639 }
640
641 #[test]
642 fn test_eval_division_by_zero() {
643 let one = Expression::one();
644 let zero = Expression::zero();
645 let expr = one / zero;
646
647 let result = evaluate(&expr, &HashMap::new());
648 assert!(result.is_err());
649 }
650
651 #[test]
656 fn test_eval_complex_imaginary_unit() {
657 let i = Expression::i();
658 let result = evaluate_complex(&i, &HashMap::new()).expect("should evaluate");
659 assert!((result.re - 0.0).abs() < 1e-10);
660 assert!((result.im - 1.0).abs() < 1e-10);
661 }
662
663 #[test]
664 fn test_eval_complex_i_squared() {
665 let i = Expression::i();
667 let i_squared = i.clone() * i;
668 let result = evaluate_complex(&i_squared, &HashMap::new()).expect("should evaluate");
669 assert!((result.re - (-1.0)).abs() < 1e-10);
670 assert!(result.im.abs() < 1e-10);
671 }
672
673 #[test]
674 fn test_eval_complex_expression() {
675 let three = Expression::int(3);
677 let two = Expression::int(2);
678 let i = Expression::i();
679 let expr = three + two * i;
680
681 let result = evaluate_complex(&expr, &HashMap::new()).expect("should evaluate");
682 assert!((result.re - 3.0).abs() < 1e-10);
683 assert!((result.im - 2.0).abs() < 1e-10);
684 }
685
686 #[test]
687 fn test_eval_complex_with_variable() {
688 let x = Expression::symbol("x");
690 let i = Expression::i();
691 let expr = x * i;
692
693 let mut values = HashMap::new();
694 values.insert("x".to_string(), 5.0);
695
696 let result = evaluate_complex(&expr, &values).expect("should evaluate");
697 assert!((result.re - 0.0).abs() < 1e-10);
698 assert!((result.im - 5.0).abs() < 1e-10);
699 }
700
701 #[test]
702 fn test_eval_complex_exp() {
703 let i = Expression::i();
705 let pi = Expression::float_unchecked(std::f64::consts::PI);
706 let expr = crate::ops::trig::exp(&(i * pi));
707
708 let result = evaluate_complex(&expr, &HashMap::new()).expect("should evaluate");
709 assert!((result.re - (-1.0)).abs() < 1e-10);
710 assert!(result.im.abs() < 1e-10);
711 }
712
713 #[test]
714 fn test_eval_complex_conjugate() {
715 let three = Expression::int(3);
717 let two = Expression::int(2);
718 let i = Expression::i();
719 let z = three + two * i;
720 let conj_z = crate::ops::complex::conj(&z);
721
722 let result = evaluate_complex(&conj_z, &HashMap::new()).expect("should evaluate");
723 assert!((result.re - 3.0).abs() < 1e-10);
724 assert!((result.im - (-2.0)).abs() < 1e-10);
725 }
726
727 #[test]
728 fn test_eval_complex_real_part() {
729 let three = Expression::int(3);
731 let two = Expression::int(2);
732 let i = Expression::i();
733 let z = three + two * i;
734 let re_z = crate::ops::complex::re(&z);
735
736 let result = evaluate_complex(&re_z, &HashMap::new()).expect("should evaluate");
737 assert!((result.re - 3.0).abs() < 1e-10);
738 assert!(result.im.abs() < 1e-10);
739 }
740
741 #[test]
742 fn test_eval_complex_imag_part() {
743 let three = Expression::int(3);
745 let two = Expression::int(2);
746 let i = Expression::i();
747 let z = three + two * i;
748 let im_z = crate::ops::complex::im(&z);
749
750 let result = evaluate_complex(&im_z, &HashMap::new()).expect("should evaluate");
751 assert!((result.re - 2.0).abs() < 1e-10);
752 assert!(result.im.abs() < 1e-10);
753 }
754
755 #[test]
756 fn test_eval_complex_with_complex_values() {
757 let z = Expression::symbol("z");
759
760 let mut values = HashMap::new();
761 values.insert("z".to_string(), Complex64::new(1.0, 2.0));
762
763 let result = evaluate_complex_with_complex_values(&z, &values).expect("should evaluate");
764 assert!((result.re - 1.0).abs() < 1e-10);
765 assert!((result.im - 2.0).abs() < 1e-10);
766 }
767
768 #[test]
769 fn test_eval_complex_abs() {
770 let three = Expression::int(3);
772 let four = Expression::int(4);
773 let i = Expression::i();
774 let z = three + four * i;
775 let abs_z = crate::ops::trig::abs(&z);
776
777 let result = evaluate_complex(&abs_z, &HashMap::new()).expect("should evaluate");
778 assert!((result.re - 5.0).abs() < 1e-10);
779 assert!(result.im.abs() < 1e-10);
780 }
781}