use quantrs2_circuit::builder::{Circuit, Simulator};
use quantrs2_core::{
error::{QuantRS2Error, QuantRS2Result},
gate::GateOp,
prelude::QubitId,
register::Register,
};
use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView2};
use scirs2_core::Complex64;
#[derive(Debug, Clone)]
struct MPSTensor {
data: Array3<Complex64>,
left_dim: usize,
right_dim: usize,
}
impl MPSTensor {
fn new(data: Array3<Complex64>) -> Self {
let shape = data.shape();
Self {
left_dim: shape[0],
right_dim: shape[2],
data,
}
}
fn zero_state(is_first: bool, is_last: bool) -> Self {
let data = if is_first && is_last {
let mut tensor = Array3::zeros((1, 2, 1));
tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
tensor
} else if is_first {
let mut tensor = Array3::zeros((1, 2, 2));
tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
tensor
} else if is_last {
let mut tensor = Array3::zeros((2, 2, 1));
tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
tensor
} else {
let mut tensor = Array3::zeros((2, 2, 2));
tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
tensor
};
Self::new(data)
}
}
pub struct MPS {
tensors: Vec<MPSTensor>,
num_qubits: usize,
max_bond_dim: usize,
truncation_threshold: f64,
orthogonality_center: i32,
}
impl MPS {
#[must_use]
pub fn new(num_qubits: usize, max_bond_dim: usize) -> Self {
let tensors = (0..num_qubits)
.map(|i| MPSTensor::zero_state(i == 0, i == num_qubits - 1))
.collect();
Self {
tensors,
num_qubits,
max_bond_dim,
truncation_threshold: 1e-10,
orthogonality_center: -1,
}
}
pub const fn set_truncation_threshold(&mut self, threshold: f64) {
self.truncation_threshold = threshold;
}
pub fn move_orthogonality_center(&mut self, target: usize) -> QuantRS2Result<()> {
if target >= self.num_qubits {
return Err(QuantRS2Error::InvalidQubitId(target as u32));
}
if self.orthogonality_center < 0 {
self.left_canonicalize_up_to(target)?;
self.orthogonality_center = target as i32;
return Ok(());
}
let current = self.orthogonality_center as usize;
if current < target {
for i in current..target {
self.move_center_right(i)?;
}
} else if current > target {
for i in (target + 1..=current).rev() {
self.move_center_left(i)?;
}
}
self.orthogonality_center = target as i32;
Ok(())
}
fn left_canonicalize_up_to(&mut self, position: usize) -> QuantRS2Result<()> {
for i in 0..position {
let tensor = &self.tensors[i];
let (left_dim, phys_dim, right_dim) = (tensor.left_dim, 2, tensor.right_dim);
let matrix = tensor
.data
.view()
.into_shape((left_dim * phys_dim, right_dim))?;
let (q, r) = qr_decomposition(&matrix)?;
let new_shape = (left_dim, phys_dim, q.shape()[1]);
self.tensors[i].data = q.into_shape(new_shape)?;
self.tensors[i].right_dim = new_shape.2;
if i + 1 < self.num_qubits {
let next = &mut self.tensors[i + 1];
let next_matrix = next
.data
.view()
.into_shape((next.left_dim, 2 * next.right_dim))?;
let new_matrix = r.dot(&next_matrix);
next.data = new_matrix.into_shape((r.shape()[0], 2, next.right_dim))?;
next.left_dim = r.shape()[0];
}
}
Ok(())
}
fn move_center_right(&mut self, position: usize) -> QuantRS2Result<()> {
let tensor = &self.tensors[position];
let (left_dim, phys_dim, right_dim) = (tensor.left_dim, 2, tensor.right_dim);
let matrix = tensor
.data
.view()
.into_shape((left_dim * phys_dim, right_dim))?;
let (q, r) = qr_decomposition(&matrix)?;
let q_cols = q.shape()[1];
self.tensors[position].data = q.into_shape((left_dim, phys_dim, q_cols))?;
self.tensors[position].right_dim = q_cols;
if position + 1 < self.num_qubits {
let next = &mut self.tensors[position + 1];
let next_matrix = next
.data
.view()
.into_shape((next.left_dim, 2 * next.right_dim))?;
let new_matrix = r.dot(&next_matrix);
next.data = new_matrix.into_shape((r.shape()[0], 2, next.right_dim))?;
next.left_dim = r.shape()[0];
}
Ok(())
}
fn move_center_left(&mut self, position: usize) -> QuantRS2Result<()> {
let tensor = &self.tensors[position];
let (left_dim, phys_dim, right_dim) = (tensor.left_dim, 2, tensor.right_dim);
let matrix = tensor
.data
.view()
.permuted_axes([2, 1, 0])
.into_shape((right_dim * phys_dim, left_dim))?;
let (q, r) = qr_decomposition(&matrix)?;
let q_cols = q.shape()[1];
let q_reshaped = q.into_shape((right_dim, phys_dim, q_cols))?;
self.tensors[position].data = q_reshaped.permuted_axes([2, 1, 0]);
self.tensors[position].left_dim = q_cols;
if position > 0 {
let prev = &mut self.tensors[position - 1];
let prev_matrix = prev
.data
.view()
.into_shape((prev.left_dim * 2, prev.right_dim))?;
let new_matrix = prev_matrix.dot(&r.t());
prev.data = new_matrix.into_shape((prev.left_dim, 2, r.shape()[0]))?;
prev.right_dim = r.shape()[0];
}
Ok(())
}
pub fn apply_single_qubit_gate(
&mut self,
gate: &dyn GateOp,
qubit: usize,
) -> QuantRS2Result<()> {
if qubit >= self.num_qubits {
return Err(QuantRS2Error::InvalidQubitId(qubit as u32));
}
let gate_matrix = gate.matrix()?;
let gate_array = Array2::from_shape_vec((2, 2), gate_matrix)?;
let tensor = &mut self.tensors[qubit];
let mut new_data = Array3::zeros(tensor.data.dim());
for left in 0..tensor.left_dim {
for right in 0..tensor.right_dim {
for i in 0..2 {
for j in 0..2 {
new_data[[left, i, right]] +=
gate_array[[i, j]] * tensor.data[[left, j, right]];
}
}
}
}
tensor.data = new_data;
Ok(())
}
pub fn apply_two_qubit_gate(
&mut self,
gate: &dyn GateOp,
qubit1: usize,
qubit2: usize,
) -> QuantRS2Result<()> {
if (qubit1 as i32 - qubit2 as i32).abs() != 1 {
return Err(QuantRS2Error::ComputationError(
"MPS simulator requires adjacent qubits for two-qubit gates".to_string(),
));
}
let (left_qubit, right_qubit) = if qubit1 < qubit2 {
(qubit1, qubit2)
} else {
(qubit2, qubit1)
};
self.move_orthogonality_center(left_qubit)?;
let gate_matrix = gate.matrix()?;
let gate_array = Array2::from_shape_vec((4, 4), gate_matrix)?;
let left_tensor = &self.tensors[left_qubit];
let right_tensor = &self.tensors[right_qubit];
let left_dim = left_tensor.left_dim;
let right_dim = right_tensor.right_dim;
let mut combined = Array3::<Complex64>::zeros((left_dim, 4, right_dim));
for l in 0..left_dim {
for r in 0..right_dim {
for i in 0..2 {
for j in 0..2 {
for k in 0..left_tensor.right_dim {
combined[[l, i * 2 + j, r]] +=
left_tensor.data[[l, i, k]] * right_tensor.data[[k, j, r]];
}
}
}
}
}
let mut gated = Array3::<Complex64>::zeros((left_dim, 4, right_dim));
for l in 0..left_dim {
for r in 0..right_dim {
for out_idx in 0..4 {
for in_idx in 0..4 {
gated[[l, out_idx, r]] +=
gate_array[[out_idx, in_idx]] * combined[[l, in_idx, r]];
}
}
}
}
let matrix = gated.into_shape((left_dim * 2, 2 * right_dim))?;
let (u, s, vt) = svd_decomposition(&matrix, self.max_bond_dim, self.truncation_threshold)?;
let new_bond = s.len();
self.tensors[left_qubit].data = u.into_shape((left_dim, 2, new_bond))?;
self.tensors[left_qubit].right_dim = new_bond;
let mut sv = Array2::<Complex64>::zeros((new_bond, vt.shape()[1]));
for i in 0..new_bond {
for j in 0..vt.shape()[1] {
sv[[i, j]] = Complex64::new(s[i], 0.0) * vt[[i, j]];
}
}
self.tensors[right_qubit].data = sv.t().to_owned().into_shape((new_bond, 2, right_dim))?;
self.tensors[right_qubit].left_dim = new_bond;
self.orthogonality_center = right_qubit as i32;
Ok(())
}
pub fn get_amplitude(&self, bitstring: &[bool]) -> QuantRS2Result<Complex64> {
if bitstring.len() != self.num_qubits {
return Err(QuantRS2Error::ComputationError(format!(
"Bitstring length {} doesn't match qubit count {}",
bitstring.len(),
self.num_qubits
)));
}
let mut result = Array2::eye(1);
for (i, &bit) in bitstring.iter().enumerate() {
let tensor = &self.tensors[i];
let idx = i32::from(bit);
let matrix = tensor.data.slice(s![.., idx, ..]);
result = result.dot(&matrix);
}
Ok(result[[0, 0]])
}
#[must_use]
pub fn sample(&self) -> Vec<bool> {
use scirs2_core::random::prelude::*;
let mut rng = thread_rng();
let mut result = vec![false; self.num_qubits];
let mut accumulated_matrix = Array2::eye(1);
for (i, tensor) in self.tensors.iter().enumerate() {
let mut prob0 = Complex64::new(0.0, 0.0);
let mut prob1 = Complex64::new(0.0, 0.0);
let matrix0 = tensor.data.slice(s![.., 0, ..]);
let temp0: Array2<Complex64> = accumulated_matrix.dot(&matrix0);
let mut right_contract = Array2::eye(temp0.shape()[1]);
for j in (i + 1)..self.num_qubits {
let sum_matrix = self.tensors[j].data.slice(s![.., 0, ..]).to_owned()
+ self.tensors[j].data.slice(s![.., 1, ..]).to_owned();
right_contract = right_contract.dot(&sum_matrix);
}
prob0 = temp0.dot(&right_contract)[[0, 0]];
let matrix1 = tensor.data.slice(s![.., 1, ..]);
let temp1: Array2<Complex64> = accumulated_matrix.dot(&matrix1);
prob1 = temp1.dot(&right_contract)[[0, 0]];
let total = prob0.norm_sqr() + prob1.norm_sqr();
let threshold = prob0.norm_sqr() / total;
if rng.random::<f64>() < threshold {
result[i] = false;
accumulated_matrix = temp0;
} else {
result[i] = true;
accumulated_matrix = temp1;
}
}
result
}
}
fn qr_decomposition(
matrix: &ArrayView2<Complex64>,
) -> QuantRS2Result<(Array2<Complex64>, Array2<Complex64>)> {
let (m, n) = matrix.dim();
let mut q = Array2::zeros((m, n.min(m)));
let mut r = Array2::zeros((n.min(m), n));
for j in 0..n.min(m) {
let mut v = matrix.column(j).to_owned();
for i in 0..j {
let proj = q.column(i).dot(&v);
r[[i, j]] = proj;
v -= &(proj * &q.column(i).to_owned());
}
let norm = (v.dot(&v)).sqrt();
if norm.norm() > 1e-10 {
r[[j, j]] = norm;
q.column_mut(j).assign(&(v / norm));
}
}
if n > m {
for j in m..n {
for i in 0..m {
r[[i, j]] = q.column(i).dot(&matrix.column(j));
}
}
}
Ok((q, r))
}
fn svd_decomposition(
matrix: &Array2<Complex64>,
max_bond: usize,
threshold: f64,
) -> QuantRS2Result<(Array2<Complex64>, Array1<f64>, Array2<Complex64>)> {
let (m, n) = matrix.dim();
let k = m.min(n).min(max_bond);
let u = Array2::eye(m).slice(s![.., ..k]).to_owned();
let s = Array1::ones(k);
let vt = Array2::eye(n).slice(s![..k, ..]).to_owned();
Ok((u, s, vt))
}
pub struct MPSSimulator {
max_bond_dimension: usize,
truncation_threshold: f64,
}
impl MPSSimulator {
#[must_use]
pub const fn new(max_bond_dimension: usize) -> Self {
Self {
max_bond_dimension,
truncation_threshold: 1e-10,
}
}
pub const fn set_truncation_threshold(&mut self, threshold: f64) {
self.truncation_threshold = threshold;
}
}
impl<const N: usize> Simulator<N> for MPSSimulator {
fn run(&self, circuit: &Circuit<N>) -> QuantRS2Result<Register<N>> {
let mut mps = MPS::new(N, self.max_bond_dimension);
mps.set_truncation_threshold(self.truncation_threshold);
Ok(Register::new())
}
}
#[cfg(test)]
mod tests {
use super::*;
use quantrs2_core::gate::single::Hadamard;
#[test]
fn test_mps_creation() {
let mps = MPS::new(4, 10);
assert_eq!(mps.num_qubits, 4);
assert_eq!(mps.tensors.len(), 4);
}
#[test]
fn test_single_qubit_gate() {
let mut mps = MPS::new(1, 10);
let h = Hadamard {
target: QubitId::new(0),
};
mps.apply_single_qubit_gate(&h, 0)
.expect("Failed to apply single qubit gate");
let amp0 = mps
.get_amplitude(&[false])
.expect("Failed to get amplitude for |0>");
let amp1 = mps
.get_amplitude(&[true])
.expect("Failed to get amplitude for |1>");
let expected = 1.0 / 2.0_f64.sqrt();
assert!((amp0.re - expected).abs() < 1e-10);
assert!((amp1.re - expected).abs() < 1e-10);
}
#[test]
fn test_orthogonality_center() {
let mut mps = MPS::new(5, 10);
mps.move_orthogonality_center(2)
.expect("Failed to move orthogonality center to 2");
assert_eq!(mps.orthogonality_center, 2);
mps.move_orthogonality_center(4)
.expect("Failed to move orthogonality center to 4");
assert_eq!(mps.orthogonality_center, 4);
mps.move_orthogonality_center(0)
.expect("Failed to move orthogonality center to 0");
assert_eq!(mps.orthogonality_center, 0);
}
}