use scirs2_core::ndarray::{Array2, ArrayView2};
use scirs2_core::Complex64;
pub fn apply_unitary(
unitary: &ArrayView2<Complex64>,
state: &mut [Complex64],
) -> Result<(), String> {
let n = state.len();
if unitary.shape() != [n, n] {
return Err(format!(
"Unitary matrix shape {:?} doesn't match state dimension {}",
unitary.shape(),
n
));
}
let mut result = vec![Complex64::new(0.0, 0.0); n];
#[cfg(feature = "advanced_math")]
{
for i in 0..n {
for j in 0..n {
result[i] += unitary[[i, j]] * state[j];
}
}
}
#[cfg(not(feature = "advanced_math"))]
{
for i in 0..n {
for j in 0..n {
result[i] += unitary[[i, j]] * state[j];
}
}
}
state.copy_from_slice(&result);
Ok(())
}
#[must_use]
pub fn tensor_product(a: &ArrayView2<Complex64>, b: &ArrayView2<Complex64>) -> Array2<Complex64> {
let (m, n) = a.dim();
let (p, q) = b.dim();
let mut result = Array2::zeros((m * p, n * q));
for i in 0..m {
for j in 0..n {
for k in 0..p {
for l in 0..q {
result[[i * p + k, j * q + l]] = a[[i, j]] * b[[k, l]];
}
}
}
}
result
}
pub fn partial_trace(
density_matrix: &ArrayView2<Complex64>,
qubits_to_trace: &[usize],
total_qubits: usize,
) -> Result<Array2<Complex64>, String> {
let dim = 1 << total_qubits;
if density_matrix.shape() != [dim, dim] {
return Err(format!(
"Density matrix shape {:?} doesn't match {} qubits",
density_matrix.shape(),
total_qubits
));
}
let traced_qubits = qubits_to_trace.len();
let remaining_qubits = total_qubits - traced_qubits;
let remaining_dim = 1 << remaining_qubits;
let mut result = Array2::zeros((remaining_dim, remaining_dim));
for i in 0..remaining_dim {
for j in 0..remaining_dim {
let mut sum = Complex64::new(0.0, 0.0);
for k in 0..(1 << traced_qubits) {
let full_i = i + (k << remaining_qubits);
let full_j = j + (k << remaining_qubits);
if full_i < dim && full_j < dim {
sum += density_matrix[[full_i, full_j]];
}
}
result[[i, j]] = sum;
}
}
Ok(result)
}
#[must_use]
pub fn is_unitary(matrix: &ArrayView2<Complex64>, tolerance: f64) -> bool {
let n = matrix.nrows();
if matrix.ncols() != n {
return false; }
let mut product: Array2<Complex64> = Array2::zeros((n, n));
#[cfg(feature = "advanced_math")]
{
let conjugate_transpose = matrix.t().mapv(|x| x.conj());
product = conjugate_transpose.dot(matrix);
}
#[cfg(not(feature = "advanced_math"))]
{
for i in 0..n {
for j in 0..n {
for k in 0..n {
product[[i, j]] += matrix[[k, i]].conj() * matrix[[k, j]];
}
}
}
}
for i in 0..n {
for j in 0..n {
let expected = if i == j {
Complex64::new(1.0, 0.0)
} else {
Complex64::new(0.0, 0.0)
};
let diff: Complex64 = product[[i, j]] - expected;
if diff.norm() > tolerance {
return false;
}
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::arr2;
#[test]
fn test_apply_unitary() {
let h = arr2(&[
[
Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
],
[
Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0),
Complex64::new(-1.0 / std::f64::consts::SQRT_2, 0.0),
],
]);
let mut state = vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)];
apply_unitary(&h.view(), &mut state).expect("unitary application should succeed");
let expected_0 = Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0);
assert!((state[0] - expected_0).norm() < 1e-10);
assert!((state[1] - expected_0).norm() < 1e-10);
}
#[test]
fn test_tensor_product() {
let a = arr2(&[
[Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
[Complex64::new(3.0, 0.0), Complex64::new(4.0, 0.0)],
]);
let b = arr2(&[
[Complex64::new(5.0, 0.0), Complex64::new(6.0, 0.0)],
[Complex64::new(7.0, 0.0), Complex64::new(8.0, 0.0)],
]);
let result = tensor_product(&a.view(), &b.view());
assert_eq!(result.dim(), (4, 4));
assert_eq!(result[[0, 0]], Complex64::new(5.0, 0.0));
assert_eq!(result[[0, 1]], Complex64::new(6.0, 0.0));
assert_eq!(result[[3, 3]], Complex64::new(32.0, 0.0));
}
#[test]
fn test_is_unitary() {
let x = arr2(&[
[Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)],
[Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
]);
assert!(is_unitary(&x.view(), 1e-10));
let non_unitary = arr2(&[
[Complex64::new(1.0, 0.0), Complex64::new(1.0, 0.0)],
[Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)],
]);
assert!(!is_unitary(&non_unitary.view(), 1e-10));
}
}