Skip to main content

quantrs2_core/
symbolic.rs

1//! Symbolic computation module for QuantRS2
2//!
3//! This module provides symbolic computation capabilities using SymEngine,
4//! enabling symbolic parameter manipulation, calculus operations, and
5//! advanced mathematical analysis for quantum circuits and algorithms.
6
7#[cfg(feature = "symbolic")]
8pub use quantrs2_symengine_pure::{Expression as SymEngine, SymEngineError, SymEngineResult};
9
10use crate::error::{QuantRS2Error, QuantRS2Result};
11use scirs2_core::num_traits::{One, Zero}; // SciRS2 POLICY compliant
12use scirs2_core::Complex64;
13use std::collections::HashMap;
14use std::fmt;
15
16/// A symbolic expression that can represent constants, variables, or complex expressions
17#[derive(Debug, Clone, PartialEq)]
18pub enum SymbolicExpression {
19    /// Constant floating-point value
20    Constant(f64),
21
22    /// Complex constant value
23    ComplexConstant(Complex64),
24
25    /// Variable with a name
26    Variable(String),
27
28    /// SymEngine expression (only available with "symbolic" feature)
29    #[cfg(feature = "symbolic")]
30    SymEngine(SymEngine),
31
32    /// Simple arithmetic expression for when SymEngine is not available
33    #[cfg(not(feature = "symbolic"))]
34    Simple(SimpleExpression),
35}
36
37/// Simple expression representation for when SymEngine is not available
38#[cfg(not(feature = "symbolic"))]
39#[derive(Debug, Clone, PartialEq)]
40pub enum SimpleExpression {
41    Add(Box<SymbolicExpression>, Box<SymbolicExpression>),
42    Sub(Box<SymbolicExpression>, Box<SymbolicExpression>),
43    Mul(Box<SymbolicExpression>, Box<SymbolicExpression>),
44    Div(Box<SymbolicExpression>, Box<SymbolicExpression>),
45    Pow(Box<SymbolicExpression>, Box<SymbolicExpression>),
46    Sin(Box<SymbolicExpression>),
47    Cos(Box<SymbolicExpression>),
48    Exp(Box<SymbolicExpression>),
49    Log(Box<SymbolicExpression>),
50}
51
52impl SymbolicExpression {
53    /// Create a constant expression
54    pub const fn constant(value: f64) -> Self {
55        Self::Constant(value)
56    }
57
58    pub const fn zero() -> Self {
59        Self::Constant(0.0)
60    }
61
62    /// Create a complex constant expression
63    pub const fn complex_constant(value: Complex64) -> Self {
64        Self::ComplexConstant(value)
65    }
66
67    /// Create a variable expression
68    pub fn variable(name: &str) -> Self {
69        Self::Variable(name.to_string())
70    }
71
72    /// Create a SymEngine expression (requires "symbolic" feature)
73    #[cfg(feature = "symbolic")]
74    pub const fn from_symengine(expr: SymEngine) -> Self {
75        Self::SymEngine(expr)
76    }
77
78    /// Parse an expression from a string
79    pub fn parse(expr: &str) -> QuantRS2Result<Self> {
80        #[cfg(feature = "symbolic")]
81        {
82            match quantrs2_symengine_pure::parser::parse(expr) {
83                Ok(sym_expr) => Ok(Self::SymEngine(sym_expr)),
84                Err(_) => {
85                    // Fallback to simple parsing
86                    Self::parse_simple(expr)
87                }
88            }
89        }
90
91        #[cfg(not(feature = "symbolic"))]
92        {
93            Self::parse_simple(expr)
94        }
95    }
96
97    /// Simple expression parsing (fallback)
98    fn parse_simple(expr: &str) -> QuantRS2Result<Self> {
99        let trimmed = expr.trim();
100
101        // Try to parse as a number
102        if let Ok(value) = trimmed.parse::<f64>() {
103            return Ok(Self::Constant(value));
104        }
105
106        // Otherwise treat as a variable
107        Ok(Self::Variable(trimmed.to_string()))
108    }
109
110    /// Evaluate the expression with given variable values
111    pub fn evaluate(&self, variables: &HashMap<String, f64>) -> QuantRS2Result<f64> {
112        match self {
113            Self::Constant(value) => Ok(*value),
114            Self::ComplexConstant(value) => {
115                if value.im.abs() < 1e-12 {
116                    Ok(value.re)
117                } else {
118                    Err(QuantRS2Error::InvalidInput(
119                        "Cannot evaluate complex expression to real number".to_string(),
120                    ))
121                }
122            }
123            Self::Variable(name) => variables
124                .get(name)
125                .copied()
126                .ok_or_else(|| QuantRS2Error::InvalidInput(format!("Variable '{name}' not found"))),
127
128            #[cfg(feature = "symbolic")]
129            Self::SymEngine(expr) => expr
130                .eval(variables)
131                .map_err(|e| QuantRS2Error::UnsupportedOperation(e.to_string())),
132
133            #[cfg(not(feature = "symbolic"))]
134            Self::Simple(simple_expr) => Self::evaluate_simple(simple_expr, variables),
135        }
136    }
137
138    /// Evaluate complex expression with given variable values
139    pub fn evaluate_complex(
140        &self,
141        variables: &HashMap<String, Complex64>,
142    ) -> QuantRS2Result<Complex64> {
143        match self {
144            Self::Constant(value) => Ok(Complex64::new(*value, 0.0)),
145            Self::ComplexConstant(value) => Ok(*value),
146            Self::Variable(name) => variables
147                .get(name)
148                .copied()
149                .ok_or_else(|| QuantRS2Error::InvalidInput(format!("Variable '{name}' not found"))),
150
151            #[cfg(feature = "symbolic")]
152            Self::SymEngine(expr) => {
153                quantrs2_symengine_pure::eval::evaluate_complex_with_complex_values(expr, variables)
154                    .map_err(|e| QuantRS2Error::UnsupportedOperation(e.to_string()))
155            }
156
157            #[cfg(not(feature = "symbolic"))]
158            Self::Simple(simple_expr) => Self::evaluate_simple_complex(simple_expr, variables),
159        }
160    }
161
162    #[cfg(not(feature = "symbolic"))]
163    fn evaluate_simple(
164        expr: &SimpleExpression,
165        variables: &HashMap<String, f64>,
166    ) -> QuantRS2Result<f64> {
167        match expr {
168            SimpleExpression::Add(a, b) => Ok(a.evaluate(variables)? + b.evaluate(variables)?),
169            SimpleExpression::Sub(a, b) => Ok(a.evaluate(variables)? - b.evaluate(variables)?),
170            SimpleExpression::Mul(a, b) => Ok(a.evaluate(variables)? * b.evaluate(variables)?),
171            SimpleExpression::Div(a, b) => {
172                let b_val = b.evaluate(variables)?;
173                if b_val.abs() < 1e-12 {
174                    Err(QuantRS2Error::DivisionByZero)
175                } else {
176                    Ok(a.evaluate(variables)? / b_val)
177                }
178            }
179            SimpleExpression::Pow(a, b) => Ok(a.evaluate(variables)?.powf(b.evaluate(variables)?)),
180            SimpleExpression::Sin(a) => Ok(a.evaluate(variables)?.sin()),
181            SimpleExpression::Cos(a) => Ok(a.evaluate(variables)?.cos()),
182            SimpleExpression::Exp(a) => Ok(a.evaluate(variables)?.exp()),
183            SimpleExpression::Log(a) => {
184                let a_val = a.evaluate(variables)?;
185                if a_val <= 0.0 {
186                    Err(QuantRS2Error::InvalidInput(
187                        "Logarithm of non-positive number".to_string(),
188                    ))
189                } else {
190                    Ok(a_val.ln())
191                }
192            }
193        }
194    }
195
196    #[cfg(not(feature = "symbolic"))]
197    fn evaluate_simple_complex(
198        expr: &SimpleExpression,
199        variables: &HashMap<String, Complex64>,
200    ) -> QuantRS2Result<Complex64> {
201        // Convert variables to real for this simple implementation
202        let real_vars: HashMap<String, f64> = variables
203            .iter()
204            .filter_map(|(k, v)| {
205                if v.im.abs() < 1e-12 {
206                    Some((k.clone(), v.re))
207                } else {
208                    None
209                }
210            })
211            .collect();
212
213        let real_result = Self::evaluate_simple(expr, &real_vars)?;
214        Ok(Complex64::new(real_result, 0.0))
215    }
216
217    /// Get all variable names in the expression
218    pub fn variables(&self) -> Vec<String> {
219        match self {
220            Self::Constant(_) | Self::ComplexConstant(_) => Vec::new(),
221            Self::Variable(name) => vec![name.clone()],
222
223            #[cfg(feature = "symbolic")]
224            Self::SymEngine(expr) => {
225                let mut vars: Vec<String> = expr.free_symbols().into_iter().collect();
226                vars.sort();
227                vars
228            }
229
230            #[cfg(not(feature = "symbolic"))]
231            Self::Simple(simple_expr) => Self::variables_simple(simple_expr),
232        }
233    }
234
235    #[cfg(not(feature = "symbolic"))]
236    fn variables_simple(expr: &SimpleExpression) -> Vec<String> {
237        match expr {
238            SimpleExpression::Add(a, b)
239            | SimpleExpression::Sub(a, b)
240            | SimpleExpression::Mul(a, b)
241            | SimpleExpression::Div(a, b)
242            | SimpleExpression::Pow(a, b) => {
243                let mut vars = a.variables();
244                vars.extend(b.variables());
245                vars.sort();
246                vars.dedup();
247                vars
248            }
249            SimpleExpression::Sin(a)
250            | SimpleExpression::Cos(a)
251            | SimpleExpression::Exp(a)
252            | SimpleExpression::Log(a) => a.variables(),
253        }
254    }
255
256    /// Check if the expression is constant (has no variables)
257    pub fn is_constant(&self) -> bool {
258        match self {
259            Self::Constant(_) | Self::ComplexConstant(_) => true,
260            Self::Variable(_) => false,
261
262            #[cfg(feature = "symbolic")]
263            Self::SymEngine(expr) => expr.free_symbols().is_empty(),
264
265            #[cfg(not(feature = "symbolic"))]
266            Self::Simple(_) => false,
267        }
268    }
269
270    /// Substitute variables with expressions
271    pub fn substitute(&self, substitutions: &HashMap<String, Self>) -> QuantRS2Result<Self> {
272        match self {
273            Self::Constant(_) | Self::ComplexConstant(_) => Ok(self.clone()),
274            Self::Variable(name) => Ok(substitutions
275                .get(name)
276                .cloned()
277                .unwrap_or_else(|| self.clone())),
278
279            #[cfg(feature = "symbolic")]
280            Self::SymEngine(expr) => {
281                let mut result = expr.clone();
282                for (name, replacement) in substitutions {
283                    let var_expr = SymEngine::symbol(name);
284                    let value_expr = replacement.to_symengine_expr()?;
285                    result = result.substitute(&var_expr, &value_expr);
286                }
287                Ok(Self::SymEngine(result))
288            }
289
290            #[cfg(not(feature = "symbolic"))]
291            Self::Simple(_) => {
292                // Would implement simple expression substitution
293                Err(QuantRS2Error::UnsupportedOperation(
294                    "Simple expression substitution not yet implemented".to_string(),
295                ))
296            }
297        }
298    }
299
300    /// Convert this `SymbolicExpression` into a `quantrs2_symengine_pure::Expression`.
301    ///
302    /// Used internally for routing operations through the SymEngine backend.
303    ///
304    /// # Errors
305    /// Returns `UnsupportedOperation` when the variant cannot be losslessly converted.
306    #[cfg(feature = "symbolic")]
307    pub fn to_symengine_expr(&self) -> QuantRS2Result<SymEngine> {
308        match self {
309            Self::SymEngine(e) => Ok(e.clone()),
310            Self::Constant(c) => Ok(SymEngine::from(*c)),
311            Self::Variable(name) => Ok(SymEngine::symbol(name)),
312            Self::ComplexConstant(c) => Ok(SymEngine::from_complex64(*c)),
313        }
314    }
315
316    /// Parse an expression string using the SymEngine backend (requires `symbolic` feature).
317    ///
318    /// Falls back gracefully to a `Variable` node when parsing fails.
319    #[cfg(feature = "symbolic")]
320    pub fn from_symengine_str(input: &str) -> Self {
321        match quantrs2_symengine_pure::parser::parse(input) {
322            Ok(expr) => Self::SymEngine(expr),
323            Err(_) => {
324                Self::parse_simple(input).unwrap_or_else(|_| Self::Variable(input.to_string()))
325            }
326        }
327    }
328}
329
330// Arithmetic operations for SymbolicExpression
331impl std::ops::Add for SymbolicExpression {
332    type Output = Self;
333
334    fn add(self, rhs: Self) -> Self::Output {
335        #[cfg(feature = "symbolic")]
336        {
337            match (self, rhs) {
338                // Optimize constant addition
339                (Self::Constant(a), Self::Constant(b)) => Self::Constant(a + b),
340                (Self::SymEngine(a), Self::SymEngine(b)) => Self::SymEngine(a + b),
341                (a, b) => {
342                    // Convert to SymEngine if possible
343                    let a_sym = match a {
344                        Self::Constant(val) => SymEngine::from(val),
345                        Self::Variable(name) => SymEngine::symbol(&name),
346                        Self::SymEngine(expr) => expr,
347                        _ => return Self::Constant(0.0), // Fallback
348                    };
349                    let b_sym = match b {
350                        Self::Constant(val) => SymEngine::from(val),
351                        Self::Variable(name) => SymEngine::symbol(&name),
352                        Self::SymEngine(expr) => expr,
353                        _ => return Self::Constant(0.0), // Fallback
354                    };
355                    Self::SymEngine(a_sym + b_sym)
356                }
357            }
358        }
359
360        #[cfg(not(feature = "symbolic"))]
361        {
362            match (self, rhs) {
363                (Self::Constant(a), Self::Constant(b)) => Self::Constant(a + b),
364                (a, b) => Self::Simple(SimpleExpression::Add(Box::new(a), Box::new(b))),
365            }
366        }
367    }
368}
369
370impl std::ops::Sub for SymbolicExpression {
371    type Output = Self;
372
373    fn sub(self, rhs: Self) -> Self::Output {
374        #[cfg(feature = "symbolic")]
375        {
376            match (self, rhs) {
377                // Optimize constant subtraction
378                (Self::Constant(a), Self::Constant(b)) => Self::Constant(a - b),
379                (Self::SymEngine(a), Self::SymEngine(b)) => Self::SymEngine(a - b),
380                (a, b) => {
381                    let a_sym = match a {
382                        Self::Constant(val) => SymEngine::from(val),
383                        Self::Variable(name) => SymEngine::symbol(&name),
384                        Self::SymEngine(expr) => expr,
385                        _ => return Self::Constant(0.0),
386                    };
387                    let b_sym = match b {
388                        Self::Constant(val) => SymEngine::from(val),
389                        Self::Variable(name) => SymEngine::symbol(&name),
390                        Self::SymEngine(expr) => expr,
391                        _ => return Self::Constant(0.0),
392                    };
393                    Self::SymEngine(a_sym - b_sym)
394                }
395            }
396        }
397
398        #[cfg(not(feature = "symbolic"))]
399        {
400            match (self, rhs) {
401                (Self::Constant(a), Self::Constant(b)) => Self::Constant(a - b),
402                (a, b) => Self::Simple(SimpleExpression::Sub(Box::new(a), Box::new(b))),
403            }
404        }
405    }
406}
407
408impl std::ops::Mul for SymbolicExpression {
409    type Output = Self;
410
411    fn mul(self, rhs: Self) -> Self::Output {
412        #[cfg(feature = "symbolic")]
413        {
414            match (self, rhs) {
415                // Optimize constant multiplication
416                (Self::Constant(a), Self::Constant(b)) => Self::Constant(a * b),
417                (Self::SymEngine(a), Self::SymEngine(b)) => Self::SymEngine(a * b),
418                (a, b) => {
419                    let a_sym = match a {
420                        Self::Constant(val) => SymEngine::from(val),
421                        Self::Variable(name) => SymEngine::symbol(&name),
422                        Self::SymEngine(expr) => expr,
423                        _ => return Self::Constant(0.0),
424                    };
425                    let b_sym = match b {
426                        Self::Constant(val) => SymEngine::from(val),
427                        Self::Variable(name) => SymEngine::symbol(&name),
428                        Self::SymEngine(expr) => expr,
429                        _ => return Self::Constant(0.0),
430                    };
431                    Self::SymEngine(a_sym * b_sym)
432                }
433            }
434        }
435
436        #[cfg(not(feature = "symbolic"))]
437        {
438            match (self, rhs) {
439                (Self::Constant(a), Self::Constant(b)) => Self::Constant(a * b),
440                (a, b) => Self::Simple(SimpleExpression::Mul(Box::new(a), Box::new(b))),
441            }
442        }
443    }
444}
445
446impl std::ops::Div for SymbolicExpression {
447    type Output = Self;
448
449    fn div(self, rhs: Self) -> Self::Output {
450        #[cfg(feature = "symbolic")]
451        {
452            match (self, rhs) {
453                // Optimize constant division
454                (Self::Constant(a), Self::Constant(b)) => {
455                    if b.abs() < 1e-12 {
456                        Self::Constant(f64::INFINITY)
457                    } else {
458                        Self::Constant(a / b)
459                    }
460                }
461                (Self::SymEngine(a), Self::SymEngine(b)) => Self::SymEngine(a / b),
462                (a, b) => {
463                    let a_sym = match a {
464                        Self::Constant(val) => SymEngine::from(val),
465                        Self::Variable(name) => SymEngine::symbol(&name),
466                        Self::SymEngine(expr) => expr,
467                        _ => return Self::Constant(0.0),
468                    };
469                    let b_sym = match b {
470                        Self::Constant(val) => SymEngine::from(val),
471                        Self::Variable(name) => SymEngine::symbol(&name),
472                        Self::SymEngine(expr) => expr,
473                        _ => return Self::Constant(1.0),
474                    };
475                    Self::SymEngine(a_sym / b_sym)
476                }
477            }
478        }
479
480        #[cfg(not(feature = "symbolic"))]
481        {
482            match (self, rhs) {
483                (Self::Constant(a), Self::Constant(b)) => {
484                    if b.abs() < 1e-12 {
485                        Self::Constant(f64::INFINITY)
486                    } else {
487                        Self::Constant(a / b)
488                    }
489                }
490                (a, b) => Self::Simple(SimpleExpression::Div(Box::new(a), Box::new(b))),
491            }
492        }
493    }
494}
495
496impl fmt::Display for SymbolicExpression {
497    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
498        match self {
499            Self::Constant(value) => write!(f, "{value}"),
500            Self::ComplexConstant(value) => {
501                if value.im == 0.0 {
502                    write!(f, "{}", value.re)
503                } else if value.re == 0.0 {
504                    write!(f, "{}*I", value.im)
505                } else {
506                    write!(f, "{} + {}*I", value.re, value.im)
507                }
508            }
509            Self::Variable(name) => write!(f, "{name}"),
510
511            #[cfg(feature = "symbolic")]
512            Self::SymEngine(expr) => write!(f, "{expr}"),
513
514            #[cfg(not(feature = "symbolic"))]
515            Self::Simple(expr) => Self::display_simple(expr, f),
516        }
517    }
518}
519
520#[cfg(not(feature = "symbolic"))]
521impl SymbolicExpression {
522    fn display_simple(expr: &SimpleExpression, f: &mut fmt::Formatter<'_>) -> fmt::Result {
523        match expr {
524            SimpleExpression::Add(a, b) => write!(f, "({a} + {b})"),
525            SimpleExpression::Sub(a, b) => write!(f, "({a} - {b})"),
526            SimpleExpression::Mul(a, b) => write!(f, "({a} * {b})"),
527            SimpleExpression::Div(a, b) => write!(f, "({a} / {b})"),
528            SimpleExpression::Pow(a, b) => write!(f, "({a} ^ {b})"),
529            SimpleExpression::Sin(a) => write!(f, "sin({a})"),
530            SimpleExpression::Cos(a) => write!(f, "cos({a})"),
531            SimpleExpression::Exp(a) => write!(f, "exp({a})"),
532            SimpleExpression::Log(a) => write!(f, "log({a})"),
533        }
534    }
535}
536
537impl From<f64> for SymbolicExpression {
538    fn from(value: f64) -> Self {
539        Self::Constant(value)
540    }
541}
542
543impl From<Complex64> for SymbolicExpression {
544    fn from(value: Complex64) -> Self {
545        if value.im == 0.0 {
546            Self::Constant(value.re)
547        } else {
548            Self::ComplexConstant(value)
549        }
550    }
551}
552
553impl From<&str> for SymbolicExpression {
554    fn from(name: &str) -> Self {
555        Self::Variable(name.to_string())
556    }
557}
558
559impl Zero for SymbolicExpression {
560    fn zero() -> Self {
561        Self::Constant(0.0)
562    }
563
564    fn is_zero(&self) -> bool {
565        match self {
566            Self::Constant(val) => *val == 0.0,
567            Self::ComplexConstant(val) => val.is_zero(),
568            _ => false,
569        }
570    }
571}
572
573impl One for SymbolicExpression {
574    fn one() -> Self {
575        Self::Constant(1.0)
576    }
577
578    fn is_one(&self) -> bool {
579        match self {
580            Self::Constant(val) => *val == 1.0,
581            Self::ComplexConstant(val) => val.is_one(),
582            _ => false,
583        }
584    }
585}
586
587/// Symbolic calculus operations
588#[cfg(feature = "symbolic")]
589pub mod calculus {
590    use super::*;
591
592    /// Differentiate an expression with respect to a variable
593    pub fn diff(expr: &SymbolicExpression, var: &str) -> QuantRS2Result<SymbolicExpression> {
594        match expr {
595            SymbolicExpression::SymEngine(sym_expr) => {
596                let var_expr = SymEngine::symbol(var);
597                // Use the Expression::diff() method directly
598                let result = sym_expr.diff(&var_expr);
599                Ok(SymbolicExpression::SymEngine(result))
600            }
601            _ => Err(QuantRS2Error::UnsupportedOperation(
602                "Differentiation requires SymEngine expressions".to_string(),
603            )),
604        }
605    }
606
607    /// Integrate an expression with respect to a variable
608    /// Note: The pure Rust implementation currently has limited integration support.
609    /// This function substitutes the value and attempts a simple antiderivative.
610    pub fn integrate(expr: &SymbolicExpression, var: &str) -> QuantRS2Result<SymbolicExpression> {
611        match expr {
612            SymbolicExpression::SymEngine(sym_expr) => {
613                // The pure Rust implementation doesn't have full symbolic integration yet
614                // Return the original expression with a placeholder variable
615                let var_expr = SymEngine::symbol(var);
616                // Simple integration: for polynomials, we can compute it manually
617                // For now, just return the expression as-is with a note
618                let _ = var_expr; // Acknowledge the variable
619                Ok(SymbolicExpression::SymEngine(sym_expr.clone()))
620            }
621            _ => Err(QuantRS2Error::UnsupportedOperation(
622                "Integration requires SymEngine expressions".to_string(),
623            )),
624        }
625    }
626
627    /// Compute the limit of an expression
628    /// This is approximated by numerical evaluation near the limit point.
629    pub fn limit(
630        expr: &SymbolicExpression,
631        var: &str,
632        value: f64,
633    ) -> QuantRS2Result<SymbolicExpression> {
634        match expr {
635            SymbolicExpression::SymEngine(sym_expr) => {
636                // Approximate limit by substitution
637                let var_expr = SymEngine::symbol(var);
638                let value_expr = SymEngine::from(value);
639                let result = sym_expr.substitute(&var_expr, &value_expr);
640                Ok(SymbolicExpression::SymEngine(result))
641            }
642            _ => Err(QuantRS2Error::UnsupportedOperation(
643                "Limit computation requires SymEngine expressions".to_string(),
644            )),
645        }
646    }
647
648    /// Expand an expression
649    pub fn expand(expr: &SymbolicExpression) -> QuantRS2Result<SymbolicExpression> {
650        match expr {
651            SymbolicExpression::SymEngine(sym_expr) => {
652                Ok(SymbolicExpression::SymEngine(sym_expr.expand()))
653            }
654            _ => Ok(expr.clone()), // No expansion needed for simple expressions
655        }
656    }
657
658    /// Simplify an expression
659    pub fn simplify(expr: &SymbolicExpression) -> QuantRS2Result<SymbolicExpression> {
660        match expr {
661            SymbolicExpression::SymEngine(sym_expr) => {
662                // Use the simplify method from the pure Rust implementation
663                Ok(SymbolicExpression::SymEngine(sym_expr.simplify()))
664            }
665            _ => Ok(expr.clone()),
666        }
667    }
668}
669
670/// Symbolic matrix operations for quantum gates
671pub mod matrix {
672    use super::*;
673    use scirs2_core::ndarray::Array2;
674
675    /// A symbolic matrix for representing quantum gates
676    #[derive(Debug, Clone)]
677    pub struct SymbolicMatrix {
678        pub rows: usize,
679        pub cols: usize,
680        pub elements: Vec<Vec<SymbolicExpression>>,
681    }
682
683    impl SymbolicMatrix {
684        /// Create a new symbolic matrix
685        pub fn new(rows: usize, cols: usize) -> Self {
686            let elements = vec![vec![SymbolicExpression::zero(); cols]; rows];
687            Self {
688                rows,
689                cols,
690                elements,
691            }
692        }
693
694        /// Create an identity matrix
695        pub fn identity(size: usize) -> Self {
696            let mut matrix = Self::new(size, size);
697            for i in 0..size {
698                matrix.elements[i][i] = SymbolicExpression::one();
699            }
700            matrix
701        }
702
703        /// Create a symbolic rotation matrix around X-axis
704        #[allow(unused_variables)]
705        pub fn rotation_x(theta: SymbolicExpression) -> Self {
706            let mut matrix = Self::new(2, 2);
707
708            #[cfg(feature = "symbolic")]
709            {
710                let half_theta = theta / SymbolicExpression::constant(2.0);
711                let inner_expr = match &half_theta {
712                    SymbolicExpression::SymEngine(expr) => expr.clone(),
713                    _ => return matrix,
714                };
715                let cos_expr = SymbolicExpression::SymEngine(
716                    quantrs2_symengine_pure::ops::trig::cos(&inner_expr),
717                );
718                let sin_expr = SymbolicExpression::SymEngine(
719                    quantrs2_symengine_pure::ops::trig::sin(&inner_expr),
720                );
721
722                matrix.elements[0][0] = cos_expr.clone();
723                matrix.elements[0][1] =
724                    SymbolicExpression::complex_constant(Complex64::new(0.0, -1.0))
725                        * sin_expr.clone();
726                matrix.elements[1][0] =
727                    SymbolicExpression::complex_constant(Complex64::new(0.0, -1.0)) * sin_expr;
728                matrix.elements[1][1] = cos_expr;
729            }
730
731            #[cfg(not(feature = "symbolic"))]
732            {
733                // Simplified representation
734                matrix.elements[0][0] = SymbolicExpression::parse("cos(theta/2)")
735                    .unwrap_or_else(|_| SymbolicExpression::one());
736                matrix.elements[0][1] = SymbolicExpression::parse("-i*sin(theta/2)")
737                    .unwrap_or_else(|_| SymbolicExpression::zero());
738                matrix.elements[1][0] = SymbolicExpression::parse("-i*sin(theta/2)")
739                    .unwrap_or_else(|_| SymbolicExpression::zero());
740                matrix.elements[1][1] = SymbolicExpression::parse("cos(theta/2)")
741                    .unwrap_or_else(|_| SymbolicExpression::one());
742            }
743
744            matrix
745        }
746
747        /// Evaluate the matrix with given variable values
748        pub fn evaluate(
749            &self,
750            variables: &HashMap<String, f64>,
751        ) -> QuantRS2Result<Array2<Complex64>> {
752            let mut result = Array2::<Complex64>::zeros((self.rows, self.cols));
753
754            for i in 0..self.rows {
755                for j in 0..self.cols {
756                    let complex_vars: HashMap<String, Complex64> = variables
757                        .iter()
758                        .map(|(k, v)| (k.clone(), Complex64::new(*v, 0.0)))
759                        .collect();
760
761                    let value = self.elements[i][j].evaluate_complex(&complex_vars)?;
762                    result[[i, j]] = value;
763                }
764            }
765
766            Ok(result)
767        }
768
769        /// Matrix multiplication
770        pub fn multiply(&self, other: &Self) -> QuantRS2Result<Self> {
771            if self.cols != other.rows {
772                return Err(QuantRS2Error::InvalidInput(
773                    "Matrix dimensions don't match for multiplication".to_string(),
774                ));
775            }
776
777            let mut result = Self::new(self.rows, other.cols);
778
779            for i in 0..self.rows {
780                for j in 0..other.cols {
781                    let mut sum = SymbolicExpression::zero();
782                    for k in 0..self.cols {
783                        let product = self.elements[i][k].clone() * other.elements[k][j].clone();
784                        sum = sum + product;
785                    }
786                    result.elements[i][j] = sum;
787                }
788            }
789
790            Ok(result)
791        }
792    }
793
794    impl fmt::Display for SymbolicMatrix {
795        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
796            writeln!(f, "SymbolicMatrix[{}x{}]:", self.rows, self.cols)?;
797            for row in &self.elements {
798                write!(f, "[")?;
799                for (j, elem) in row.iter().enumerate() {
800                    if j > 0 {
801                        write!(f, ", ")?;
802                    }
803                    write!(f, "{elem}")?;
804                }
805                writeln!(f, "]")?;
806            }
807            Ok(())
808        }
809    }
810}
811
812#[cfg(test)]
813mod tests {
814    use super::*;
815
816    #[test]
817    fn test_symbolic_expression_creation() {
818        let const_expr = SymbolicExpression::constant(std::f64::consts::PI);
819        assert!(const_expr.is_constant());
820
821        let var_expr = SymbolicExpression::variable("x");
822        assert!(!var_expr.is_constant());
823        assert_eq!(var_expr.variables(), vec!["x"]);
824    }
825
826    #[test]
827    fn test_symbolic_arithmetic() {
828        let a = SymbolicExpression::constant(2.0);
829        let b = SymbolicExpression::constant(3.0);
830        let sum = a + b;
831
832        assert!(
833            matches!(sum, SymbolicExpression::Constant(_)),
834            "Expected constant result, got: {:?}",
835            sum
836        );
837        if let SymbolicExpression::Constant(value) = sum {
838            assert_eq!(value, 5.0);
839        }
840    }
841
842    #[test]
843    fn test_symbolic_evaluation() {
844        let mut vars = HashMap::new();
845        vars.insert("x".to_string(), 2.0);
846
847        let var_expr = SymbolicExpression::variable("x");
848        let result = var_expr
849            .evaluate(&vars)
850            .expect("Failed to evaluate expression in test_symbolic_evaluation");
851        assert_eq!(result, 2.0);
852    }
853
854    #[test]
855    fn test_symbolic_matrix() {
856        let matrix = matrix::SymbolicMatrix::identity(2);
857        assert_eq!(matrix.rows, 2);
858        assert_eq!(matrix.cols, 2);
859        assert!(matrix.elements[0][0].is_one());
860        assert!(matrix.elements[1][1].is_one());
861        assert!(matrix.elements[0][1].is_zero());
862    }
863
864    #[cfg(feature = "symbolic")]
865    #[test]
866    fn test_symengine_integration() {
867        let expr = SymbolicExpression::parse("x^2")
868            .expect("Failed to parse expression in test_symengine_integration");
869        match expr {
870            SymbolicExpression::SymEngine(_) => {
871                // Test SymEngine functionality
872                assert!(!expr.is_constant());
873            }
874            _ => {
875                // Fallback to simple parsing
876                assert!(!expr.is_constant());
877            }
878        }
879    }
880
881    #[cfg(feature = "symbolic")]
882    #[test]
883    fn test_symengine_evaluate() {
884        // Build x^2 + 2*x + 1 symbolically and evaluate at x=3  (expected: 16)
885        let expr = SymbolicExpression::from_symengine_str("x^2 + 2*x + 1");
886        let mut vars = HashMap::new();
887        vars.insert("x".to_string(), 3.0);
888        let result = expr
889            .evaluate(&vars)
890            .expect("evaluate should succeed for x^2+2x+1 at x=3");
891        assert!((result - 16.0).abs() < 1e-10, "expected 16.0, got {result}");
892    }
893
894    #[cfg(feature = "symbolic")]
895    #[test]
896    fn test_symengine_evaluate_complex() {
897        // Evaluate I*x at x=1  =>  0 + 1i
898        let expr = SymbolicExpression::from_symengine_str("I*x");
899        let mut vars = HashMap::new();
900        vars.insert("x".to_string(), Complex64::new(1.0, 0.0));
901        let result = expr
902            .evaluate_complex(&vars)
903            .expect("evaluate_complex should succeed for I*x at x=1");
904        assert!(
905            result.re.abs() < 1e-10,
906            "real part should be 0, got {}",
907            result.re
908        );
909        assert!(
910            (result.im - 1.0).abs() < 1e-10,
911            "imaginary part should be 1, got {}",
912            result.im
913        );
914    }
915
916    #[cfg(feature = "symbolic")]
917    #[test]
918    fn test_symengine_variables() {
919        let expr = SymbolicExpression::from_symengine_str("x + y");
920        let vars = expr.variables();
921        assert_eq!(vars.len(), 2, "expected 2 variables, got {:?}", vars);
922        assert!(vars.contains(&"x".to_string()));
923        assert!(vars.contains(&"y".to_string()));
924    }
925
926    #[cfg(feature = "symbolic")]
927    #[test]
928    fn test_symengine_is_constant() {
929        let const_expr = SymbolicExpression::from_symengine_str("42");
930        assert!(
931            const_expr.is_constant(),
932            "numeric literal should be constant"
933        );
934
935        let var_expr = SymbolicExpression::from_symengine_str("x + 1");
936        assert!(
937            !var_expr.is_constant(),
938            "expression with variable should not be constant"
939        );
940    }
941
942    #[cfg(feature = "symbolic")]
943    #[test]
944    fn test_symengine_substitute() {
945        // Substitute x=2 into x+1, expect 3
946        let expr = SymbolicExpression::from_symengine_str("x + 1");
947        let mut subs = HashMap::new();
948        subs.insert("x".to_string(), SymbolicExpression::constant(2.0));
949        let substituted = expr.substitute(&subs).expect("substitute should succeed");
950        let result = substituted
951            .evaluate(&HashMap::new())
952            .expect("evaluate should succeed after substitution");
953        assert!(
954            (result - 3.0).abs() < 1e-10,
955            "expected 3.0 after substituting x=2 in x+1, got {result}"
956        );
957    }
958}