1use std::fmt;
8
9use scirs2_core::ndarray::{Array1, Array2};
10use scirs2_core::Complex64;
11
12use crate::error::{SymEngineError, SymEngineResult};
13use crate::expr::Expression;
14
15#[derive(Clone, Debug)]
20pub struct SymbolicMatrix {
21 elements: Vec<Expression>,
23 rows: usize,
25 cols: usize,
27}
28
29impl SymbolicMatrix {
30 pub fn new(elements: Vec<Vec<Expression>>) -> SymEngineResult<Self> {
35 if elements.is_empty() {
36 return Err(SymEngineError::dimension("Matrix cannot be empty"));
37 }
38
39 let rows = elements.len();
40 let cols = elements[0].len();
41
42 for (i, row) in elements.iter().enumerate() {
44 if row.len() != cols {
45 return Err(SymEngineError::dimension(format!(
46 "Row {i} has {} columns, expected {cols}",
47 row.len()
48 )));
49 }
50 }
51
52 let flat: Vec<Expression> = elements.into_iter().flatten().collect();
53
54 Ok(Self {
55 elements: flat,
56 rows,
57 cols,
58 })
59 }
60
61 pub fn from_flat(elements: Vec<Expression>, rows: usize, cols: usize) -> SymEngineResult<Self> {
66 if elements.len() != rows * cols {
67 return Err(SymEngineError::dimension(format!(
68 "Expected {} elements for {}x{} matrix, got {}",
69 rows * cols,
70 rows,
71 cols,
72 elements.len()
73 )));
74 }
75
76 Ok(Self {
77 elements,
78 rows,
79 cols,
80 })
81 }
82
83 #[must_use]
85 pub fn zeros(rows: usize, cols: usize) -> Self {
86 Self {
87 elements: vec![Expression::zero(); rows * cols],
88 rows,
89 cols,
90 }
91 }
92
93 #[must_use]
95 pub fn identity(n: usize) -> Self {
96 let mut elements = vec![Expression::zero(); n * n];
97 for i in 0..n {
98 elements[i * n + i] = Expression::one();
99 }
100 Self {
101 elements,
102 rows: n,
103 cols: n,
104 }
105 }
106
107 #[must_use]
109 pub fn diagonal(diag: Vec<Expression>) -> Self {
110 let n = diag.len();
111 let mut elements = vec![Expression::zero(); n * n];
112 for (i, d) in diag.into_iter().enumerate() {
113 elements[i * n + i] = d;
114 }
115 Self {
116 elements,
117 rows: n,
118 cols: n,
119 }
120 }
121
122 #[must_use]
124 pub fn from_array(arr: &Array2<f64>) -> Self {
125 let rows = arr.nrows();
126 let cols = arr.ncols();
127 let elements: Vec<Expression> = arr
128 .iter()
129 .map(|&v| Expression::float_unchecked(v))
130 .collect();
131 Self {
132 elements,
133 rows,
134 cols,
135 }
136 }
137
138 #[must_use]
140 pub fn from_complex_array(arr: &Array2<Complex64>) -> Self {
141 let rows = arr.nrows();
142 let cols = arr.ncols();
143 let elements: Vec<Expression> =
144 arr.iter().map(|&c| Expression::from_complex64(c)).collect();
145 Self {
146 elements,
147 rows,
148 cols,
149 }
150 }
151
152 #[must_use]
158 pub const fn nrows(&self) -> usize {
159 self.rows
160 }
161
162 #[must_use]
164 pub const fn ncols(&self) -> usize {
165 self.cols
166 }
167
168 #[must_use]
170 pub const fn shape(&self) -> (usize, usize) {
171 (self.rows, self.cols)
172 }
173
174 #[must_use]
176 pub const fn is_square(&self) -> bool {
177 self.rows == self.cols
178 }
179
180 #[must_use]
185 pub fn get(&self, i: usize, j: usize) -> &Expression {
186 assert!(i < self.rows && j < self.cols, "Index out of bounds");
187 &self.elements[i * self.cols + j]
188 }
189
190 pub fn get_mut(&mut self, i: usize, j: usize) -> &mut Expression {
195 assert!(i < self.rows && j < self.cols, "Index out of bounds");
196 &mut self.elements[i * self.cols + j]
197 }
198
199 pub fn set(&mut self, i: usize, j: usize, value: Expression) {
204 assert!(i < self.rows && j < self.cols, "Index out of bounds");
205 self.elements[i * self.cols + j] = value;
206 }
207
208 #[must_use]
210 pub fn row(&self, i: usize) -> Vec<Expression> {
211 assert!(i < self.rows, "Row index out of bounds");
212 let start = i * self.cols;
213 self.elements[start..start + self.cols].to_vec()
214 }
215
216 #[must_use]
218 pub fn col(&self, j: usize) -> Vec<Expression> {
219 assert!(j < self.cols, "Column index out of bounds");
220 (0..self.rows).map(|i| self.get(i, j).clone()).collect()
221 }
222
223 #[must_use]
229 pub fn transpose(&self) -> Self {
230 let mut elements = Vec::with_capacity(self.rows * self.cols);
231 for j in 0..self.cols {
232 for i in 0..self.rows {
233 elements.push(self.get(i, j).clone());
234 }
235 }
236 Self {
237 elements,
238 rows: self.cols,
239 cols: self.rows,
240 }
241 }
242
243 #[must_use]
245 pub fn conjugate(&self) -> Self {
246 Self {
247 elements: self.elements.iter().map(Expression::conjugate).collect(),
248 rows: self.rows,
249 cols: self.cols,
250 }
251 }
252
253 #[must_use]
255 pub fn dagger(&self) -> Self {
256 self.transpose().conjugate()
257 }
258
259 pub fn add(&self, other: &Self) -> SymEngineResult<Self> {
264 if self.rows != other.rows || self.cols != other.cols {
265 return Err(SymEngineError::dimension(format!(
266 "Cannot add {}x{} matrix with {}x{} matrix",
267 self.rows, self.cols, other.rows, other.cols
268 )));
269 }
270
271 let elements: Vec<Expression> = self
272 .elements
273 .iter()
274 .zip(other.elements.iter())
275 .map(|(a, b)| a.clone() + b.clone())
276 .collect();
277
278 Ok(Self {
279 elements,
280 rows: self.rows,
281 cols: self.cols,
282 })
283 }
284
285 pub fn sub(&self, other: &Self) -> SymEngineResult<Self> {
290 if self.rows != other.rows || self.cols != other.cols {
291 return Err(SymEngineError::dimension(format!(
292 "Cannot subtract {}x{} matrix from {}x{} matrix",
293 other.rows, other.cols, self.rows, self.cols
294 )));
295 }
296
297 let elements: Vec<Expression> = self
298 .elements
299 .iter()
300 .zip(other.elements.iter())
301 .map(|(a, b)| a.clone() - b.clone())
302 .collect();
303
304 Ok(Self {
305 elements,
306 rows: self.rows,
307 cols: self.cols,
308 })
309 }
310
311 pub fn matmul(&self, other: &Self) -> SymEngineResult<Self> {
316 if self.cols != other.rows {
317 return Err(SymEngineError::dimension(format!(
318 "Cannot multiply {}x{} matrix with {}x{} matrix",
319 self.rows, self.cols, other.rows, other.cols
320 )));
321 }
322
323 let mut elements = Vec::with_capacity(self.rows * other.cols);
324
325 for i in 0..self.rows {
326 for j in 0..other.cols {
327 let mut sum = Expression::zero();
328 for k in 0..self.cols {
329 sum = sum + self.get(i, k).clone() * other.get(k, j).clone();
330 }
331 elements.push(sum);
332 }
333 }
334
335 Ok(Self {
336 elements,
337 rows: self.rows,
338 cols: other.cols,
339 })
340 }
341
342 #[must_use]
344 pub fn scale(&self, scalar: &Expression) -> Self {
345 Self {
346 elements: self
347 .elements
348 .iter()
349 .map(|e| e.clone() * scalar.clone())
350 .collect(),
351 rows: self.rows,
352 cols: self.cols,
353 }
354 }
355
356 #[must_use]
358 pub fn kron(&self, other: &Self) -> Self {
359 let new_rows = self.rows * other.rows;
360 let new_cols = self.cols * other.cols;
361 let mut elements = Vec::with_capacity(new_rows * new_cols);
362
363 for i1 in 0..self.rows {
364 for i2 in 0..other.rows {
365 for j1 in 0..self.cols {
366 for j2 in 0..other.cols {
367 let a = self.get(i1, j1).clone();
368 let b = other.get(i2, j2).clone();
369 elements.push(a * b);
370 }
371 }
372 }
373 }
374
375 Self {
376 elements,
377 rows: new_rows,
378 cols: new_cols,
379 }
380 }
381
382 pub fn trace(&self) -> SymEngineResult<Expression> {
387 if !self.is_square() {
388 return Err(SymEngineError::dimension(
389 "Trace is only defined for square matrices",
390 ));
391 }
392
393 let mut sum = Expression::zero();
394 for i in 0..self.rows {
395 sum = sum + self.get(i, i).clone();
396 }
397 Ok(sum)
398 }
399
400 pub fn commutator(&self, other: &Self) -> SymEngineResult<Self> {
405 let ab = self.matmul(other)?;
406 let ba = other.matmul(self)?;
407 ab.sub(&ba)
408 }
409
410 pub fn anticommutator(&self, other: &Self) -> SymEngineResult<Self> {
415 let ab = self.matmul(other)?;
416 let ba = other.matmul(self)?;
417 ab.add(&ba)
418 }
419
420 #[must_use]
426 pub fn simplify(&self) -> Self {
427 Self {
428 elements: self.elements.iter().map(Expression::simplify).collect(),
429 rows: self.rows,
430 cols: self.cols,
431 }
432 }
433
434 #[must_use]
436 pub fn expand(&self) -> Self {
437 Self {
438 elements: self.elements.iter().map(Expression::expand).collect(),
439 rows: self.rows,
440 cols: self.cols,
441 }
442 }
443
444 pub fn eval(
453 &self,
454 values: &std::collections::HashMap<String, f64>,
455 ) -> SymEngineResult<Array2<f64>> {
456 let mut result = Array2::zeros((self.rows, self.cols));
457 for i in 0..self.rows {
458 for j in 0..self.cols {
459 result[[i, j]] = self.get(i, j).eval(values)?;
460 }
461 }
462 Ok(result)
463 }
464
465 #[must_use]
467 pub fn substitute(&self, var: &Expression, value: &Expression) -> Self {
468 Self {
469 elements: self
470 .elements
471 .iter()
472 .map(|e| e.substitute(var, value))
473 .collect(),
474 rows: self.rows,
475 cols: self.cols,
476 }
477 }
478
479 #[must_use]
485 pub fn diff(&self, var: &Expression) -> Self {
486 Self {
487 elements: self.elements.iter().map(|e| e.diff(var)).collect(),
488 rows: self.rows,
489 cols: self.cols,
490 }
491 }
492}
493
494impl fmt::Display for SymbolicMatrix {
495 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
496 writeln!(f, "[")?;
497 for i in 0..self.rows {
498 write!(f, " [")?;
499 for j in 0..self.cols {
500 if j > 0 {
501 write!(f, ", ")?;
502 }
503 write!(f, "{}", self.get(i, j))?;
504 }
505 writeln!(f, "]")?;
506 }
507 write!(f, "]")
508 }
509}
510
511impl std::ops::Index<(usize, usize)> for SymbolicMatrix {
512 type Output = Expression;
513
514 fn index(&self, index: (usize, usize)) -> &Self::Output {
515 self.get(index.0, index.1)
516 }
517}
518
519impl std::ops::IndexMut<(usize, usize)> for SymbolicMatrix {
520 fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
521 self.get_mut(index.0, index.1)
522 }
523}
524
525#[must_use]
531pub fn pauli_x() -> SymbolicMatrix {
532 SymbolicMatrix::from_flat(
533 vec![
534 Expression::zero(),
535 Expression::one(),
536 Expression::one(),
537 Expression::zero(),
538 ],
539 2,
540 2,
541 )
542 .expect("valid 2x2 matrix")
543}
544
545#[must_use]
547pub fn pauli_y() -> SymbolicMatrix {
548 let i = Expression::i();
549 SymbolicMatrix::from_flat(
550 vec![Expression::zero(), i.clone().neg(), i, Expression::zero()],
551 2,
552 2,
553 )
554 .expect("valid 2x2 matrix")
555}
556
557#[must_use]
559pub fn pauli_z() -> SymbolicMatrix {
560 SymbolicMatrix::from_flat(
561 vec![
562 Expression::one(),
563 Expression::zero(),
564 Expression::zero(),
565 Expression::one().neg(),
566 ],
567 2,
568 2,
569 )
570 .expect("valid 2x2 matrix")
571}
572
573#[must_use]
575pub fn hadamard() -> SymbolicMatrix {
576 let sqrt2_inv = Expression::one() / crate::ops::trig::sqrt(&Expression::int(2));
577 SymbolicMatrix::from_flat(
578 vec![
579 sqrt2_inv.clone(),
580 sqrt2_inv.clone(),
581 sqrt2_inv.clone(),
582 sqrt2_inv.neg(),
583 ],
584 2,
585 2,
586 )
587 .expect("valid 2x2 matrix")
588}
589
590#[must_use]
592pub fn s_gate() -> SymbolicMatrix {
593 SymbolicMatrix::from_flat(
594 vec![
595 Expression::one(),
596 Expression::zero(),
597 Expression::zero(),
598 Expression::i(),
599 ],
600 2,
601 2,
602 )
603 .expect("valid 2x2 matrix")
604}
605
606#[must_use]
608pub fn t_gate() -> SymbolicMatrix {
609 let exp_i_pi_4 =
610 crate::ops::trig::exp(&(Expression::i() * Expression::pi() / Expression::int(4)));
611 SymbolicMatrix::from_flat(
612 vec![
613 Expression::one(),
614 Expression::zero(),
615 Expression::zero(),
616 exp_i_pi_4,
617 ],
618 2,
619 2,
620 )
621 .expect("valid 2x2 matrix")
622}
623
624#[must_use]
626pub fn rx(theta: &Expression) -> SymbolicMatrix {
627 let half = Expression::float_unchecked(0.5);
628 let half_theta = theta.clone() * half;
629 let cos_half = crate::ops::trig::cos(&half_theta);
630 let sin_half = crate::ops::trig::sin(&half_theta);
631 let i = Expression::i();
632
633 SymbolicMatrix::from_flat(
634 vec![
635 cos_half.clone(),
636 i.clone().neg() * sin_half.clone(),
637 i.neg() * sin_half,
638 cos_half,
639 ],
640 2,
641 2,
642 )
643 .expect("valid 2x2 matrix")
644}
645
646#[must_use]
648pub fn ry(theta: &Expression) -> SymbolicMatrix {
649 let half = Expression::float_unchecked(0.5);
650 let half_theta = theta.clone() * half;
651 let cos_half = crate::ops::trig::cos(&half_theta);
652 let sin_half = crate::ops::trig::sin(&half_theta);
653
654 SymbolicMatrix::from_flat(
655 vec![cos_half.clone(), sin_half.clone().neg(), sin_half, cos_half],
656 2,
657 2,
658 )
659 .expect("valid 2x2 matrix")
660}
661
662#[must_use]
664pub fn rz(theta: &Expression) -> SymbolicMatrix {
665 let half = Expression::float_unchecked(0.5);
666 let i = Expression::i();
667 let half_theta = theta.clone() * half;
668 let exp_neg = crate::ops::trig::exp(&(i.neg() * half_theta.clone()));
669 let exp_pos = crate::ops::trig::exp(&(Expression::i() * half_theta));
670
671 SymbolicMatrix::from_flat(
672 vec![exp_neg, Expression::zero(), Expression::zero(), exp_pos],
673 2,
674 2,
675 )
676 .expect("valid 2x2 matrix")
677}
678
679#[must_use]
681pub fn cnot() -> SymbolicMatrix {
682 SymbolicMatrix::from_flat(
683 vec![
684 Expression::one(),
685 Expression::zero(),
686 Expression::zero(),
687 Expression::zero(),
688 Expression::zero(),
689 Expression::one(),
690 Expression::zero(),
691 Expression::zero(),
692 Expression::zero(),
693 Expression::zero(),
694 Expression::zero(),
695 Expression::one(),
696 Expression::zero(),
697 Expression::zero(),
698 Expression::one(),
699 Expression::zero(),
700 ],
701 4,
702 4,
703 )
704 .expect("valid 4x4 matrix")
705}
706
707#[must_use]
709pub fn swap() -> SymbolicMatrix {
710 SymbolicMatrix::from_flat(
711 vec![
712 Expression::one(),
713 Expression::zero(),
714 Expression::zero(),
715 Expression::zero(),
716 Expression::zero(),
717 Expression::zero(),
718 Expression::one(),
719 Expression::zero(),
720 Expression::zero(),
721 Expression::one(),
722 Expression::zero(),
723 Expression::zero(),
724 Expression::zero(),
725 Expression::zero(),
726 Expression::zero(),
727 Expression::one(),
728 ],
729 4,
730 4,
731 )
732 .expect("valid 4x4 matrix")
733}
734
735#[must_use]
737pub fn controlled(u: &SymbolicMatrix) -> SymbolicMatrix {
738 assert!(u.is_square() && u.nrows() == 2, "U must be a 2x2 matrix");
739
740 let n = 4;
741 let mut elements = vec![Expression::zero(); n * n];
742
743 elements[0] = Expression::one();
745 elements[5] = Expression::one();
746
747 elements[10] = u.get(0, 0).clone();
749 elements[11] = u.get(0, 1).clone();
750 elements[14] = u.get(1, 0).clone();
751 elements[15] = u.get(1, 1).clone();
752
753 SymbolicMatrix::from_flat(elements, n, n).expect("valid 4x4 matrix")
754}
755
756#[cfg(test)]
757#[allow(clippy::redundant_clone)]
758mod tests {
759 use super::*;
760 use std::collections::HashMap;
761
762 #[test]
763 fn test_matrix_creation() {
764 let m = SymbolicMatrix::identity(2);
765 assert_eq!(m.nrows(), 2);
766 assert_eq!(m.ncols(), 2);
767 assert!(m.get(0, 0).is_one());
768 assert!(m.get(0, 1).is_zero());
769 assert!(m.get(1, 0).is_zero());
770 assert!(m.get(1, 1).is_one());
771 }
772
773 #[test]
774 fn test_matrix_transpose() {
775 let x = Expression::symbol("x");
776 let y = Expression::symbol("y");
777 let z = Expression::symbol("z");
778 let w = Expression::symbol("w");
779
780 let m = SymbolicMatrix::new(vec![vec![x.clone(), y.clone()], vec![z.clone(), w.clone()]])
781 .expect("valid matrix");
782
783 let mt = m.transpose();
784 assert_eq!(mt.get(0, 0).as_symbol(), Some("x"));
785 assert_eq!(mt.get(0, 1).as_symbol(), Some("z"));
786 assert_eq!(mt.get(1, 0).as_symbol(), Some("y"));
787 assert_eq!(mt.get(1, 1).as_symbol(), Some("w"));
788 }
789
790 #[test]
791 fn test_matrix_multiplication() {
792 let i = SymbolicMatrix::identity(2);
794 let x = Expression::symbol("x");
795 let m = SymbolicMatrix::new(vec![
796 vec![x.clone(), Expression::zero()],
797 vec![Expression::zero(), x.clone()],
798 ])
799 .expect("valid matrix");
800
801 let result = i.matmul(&m).expect("valid matmul");
802
803 let mut values = HashMap::new();
805 values.insert("x".to_string(), 5.0);
806
807 let r00 = result.get(0, 0).eval(&values).expect("valid eval");
809 assert!((r00 - 5.0).abs() < 1e-10);
810
811 let r01 = result.get(0, 1).eval(&values).expect("valid eval");
813 assert!(r01.abs() < 1e-10);
814
815 let r11 = result.get(1, 1).eval(&values).expect("valid eval");
817 assert!((r11 - 5.0).abs() < 1e-10);
818 }
819
820 #[test]
821 fn test_kronecker_product() {
822 let x = pauli_x();
823 let z = pauli_z();
824
825 let xz = x.kron(&z);
826 assert_eq!(xz.nrows(), 4);
827 assert_eq!(xz.ncols(), 4);
828 }
829
830 #[test]
831 fn test_trace() {
832 let theta = Expression::symbol("theta");
833 let m = SymbolicMatrix::diagonal(vec![theta.clone(), theta.clone()]);
834 let tr = m.trace().expect("valid trace");
835
836 let mut values = HashMap::new();
838 values.insert("theta".to_string(), 3.0);
839 let result = tr.eval(&values).expect("valid eval");
840 assert!((result - 6.0).abs() < 1e-10);
841 }
842
843 #[test]
844 fn test_rotation_gates() {
845 let theta = Expression::symbol("theta");
846
847 let rx_gate = rx(&theta);
849 let rx_dag = rx_gate.dagger();
850
851 assert_eq!(rx_gate.nrows(), 2);
853 assert_eq!(rx_dag.nrows(), 2);
854 }
855
856 #[test]
857 fn test_pauli_commutation() {
858 let x = pauli_x();
859 let y = pauli_y();
860
861 let comm = x.commutator(&y).expect("valid commutator");
863 assert_eq!(comm.nrows(), 2);
864 }
865
866 #[test]
867 fn test_matrix_diff() {
868 let theta = Expression::symbol("theta");
869 let m = SymbolicMatrix::diagonal(vec![
870 crate::ops::trig::sin(&theta),
871 crate::ops::trig::cos(&theta),
872 ]);
873
874 let dm = m.diff(&theta);
875
876 let mut values = HashMap::new();
878 values.insert("theta".to_string(), 0.0);
879
880 let result = dm.eval(&values).expect("valid eval");
882 assert!((result[[0, 0]] - 1.0).abs() < 1e-10);
883 assert!(result[[1, 1]].abs() < 1e-10);
885 }
886}