use std::fmt;
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::Complex64;
use crate::error::{SymEngineError, SymEngineResult};
use crate::expr::Expression;
#[derive(Clone, Debug)]
pub struct SymbolicMatrix {
elements: Vec<Expression>,
rows: usize,
cols: usize,
}
impl SymbolicMatrix {
pub fn new(elements: Vec<Vec<Expression>>) -> SymEngineResult<Self> {
if elements.is_empty() {
return Err(SymEngineError::dimension("Matrix cannot be empty"));
}
let rows = elements.len();
let cols = elements[0].len();
for (i, row) in elements.iter().enumerate() {
if row.len() != cols {
return Err(SymEngineError::dimension(format!(
"Row {i} has {} columns, expected {cols}",
row.len()
)));
}
}
let flat: Vec<Expression> = elements.into_iter().flatten().collect();
Ok(Self {
elements: flat,
rows,
cols,
})
}
pub fn from_flat(elements: Vec<Expression>, rows: usize, cols: usize) -> SymEngineResult<Self> {
if elements.len() != rows * cols {
return Err(SymEngineError::dimension(format!(
"Expected {} elements for {}x{} matrix, got {}",
rows * cols,
rows,
cols,
elements.len()
)));
}
Ok(Self {
elements,
rows,
cols,
})
}
#[must_use]
pub fn zeros(rows: usize, cols: usize) -> Self {
Self {
elements: vec![Expression::zero(); rows * cols],
rows,
cols,
}
}
#[must_use]
pub fn identity(n: usize) -> Self {
let mut elements = vec![Expression::zero(); n * n];
for i in 0..n {
elements[i * n + i] = Expression::one();
}
Self {
elements,
rows: n,
cols: n,
}
}
#[must_use]
pub fn diagonal(diag: Vec<Expression>) -> Self {
let n = diag.len();
let mut elements = vec![Expression::zero(); n * n];
for (i, d) in diag.into_iter().enumerate() {
elements[i * n + i] = d;
}
Self {
elements,
rows: n,
cols: n,
}
}
#[must_use]
pub fn from_array(arr: &Array2<f64>) -> Self {
let rows = arr.nrows();
let cols = arr.ncols();
let elements: Vec<Expression> = arr
.iter()
.map(|&v| Expression::float_unchecked(v))
.collect();
Self {
elements,
rows,
cols,
}
}
#[must_use]
pub fn from_complex_array(arr: &Array2<Complex64>) -> Self {
let rows = arr.nrows();
let cols = arr.ncols();
let elements: Vec<Expression> =
arr.iter().map(|&c| Expression::from_complex64(c)).collect();
Self {
elements,
rows,
cols,
}
}
#[must_use]
pub const fn nrows(&self) -> usize {
self.rows
}
#[must_use]
pub const fn ncols(&self) -> usize {
self.cols
}
#[must_use]
pub const fn shape(&self) -> (usize, usize) {
(self.rows, self.cols)
}
#[must_use]
pub const fn is_square(&self) -> bool {
self.rows == self.cols
}
#[must_use]
pub fn get(&self, i: usize, j: usize) -> &Expression {
assert!(i < self.rows && j < self.cols, "Index out of bounds");
&self.elements[i * self.cols + j]
}
pub fn get_mut(&mut self, i: usize, j: usize) -> &mut Expression {
assert!(i < self.rows && j < self.cols, "Index out of bounds");
&mut self.elements[i * self.cols + j]
}
pub fn set(&mut self, i: usize, j: usize, value: Expression) {
assert!(i < self.rows && j < self.cols, "Index out of bounds");
self.elements[i * self.cols + j] = value;
}
#[must_use]
pub fn row(&self, i: usize) -> Vec<Expression> {
assert!(i < self.rows, "Row index out of bounds");
let start = i * self.cols;
self.elements[start..start + self.cols].to_vec()
}
#[must_use]
pub fn col(&self, j: usize) -> Vec<Expression> {
assert!(j < self.cols, "Column index out of bounds");
(0..self.rows).map(|i| self.get(i, j).clone()).collect()
}
#[must_use]
pub fn transpose(&self) -> Self {
let mut elements = Vec::with_capacity(self.rows * self.cols);
for j in 0..self.cols {
for i in 0..self.rows {
elements.push(self.get(i, j).clone());
}
}
Self {
elements,
rows: self.cols,
cols: self.rows,
}
}
#[must_use]
pub fn conjugate(&self) -> Self {
Self {
elements: self.elements.iter().map(Expression::conjugate).collect(),
rows: self.rows,
cols: self.cols,
}
}
#[must_use]
pub fn dagger(&self) -> Self {
self.transpose().conjugate()
}
pub fn add(&self, other: &Self) -> SymEngineResult<Self> {
if self.rows != other.rows || self.cols != other.cols {
return Err(SymEngineError::dimension(format!(
"Cannot add {}x{} matrix with {}x{} matrix",
self.rows, self.cols, other.rows, other.cols
)));
}
let elements: Vec<Expression> = self
.elements
.iter()
.zip(other.elements.iter())
.map(|(a, b)| a.clone() + b.clone())
.collect();
Ok(Self {
elements,
rows: self.rows,
cols: self.cols,
})
}
pub fn sub(&self, other: &Self) -> SymEngineResult<Self> {
if self.rows != other.rows || self.cols != other.cols {
return Err(SymEngineError::dimension(format!(
"Cannot subtract {}x{} matrix from {}x{} matrix",
other.rows, other.cols, self.rows, self.cols
)));
}
let elements: Vec<Expression> = self
.elements
.iter()
.zip(other.elements.iter())
.map(|(a, b)| a.clone() - b.clone())
.collect();
Ok(Self {
elements,
rows: self.rows,
cols: self.cols,
})
}
pub fn matmul(&self, other: &Self) -> SymEngineResult<Self> {
if self.cols != other.rows {
return Err(SymEngineError::dimension(format!(
"Cannot multiply {}x{} matrix with {}x{} matrix",
self.rows, self.cols, other.rows, other.cols
)));
}
let mut elements = Vec::with_capacity(self.rows * other.cols);
for i in 0..self.rows {
for j in 0..other.cols {
let mut sum = Expression::zero();
for k in 0..self.cols {
sum = sum + self.get(i, k).clone() * other.get(k, j).clone();
}
elements.push(sum);
}
}
Ok(Self {
elements,
rows: self.rows,
cols: other.cols,
})
}
#[must_use]
pub fn scale(&self, scalar: &Expression) -> Self {
Self {
elements: self
.elements
.iter()
.map(|e| e.clone() * scalar.clone())
.collect(),
rows: self.rows,
cols: self.cols,
}
}
#[must_use]
pub fn kron(&self, other: &Self) -> Self {
let new_rows = self.rows * other.rows;
let new_cols = self.cols * other.cols;
let mut elements = Vec::with_capacity(new_rows * new_cols);
for i1 in 0..self.rows {
for i2 in 0..other.rows {
for j1 in 0..self.cols {
for j2 in 0..other.cols {
let a = self.get(i1, j1).clone();
let b = other.get(i2, j2).clone();
elements.push(a * b);
}
}
}
}
Self {
elements,
rows: new_rows,
cols: new_cols,
}
}
pub fn trace(&self) -> SymEngineResult<Expression> {
if !self.is_square() {
return Err(SymEngineError::dimension(
"Trace is only defined for square matrices",
));
}
let mut sum = Expression::zero();
for i in 0..self.rows {
sum = sum + self.get(i, i).clone();
}
Ok(sum)
}
pub fn commutator(&self, other: &Self) -> SymEngineResult<Self> {
let ab = self.matmul(other)?;
let ba = other.matmul(self)?;
ab.sub(&ba)
}
pub fn anticommutator(&self, other: &Self) -> SymEngineResult<Self> {
let ab = self.matmul(other)?;
let ba = other.matmul(self)?;
ab.add(&ba)
}
#[must_use]
pub fn simplify(&self) -> Self {
Self {
elements: self.elements.iter().map(Expression::simplify).collect(),
rows: self.rows,
cols: self.cols,
}
}
#[must_use]
pub fn expand(&self) -> Self {
Self {
elements: self.elements.iter().map(Expression::expand).collect(),
rows: self.rows,
cols: self.cols,
}
}
pub fn eval(
&self,
values: &std::collections::HashMap<String, f64>,
) -> SymEngineResult<Array2<f64>> {
let mut result = Array2::zeros((self.rows, self.cols));
for i in 0..self.rows {
for j in 0..self.cols {
result[[i, j]] = self.get(i, j).eval(values)?;
}
}
Ok(result)
}
#[must_use]
pub fn substitute(&self, var: &Expression, value: &Expression) -> Self {
Self {
elements: self
.elements
.iter()
.map(|e| e.substitute(var, value))
.collect(),
rows: self.rows,
cols: self.cols,
}
}
#[must_use]
pub fn diff(&self, var: &Expression) -> Self {
Self {
elements: self.elements.iter().map(|e| e.diff(var)).collect(),
rows: self.rows,
cols: self.cols,
}
}
}
impl fmt::Display for SymbolicMatrix {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "[")?;
for i in 0..self.rows {
write!(f, " [")?;
for j in 0..self.cols {
if j > 0 {
write!(f, ", ")?;
}
write!(f, "{}", self.get(i, j))?;
}
writeln!(f, "]")?;
}
write!(f, "]")
}
}
impl std::ops::Index<(usize, usize)> for SymbolicMatrix {
type Output = Expression;
fn index(&self, index: (usize, usize)) -> &Self::Output {
self.get(index.0, index.1)
}
}
impl std::ops::IndexMut<(usize, usize)> for SymbolicMatrix {
fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
self.get_mut(index.0, index.1)
}
}
#[must_use]
pub fn pauli_x() -> SymbolicMatrix {
SymbolicMatrix::from_flat(
vec![
Expression::zero(),
Expression::one(),
Expression::one(),
Expression::zero(),
],
2,
2,
)
.expect("valid 2x2 matrix")
}
#[must_use]
pub fn pauli_y() -> SymbolicMatrix {
let i = Expression::i();
SymbolicMatrix::from_flat(
vec![Expression::zero(), i.clone().neg(), i, Expression::zero()],
2,
2,
)
.expect("valid 2x2 matrix")
}
#[must_use]
pub fn pauli_z() -> SymbolicMatrix {
SymbolicMatrix::from_flat(
vec![
Expression::one(),
Expression::zero(),
Expression::zero(),
Expression::one().neg(),
],
2,
2,
)
.expect("valid 2x2 matrix")
}
#[must_use]
pub fn hadamard() -> SymbolicMatrix {
let sqrt2_inv = Expression::one() / crate::ops::trig::sqrt(&Expression::int(2));
SymbolicMatrix::from_flat(
vec![
sqrt2_inv.clone(),
sqrt2_inv.clone(),
sqrt2_inv.clone(),
sqrt2_inv.neg(),
],
2,
2,
)
.expect("valid 2x2 matrix")
}
#[must_use]
pub fn s_gate() -> SymbolicMatrix {
SymbolicMatrix::from_flat(
vec![
Expression::one(),
Expression::zero(),
Expression::zero(),
Expression::i(),
],
2,
2,
)
.expect("valid 2x2 matrix")
}
#[must_use]
pub fn t_gate() -> SymbolicMatrix {
let exp_i_pi_4 =
crate::ops::trig::exp(&(Expression::i() * Expression::pi() / Expression::int(4)));
SymbolicMatrix::from_flat(
vec![
Expression::one(),
Expression::zero(),
Expression::zero(),
exp_i_pi_4,
],
2,
2,
)
.expect("valid 2x2 matrix")
}
#[must_use]
pub fn rx(theta: &Expression) -> SymbolicMatrix {
let half = Expression::float_unchecked(0.5);
let half_theta = theta.clone() * half;
let cos_half = crate::ops::trig::cos(&half_theta);
let sin_half = crate::ops::trig::sin(&half_theta);
let i = Expression::i();
SymbolicMatrix::from_flat(
vec![
cos_half.clone(),
i.clone().neg() * sin_half.clone(),
i.neg() * sin_half,
cos_half,
],
2,
2,
)
.expect("valid 2x2 matrix")
}
#[must_use]
pub fn ry(theta: &Expression) -> SymbolicMatrix {
let half = Expression::float_unchecked(0.5);
let half_theta = theta.clone() * half;
let cos_half = crate::ops::trig::cos(&half_theta);
let sin_half = crate::ops::trig::sin(&half_theta);
SymbolicMatrix::from_flat(
vec![cos_half.clone(), sin_half.clone().neg(), sin_half, cos_half],
2,
2,
)
.expect("valid 2x2 matrix")
}
#[must_use]
pub fn rz(theta: &Expression) -> SymbolicMatrix {
let half = Expression::float_unchecked(0.5);
let i = Expression::i();
let half_theta = theta.clone() * half;
let exp_neg = crate::ops::trig::exp(&(i.neg() * half_theta.clone()));
let exp_pos = crate::ops::trig::exp(&(Expression::i() * half_theta));
SymbolicMatrix::from_flat(
vec![exp_neg, Expression::zero(), Expression::zero(), exp_pos],
2,
2,
)
.expect("valid 2x2 matrix")
}
#[must_use]
pub fn cnot() -> SymbolicMatrix {
SymbolicMatrix::from_flat(
vec![
Expression::one(),
Expression::zero(),
Expression::zero(),
Expression::zero(),
Expression::zero(),
Expression::one(),
Expression::zero(),
Expression::zero(),
Expression::zero(),
Expression::zero(),
Expression::zero(),
Expression::one(),
Expression::zero(),
Expression::zero(),
Expression::one(),
Expression::zero(),
],
4,
4,
)
.expect("valid 4x4 matrix")
}
#[must_use]
pub fn swap() -> SymbolicMatrix {
SymbolicMatrix::from_flat(
vec![
Expression::one(),
Expression::zero(),
Expression::zero(),
Expression::zero(),
Expression::zero(),
Expression::zero(),
Expression::one(),
Expression::zero(),
Expression::zero(),
Expression::one(),
Expression::zero(),
Expression::zero(),
Expression::zero(),
Expression::zero(),
Expression::zero(),
Expression::one(),
],
4,
4,
)
.expect("valid 4x4 matrix")
}
#[must_use]
pub fn controlled(u: &SymbolicMatrix) -> SymbolicMatrix {
assert!(u.is_square() && u.nrows() == 2, "U must be a 2x2 matrix");
let n = 4;
let mut elements = vec![Expression::zero(); n * n];
elements[0] = Expression::one();
elements[5] = Expression::one();
elements[10] = u.get(0, 0).clone();
elements[11] = u.get(0, 1).clone();
elements[14] = u.get(1, 0).clone();
elements[15] = u.get(1, 1).clone();
SymbolicMatrix::from_flat(elements, n, n).expect("valid 4x4 matrix")
}
#[cfg(test)]
#[allow(clippy::redundant_clone)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_matrix_creation() {
let m = SymbolicMatrix::identity(2);
assert_eq!(m.nrows(), 2);
assert_eq!(m.ncols(), 2);
assert!(m.get(0, 0).is_one());
assert!(m.get(0, 1).is_zero());
assert!(m.get(1, 0).is_zero());
assert!(m.get(1, 1).is_one());
}
#[test]
fn test_matrix_transpose() {
let x = Expression::symbol("x");
let y = Expression::symbol("y");
let z = Expression::symbol("z");
let w = Expression::symbol("w");
let m = SymbolicMatrix::new(vec![vec![x.clone(), y.clone()], vec![z.clone(), w.clone()]])
.expect("valid matrix");
let mt = m.transpose();
assert_eq!(mt.get(0, 0).as_symbol(), Some("x"));
assert_eq!(mt.get(0, 1).as_symbol(), Some("z"));
assert_eq!(mt.get(1, 0).as_symbol(), Some("y"));
assert_eq!(mt.get(1, 1).as_symbol(), Some("w"));
}
#[test]
fn test_matrix_multiplication() {
let i = SymbolicMatrix::identity(2);
let x = Expression::symbol("x");
let m = SymbolicMatrix::new(vec![
vec![x.clone(), Expression::zero()],
vec![Expression::zero(), x.clone()],
])
.expect("valid matrix");
let result = i.matmul(&m).expect("valid matmul");
let mut values = HashMap::new();
values.insert("x".to_string(), 5.0);
let r00 = result.get(0, 0).eval(&values).expect("valid eval");
assert!((r00 - 5.0).abs() < 1e-10);
let r01 = result.get(0, 1).eval(&values).expect("valid eval");
assert!(r01.abs() < 1e-10);
let r11 = result.get(1, 1).eval(&values).expect("valid eval");
assert!((r11 - 5.0).abs() < 1e-10);
}
#[test]
fn test_kronecker_product() {
let x = pauli_x();
let z = pauli_z();
let xz = x.kron(&z);
assert_eq!(xz.nrows(), 4);
assert_eq!(xz.ncols(), 4);
}
#[test]
fn test_trace() {
let theta = Expression::symbol("theta");
let m = SymbolicMatrix::diagonal(vec![theta.clone(), theta.clone()]);
let tr = m.trace().expect("valid trace");
let mut values = HashMap::new();
values.insert("theta".to_string(), 3.0);
let result = tr.eval(&values).expect("valid eval");
assert!((result - 6.0).abs() < 1e-10);
}
#[test]
fn test_rotation_gates() {
let theta = Expression::symbol("theta");
let rx_gate = rx(&theta);
let rx_dag = rx_gate.dagger();
assert_eq!(rx_gate.nrows(), 2);
assert_eq!(rx_dag.nrows(), 2);
}
#[test]
fn test_pauli_commutation() {
let x = pauli_x();
let y = pauli_y();
let comm = x.commutator(&y).expect("valid commutator");
assert_eq!(comm.nrows(), 2);
}
#[test]
fn test_matrix_diff() {
let theta = Expression::symbol("theta");
let m = SymbolicMatrix::diagonal(vec![
crate::ops::trig::sin(&theta),
crate::ops::trig::cos(&theta),
]);
let dm = m.diff(&theta);
let mut values = HashMap::new();
values.insert("theta".to_string(), 0.0);
let result = dm.eval(&values).expect("valid eval");
assert!((result[[0, 0]] - 1.0).abs() < 1e-10);
assert!(result[[1, 1]].abs() < 1e-10);
}
}