quantrs2_symengine_pure/matrix/
mod.rs

1//! Symbolic matrix operations.
2//!
3//! This module provides a proper matrix type where each element is a symbolic expression.
4//! This is essential for quantum computing where we work with parameterized gates and
5//! symbolic Hamiltonians.
6
7use std::fmt;
8
9use scirs2_core::ndarray::{Array1, Array2};
10use scirs2_core::Complex64;
11
12use crate::error::{SymEngineError, SymEngineResult};
13use crate::expr::Expression;
14
15/// A symbolic matrix where each element is an Expression.
16///
17/// This is useful for representing parameterized quantum gates,
18/// symbolic Hamiltonians, and other matrix expressions.
19#[derive(Clone, Debug)]
20pub struct SymbolicMatrix {
21    /// The matrix elements (row-major order)
22    elements: Vec<Expression>,
23    /// Number of rows
24    rows: usize,
25    /// Number of columns
26    cols: usize,
27}
28
29impl SymbolicMatrix {
30    /// Create a new symbolic matrix from a 2D array of expressions.
31    ///
32    /// # Errors
33    /// Returns error if the input is empty
34    pub fn new(elements: Vec<Vec<Expression>>) -> SymEngineResult<Self> {
35        if elements.is_empty() {
36            return Err(SymEngineError::dimension("Matrix cannot be empty"));
37        }
38
39        let rows = elements.len();
40        let cols = elements[0].len();
41
42        // Verify all rows have the same length
43        for (i, row) in elements.iter().enumerate() {
44            if row.len() != cols {
45                return Err(SymEngineError::dimension(format!(
46                    "Row {i} has {} columns, expected {cols}",
47                    row.len()
48                )));
49            }
50        }
51
52        let flat: Vec<Expression> = elements.into_iter().flatten().collect();
53
54        Ok(Self {
55            elements: flat,
56            rows,
57            cols,
58        })
59    }
60
61    /// Create a matrix from a flat vector with specified dimensions.
62    ///
63    /// # Errors
64    /// Returns error if the dimensions don't match the vector length
65    pub fn from_flat(elements: Vec<Expression>, rows: usize, cols: usize) -> SymEngineResult<Self> {
66        if elements.len() != rows * cols {
67            return Err(SymEngineError::dimension(format!(
68                "Expected {} elements for {}x{} matrix, got {}",
69                rows * cols,
70                rows,
71                cols,
72                elements.len()
73            )));
74        }
75
76        Ok(Self {
77            elements,
78            rows,
79            cols,
80        })
81    }
82
83    /// Create a zero matrix.
84    #[must_use]
85    pub fn zeros(rows: usize, cols: usize) -> Self {
86        Self {
87            elements: vec![Expression::zero(); rows * cols],
88            rows,
89            cols,
90        }
91    }
92
93    /// Create an identity matrix.
94    #[must_use]
95    pub fn identity(n: usize) -> Self {
96        let mut elements = vec![Expression::zero(); n * n];
97        for i in 0..n {
98            elements[i * n + i] = Expression::one();
99        }
100        Self {
101            elements,
102            rows: n,
103            cols: n,
104        }
105    }
106
107    /// Create a diagonal matrix from a vector of diagonal elements.
108    #[must_use]
109    pub fn diagonal(diag: Vec<Expression>) -> Self {
110        let n = diag.len();
111        let mut elements = vec![Expression::zero(); n * n];
112        for (i, d) in diag.into_iter().enumerate() {
113            elements[i * n + i] = d;
114        }
115        Self {
116            elements,
117            rows: n,
118            cols: n,
119        }
120    }
121
122    /// Create a matrix from a numeric Array2.
123    #[must_use]
124    pub fn from_array(arr: &Array2<f64>) -> Self {
125        let rows = arr.nrows();
126        let cols = arr.ncols();
127        let elements: Vec<Expression> = arr
128            .iter()
129            .map(|&v| Expression::float_unchecked(v))
130            .collect();
131        Self {
132            elements,
133            rows,
134            cols,
135        }
136    }
137
138    /// Create a matrix from a complex Array2.
139    #[must_use]
140    pub fn from_complex_array(arr: &Array2<Complex64>) -> Self {
141        let rows = arr.nrows();
142        let cols = arr.ncols();
143        let elements: Vec<Expression> =
144            arr.iter().map(|&c| Expression::from_complex64(c)).collect();
145        Self {
146            elements,
147            rows,
148            cols,
149        }
150    }
151
152    // =========================================================================
153    // Accessors
154    // =========================================================================
155
156    /// Get the number of rows.
157    #[must_use]
158    pub const fn nrows(&self) -> usize {
159        self.rows
160    }
161
162    /// Get the number of columns.
163    #[must_use]
164    pub const fn ncols(&self) -> usize {
165        self.cols
166    }
167
168    /// Get the dimensions as (rows, cols).
169    #[must_use]
170    pub const fn shape(&self) -> (usize, usize) {
171        (self.rows, self.cols)
172    }
173
174    /// Check if the matrix is square.
175    #[must_use]
176    pub const fn is_square(&self) -> bool {
177        self.rows == self.cols
178    }
179
180    /// Get element at (i, j).
181    ///
182    /// # Panics
183    /// Panics if indices are out of bounds.
184    #[must_use]
185    pub fn get(&self, i: usize, j: usize) -> &Expression {
186        assert!(i < self.rows && j < self.cols, "Index out of bounds");
187        &self.elements[i * self.cols + j]
188    }
189
190    /// Get mutable reference to element at (i, j).
191    ///
192    /// # Panics
193    /// Panics if indices are out of bounds.
194    pub fn get_mut(&mut self, i: usize, j: usize) -> &mut Expression {
195        assert!(i < self.rows && j < self.cols, "Index out of bounds");
196        &mut self.elements[i * self.cols + j]
197    }
198
199    /// Set element at (i, j).
200    ///
201    /// # Panics
202    /// Panics if indices are out of bounds.
203    pub fn set(&mut self, i: usize, j: usize, value: Expression) {
204        assert!(i < self.rows && j < self.cols, "Index out of bounds");
205        self.elements[i * self.cols + j] = value;
206    }
207
208    /// Get a row as a vector of expressions.
209    #[must_use]
210    pub fn row(&self, i: usize) -> Vec<Expression> {
211        assert!(i < self.rows, "Row index out of bounds");
212        let start = i * self.cols;
213        self.elements[start..start + self.cols].to_vec()
214    }
215
216    /// Get a column as a vector of expressions.
217    #[must_use]
218    pub fn col(&self, j: usize) -> Vec<Expression> {
219        assert!(j < self.cols, "Column index out of bounds");
220        (0..self.rows).map(|i| self.get(i, j).clone()).collect()
221    }
222
223    // =========================================================================
224    // Matrix Operations
225    // =========================================================================
226
227    /// Matrix transpose.
228    #[must_use]
229    pub fn transpose(&self) -> Self {
230        let mut elements = Vec::with_capacity(self.rows * self.cols);
231        for j in 0..self.cols {
232            for i in 0..self.rows {
233                elements.push(self.get(i, j).clone());
234            }
235        }
236        Self {
237            elements,
238            rows: self.cols,
239            cols: self.rows,
240        }
241    }
242
243    /// Complex conjugate of all elements.
244    #[must_use]
245    pub fn conjugate(&self) -> Self {
246        Self {
247            elements: self.elements.iter().map(Expression::conjugate).collect(),
248            rows: self.rows,
249            cols: self.cols,
250        }
251    }
252
253    /// Hermitian conjugate (conjugate transpose).
254    #[must_use]
255    pub fn dagger(&self) -> Self {
256        self.transpose().conjugate()
257    }
258
259    /// Matrix addition.
260    ///
261    /// # Errors
262    /// Returns error if dimensions don't match.
263    pub fn add(&self, other: &Self) -> SymEngineResult<Self> {
264        if self.rows != other.rows || self.cols != other.cols {
265            return Err(SymEngineError::dimension(format!(
266                "Cannot add {}x{} matrix with {}x{} matrix",
267                self.rows, self.cols, other.rows, other.cols
268            )));
269        }
270
271        let elements: Vec<Expression> = self
272            .elements
273            .iter()
274            .zip(other.elements.iter())
275            .map(|(a, b)| a.clone() + b.clone())
276            .collect();
277
278        Ok(Self {
279            elements,
280            rows: self.rows,
281            cols: self.cols,
282        })
283    }
284
285    /// Matrix subtraction.
286    ///
287    /// # Errors
288    /// Returns error if dimensions don't match.
289    pub fn sub(&self, other: &Self) -> SymEngineResult<Self> {
290        if self.rows != other.rows || self.cols != other.cols {
291            return Err(SymEngineError::dimension(format!(
292                "Cannot subtract {}x{} matrix from {}x{} matrix",
293                other.rows, other.cols, self.rows, self.cols
294            )));
295        }
296
297        let elements: Vec<Expression> = self
298            .elements
299            .iter()
300            .zip(other.elements.iter())
301            .map(|(a, b)| a.clone() - b.clone())
302            .collect();
303
304        Ok(Self {
305            elements,
306            rows: self.rows,
307            cols: self.cols,
308        })
309    }
310
311    /// Matrix multiplication.
312    ///
313    /// # Errors
314    /// Returns error if inner dimensions don't match.
315    pub fn matmul(&self, other: &Self) -> SymEngineResult<Self> {
316        if self.cols != other.rows {
317            return Err(SymEngineError::dimension(format!(
318                "Cannot multiply {}x{} matrix with {}x{} matrix",
319                self.rows, self.cols, other.rows, other.cols
320            )));
321        }
322
323        let mut elements = Vec::with_capacity(self.rows * other.cols);
324
325        for i in 0..self.rows {
326            for j in 0..other.cols {
327                let mut sum = Expression::zero();
328                for k in 0..self.cols {
329                    sum = sum + self.get(i, k).clone() * other.get(k, j).clone();
330                }
331                elements.push(sum);
332            }
333        }
334
335        Ok(Self {
336            elements,
337            rows: self.rows,
338            cols: other.cols,
339        })
340    }
341
342    /// Scalar multiplication.
343    #[must_use]
344    pub fn scale(&self, scalar: &Expression) -> Self {
345        Self {
346            elements: self
347                .elements
348                .iter()
349                .map(|e| e.clone() * scalar.clone())
350                .collect(),
351            rows: self.rows,
352            cols: self.cols,
353        }
354    }
355
356    /// Kronecker (tensor) product.
357    #[must_use]
358    pub fn kron(&self, other: &Self) -> Self {
359        let new_rows = self.rows * other.rows;
360        let new_cols = self.cols * other.cols;
361        let mut elements = Vec::with_capacity(new_rows * new_cols);
362
363        for i1 in 0..self.rows {
364            for i2 in 0..other.rows {
365                for j1 in 0..self.cols {
366                    for j2 in 0..other.cols {
367                        let a = self.get(i1, j1).clone();
368                        let b = other.get(i2, j2).clone();
369                        elements.push(a * b);
370                    }
371                }
372            }
373        }
374
375        Self {
376            elements,
377            rows: new_rows,
378            cols: new_cols,
379        }
380    }
381
382    /// Matrix trace (sum of diagonal elements).
383    ///
384    /// # Errors
385    /// Returns error if matrix is not square.
386    pub fn trace(&self) -> SymEngineResult<Expression> {
387        if !self.is_square() {
388            return Err(SymEngineError::dimension(
389                "Trace is only defined for square matrices",
390            ));
391        }
392
393        let mut sum = Expression::zero();
394        for i in 0..self.rows {
395            sum = sum + self.get(i, i).clone();
396        }
397        Ok(sum)
398    }
399
400    /// Commutator [A, B] = AB - BA.
401    ///
402    /// # Errors
403    /// Returns error if dimensions don't match or matrices are not square.
404    pub fn commutator(&self, other: &Self) -> SymEngineResult<Self> {
405        let ab = self.matmul(other)?;
406        let ba = other.matmul(self)?;
407        ab.sub(&ba)
408    }
409
410    /// Anticommutator {A, B} = AB + BA.
411    ///
412    /// # Errors
413    /// Returns error if dimensions don't match.
414    pub fn anticommutator(&self, other: &Self) -> SymEngineResult<Self> {
415        let ab = self.matmul(other)?;
416        let ba = other.matmul(self)?;
417        ab.add(&ba)
418    }
419
420    // =========================================================================
421    // Simplification
422    // =========================================================================
423
424    /// Simplify all matrix elements.
425    #[must_use]
426    pub fn simplify(&self) -> Self {
427        Self {
428            elements: self.elements.iter().map(Expression::simplify).collect(),
429            rows: self.rows,
430            cols: self.cols,
431        }
432    }
433
434    /// Expand all matrix elements.
435    #[must_use]
436    pub fn expand(&self) -> Self {
437        Self {
438            elements: self.elements.iter().map(Expression::expand).collect(),
439            rows: self.rows,
440            cols: self.cols,
441        }
442    }
443
444    // =========================================================================
445    // Evaluation
446    // =========================================================================
447
448    /// Evaluate all matrix elements with given variable values.
449    ///
450    /// # Errors
451    /// Returns error if any element evaluation fails.
452    pub fn eval(
453        &self,
454        values: &std::collections::HashMap<String, f64>,
455    ) -> SymEngineResult<Array2<f64>> {
456        let mut result = Array2::zeros((self.rows, self.cols));
457        for i in 0..self.rows {
458            for j in 0..self.cols {
459                result[[i, j]] = self.get(i, j).eval(values)?;
460            }
461        }
462        Ok(result)
463    }
464
465    /// Substitute a variable with an expression in all elements.
466    #[must_use]
467    pub fn substitute(&self, var: &Expression, value: &Expression) -> Self {
468        Self {
469            elements: self
470                .elements
471                .iter()
472                .map(|e| e.substitute(var, value))
473                .collect(),
474            rows: self.rows,
475            cols: self.cols,
476        }
477    }
478
479    // =========================================================================
480    // Differentiation
481    // =========================================================================
482
483    /// Compute the derivative of all elements with respect to a variable.
484    #[must_use]
485    pub fn diff(&self, var: &Expression) -> Self {
486        Self {
487            elements: self.elements.iter().map(|e| e.diff(var)).collect(),
488            rows: self.rows,
489            cols: self.cols,
490        }
491    }
492}
493
494impl fmt::Display for SymbolicMatrix {
495    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
496        writeln!(f, "[")?;
497        for i in 0..self.rows {
498            write!(f, "  [")?;
499            for j in 0..self.cols {
500                if j > 0 {
501                    write!(f, ", ")?;
502                }
503                write!(f, "{}", self.get(i, j))?;
504            }
505            writeln!(f, "]")?;
506        }
507        write!(f, "]")
508    }
509}
510
511impl std::ops::Index<(usize, usize)> for SymbolicMatrix {
512    type Output = Expression;
513
514    fn index(&self, index: (usize, usize)) -> &Self::Output {
515        self.get(index.0, index.1)
516    }
517}
518
519impl std::ops::IndexMut<(usize, usize)> for SymbolicMatrix {
520    fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
521        self.get_mut(index.0, index.1)
522    }
523}
524
525// =========================================================================
526// Quantum Gate Matrices
527// =========================================================================
528
529/// Create the Pauli X gate matrix.
530#[must_use]
531pub fn pauli_x() -> SymbolicMatrix {
532    SymbolicMatrix::from_flat(
533        vec![
534            Expression::zero(),
535            Expression::one(),
536            Expression::one(),
537            Expression::zero(),
538        ],
539        2,
540        2,
541    )
542    .expect("valid 2x2 matrix")
543}
544
545/// Create the Pauli Y gate matrix.
546#[must_use]
547pub fn pauli_y() -> SymbolicMatrix {
548    let i = Expression::i();
549    SymbolicMatrix::from_flat(
550        vec![Expression::zero(), i.clone().neg(), i, Expression::zero()],
551        2,
552        2,
553    )
554    .expect("valid 2x2 matrix")
555}
556
557/// Create the Pauli Z gate matrix.
558#[must_use]
559pub fn pauli_z() -> SymbolicMatrix {
560    SymbolicMatrix::from_flat(
561        vec![
562            Expression::one(),
563            Expression::zero(),
564            Expression::zero(),
565            Expression::one().neg(),
566        ],
567        2,
568        2,
569    )
570    .expect("valid 2x2 matrix")
571}
572
573/// Create the Hadamard gate matrix.
574#[must_use]
575pub fn hadamard() -> SymbolicMatrix {
576    let sqrt2_inv = Expression::one() / crate::ops::trig::sqrt(&Expression::int(2));
577    SymbolicMatrix::from_flat(
578        vec![
579            sqrt2_inv.clone(),
580            sqrt2_inv.clone(),
581            sqrt2_inv.clone(),
582            sqrt2_inv.neg(),
583        ],
584        2,
585        2,
586    )
587    .expect("valid 2x2 matrix")
588}
589
590/// Create the S (phase) gate matrix.
591#[must_use]
592pub fn s_gate() -> SymbolicMatrix {
593    SymbolicMatrix::from_flat(
594        vec![
595            Expression::one(),
596            Expression::zero(),
597            Expression::zero(),
598            Expression::i(),
599        ],
600        2,
601        2,
602    )
603    .expect("valid 2x2 matrix")
604}
605
606/// Create the T gate matrix.
607#[must_use]
608pub fn t_gate() -> SymbolicMatrix {
609    let exp_i_pi_4 =
610        crate::ops::trig::exp(&(Expression::i() * Expression::pi() / Expression::int(4)));
611    SymbolicMatrix::from_flat(
612        vec![
613            Expression::one(),
614            Expression::zero(),
615            Expression::zero(),
616            exp_i_pi_4,
617        ],
618        2,
619        2,
620    )
621    .expect("valid 2x2 matrix")
622}
623
624/// Create a rotation gate Rx(θ) around the X axis.
625#[must_use]
626pub fn rx(theta: &Expression) -> SymbolicMatrix {
627    let half = Expression::float_unchecked(0.5);
628    let half_theta = theta.clone() * half;
629    let cos_half = crate::ops::trig::cos(&half_theta);
630    let sin_half = crate::ops::trig::sin(&half_theta);
631    let i = Expression::i();
632
633    SymbolicMatrix::from_flat(
634        vec![
635            cos_half.clone(),
636            i.clone().neg() * sin_half.clone(),
637            i.neg() * sin_half,
638            cos_half,
639        ],
640        2,
641        2,
642    )
643    .expect("valid 2x2 matrix")
644}
645
646/// Create a rotation gate Ry(θ) around the Y axis.
647#[must_use]
648pub fn ry(theta: &Expression) -> SymbolicMatrix {
649    let half = Expression::float_unchecked(0.5);
650    let half_theta = theta.clone() * half;
651    let cos_half = crate::ops::trig::cos(&half_theta);
652    let sin_half = crate::ops::trig::sin(&half_theta);
653
654    SymbolicMatrix::from_flat(
655        vec![cos_half.clone(), sin_half.clone().neg(), sin_half, cos_half],
656        2,
657        2,
658    )
659    .expect("valid 2x2 matrix")
660}
661
662/// Create a rotation gate Rz(θ) around the Z axis.
663#[must_use]
664pub fn rz(theta: &Expression) -> SymbolicMatrix {
665    let half = Expression::float_unchecked(0.5);
666    let i = Expression::i();
667    let half_theta = theta.clone() * half;
668    let exp_neg = crate::ops::trig::exp(&(i.neg() * half_theta.clone()));
669    let exp_pos = crate::ops::trig::exp(&(Expression::i() * half_theta));
670
671    SymbolicMatrix::from_flat(
672        vec![exp_neg, Expression::zero(), Expression::zero(), exp_pos],
673        2,
674        2,
675    )
676    .expect("valid 2x2 matrix")
677}
678
679/// Create the CNOT (CX) gate matrix.
680#[must_use]
681pub fn cnot() -> SymbolicMatrix {
682    SymbolicMatrix::from_flat(
683        vec![
684            Expression::one(),
685            Expression::zero(),
686            Expression::zero(),
687            Expression::zero(),
688            Expression::zero(),
689            Expression::one(),
690            Expression::zero(),
691            Expression::zero(),
692            Expression::zero(),
693            Expression::zero(),
694            Expression::zero(),
695            Expression::one(),
696            Expression::zero(),
697            Expression::zero(),
698            Expression::one(),
699            Expression::zero(),
700        ],
701        4,
702        4,
703    )
704    .expect("valid 4x4 matrix")
705}
706
707/// Create the SWAP gate matrix.
708#[must_use]
709pub fn swap() -> SymbolicMatrix {
710    SymbolicMatrix::from_flat(
711        vec![
712            Expression::one(),
713            Expression::zero(),
714            Expression::zero(),
715            Expression::zero(),
716            Expression::zero(),
717            Expression::zero(),
718            Expression::one(),
719            Expression::zero(),
720            Expression::zero(),
721            Expression::one(),
722            Expression::zero(),
723            Expression::zero(),
724            Expression::zero(),
725            Expression::zero(),
726            Expression::zero(),
727            Expression::one(),
728        ],
729        4,
730        4,
731    )
732    .expect("valid 4x4 matrix")
733}
734
735/// Create a controlled-U gate matrix.
736#[must_use]
737pub fn controlled(u: &SymbolicMatrix) -> SymbolicMatrix {
738    assert!(u.is_square() && u.nrows() == 2, "U must be a 2x2 matrix");
739
740    let n = 4;
741    let mut elements = vec![Expression::zero(); n * n];
742
743    // Top-left 2x2 block is identity (control = |0⟩)
744    elements[0] = Expression::one();
745    elements[5] = Expression::one();
746
747    // Bottom-right 2x2 block is U (control = |1⟩)
748    elements[10] = u.get(0, 0).clone();
749    elements[11] = u.get(0, 1).clone();
750    elements[14] = u.get(1, 0).clone();
751    elements[15] = u.get(1, 1).clone();
752
753    SymbolicMatrix::from_flat(elements, n, n).expect("valid 4x4 matrix")
754}
755
756#[cfg(test)]
757#[allow(clippy::redundant_clone)]
758mod tests {
759    use super::*;
760    use std::collections::HashMap;
761
762    #[test]
763    fn test_matrix_creation() {
764        let m = SymbolicMatrix::identity(2);
765        assert_eq!(m.nrows(), 2);
766        assert_eq!(m.ncols(), 2);
767        assert!(m.get(0, 0).is_one());
768        assert!(m.get(0, 1).is_zero());
769        assert!(m.get(1, 0).is_zero());
770        assert!(m.get(1, 1).is_one());
771    }
772
773    #[test]
774    fn test_matrix_transpose() {
775        let x = Expression::symbol("x");
776        let y = Expression::symbol("y");
777        let z = Expression::symbol("z");
778        let w = Expression::symbol("w");
779
780        let m = SymbolicMatrix::new(vec![vec![x.clone(), y.clone()], vec![z.clone(), w.clone()]])
781            .expect("valid matrix");
782
783        let mt = m.transpose();
784        assert_eq!(mt.get(0, 0).as_symbol(), Some("x"));
785        assert_eq!(mt.get(0, 1).as_symbol(), Some("z"));
786        assert_eq!(mt.get(1, 0).as_symbol(), Some("y"));
787        assert_eq!(mt.get(1, 1).as_symbol(), Some("w"));
788    }
789
790    #[test]
791    fn test_matrix_multiplication() {
792        // Test with identity
793        let i = SymbolicMatrix::identity(2);
794        let x = Expression::symbol("x");
795        let m = SymbolicMatrix::new(vec![
796            vec![x.clone(), Expression::zero()],
797            vec![Expression::zero(), x.clone()],
798        ])
799        .expect("valid matrix");
800
801        let result = i.matmul(&m).expect("valid matmul");
802
803        // I * M = M - verify by evaluation
804        let mut values = HashMap::new();
805        values.insert("x".to_string(), 5.0);
806
807        // result[0,0] should evaluate to x = 5.0
808        let r00 = result.get(0, 0).eval(&values).expect("valid eval");
809        assert!((r00 - 5.0).abs() < 1e-10);
810
811        // result[0,1] should evaluate to 0
812        let r01 = result.get(0, 1).eval(&values).expect("valid eval");
813        assert!(r01.abs() < 1e-10);
814
815        // result[1,1] should evaluate to x = 5.0
816        let r11 = result.get(1, 1).eval(&values).expect("valid eval");
817        assert!((r11 - 5.0).abs() < 1e-10);
818    }
819
820    #[test]
821    fn test_kronecker_product() {
822        let x = pauli_x();
823        let z = pauli_z();
824
825        let xz = x.kron(&z);
826        assert_eq!(xz.nrows(), 4);
827        assert_eq!(xz.ncols(), 4);
828    }
829
830    #[test]
831    fn test_trace() {
832        let theta = Expression::symbol("theta");
833        let m = SymbolicMatrix::diagonal(vec![theta.clone(), theta.clone()]);
834        let tr = m.trace().expect("valid trace");
835
836        // tr(diag(θ, θ)) = 2θ
837        let mut values = HashMap::new();
838        values.insert("theta".to_string(), 3.0);
839        let result = tr.eval(&values).expect("valid eval");
840        assert!((result - 6.0).abs() < 1e-10);
841    }
842
843    #[test]
844    fn test_rotation_gates() {
845        let theta = Expression::symbol("theta");
846
847        // Rx(θ) should be unitary: Rx†Rx = I
848        let rx_gate = rx(&theta);
849        let rx_dag = rx_gate.dagger();
850
851        // Just verify structure - full verification would require complex eval
852        assert_eq!(rx_gate.nrows(), 2);
853        assert_eq!(rx_dag.nrows(), 2);
854    }
855
856    #[test]
857    fn test_pauli_commutation() {
858        let x = pauli_x();
859        let y = pauli_y();
860
861        // [X, Y] = 2iZ
862        let comm = x.commutator(&y).expect("valid commutator");
863        assert_eq!(comm.nrows(), 2);
864    }
865
866    #[test]
867    fn test_matrix_diff() {
868        let theta = Expression::symbol("theta");
869        let m = SymbolicMatrix::diagonal(vec![
870            crate::ops::trig::sin(&theta),
871            crate::ops::trig::cos(&theta),
872        ]);
873
874        let dm = m.diff(&theta);
875
876        // d/dθ sin(θ) = cos(θ), d/dθ cos(θ) = -sin(θ)
877        let mut values = HashMap::new();
878        values.insert("theta".to_string(), 0.0);
879
880        // At θ=0: d/dθ sin(θ)|θ=0 = cos(0) = 1
881        let result = dm.eval(&values).expect("valid eval");
882        assert!((result[[0, 0]] - 1.0).abs() < 1e-10);
883        // d/dθ cos(θ)|θ=0 = -sin(0) = 0
884        assert!(result[[1, 1]].abs() < 1e-10);
885    }
886}