use quantrs2_circuit::builder::{Circuit, Simulator};
use quantrs2_core::{
error::{QuantRS2Error, QuantRS2Result},
gate::GateOp,
register::Register,
};
use scirs2_core::ndarray::{array, s, Array2, Array3};
use scirs2_core::random::prelude::*;
use scirs2_core::random::{thread_rng, Rng};
use scirs2_core::Complex64;
use std::f64::consts::SQRT_2;
#[derive(Debug, Clone)]
pub struct BasicMPSConfig {
pub max_bond_dim: usize,
pub svd_threshold: f64,
}
impl Default for BasicMPSConfig {
fn default() -> Self {
Self {
max_bond_dim: 64,
svd_threshold: 1e-10,
}
}
}
#[derive(Debug, Clone)]
struct MPSTensor {
data: Array3<Complex64>,
}
impl MPSTensor {
fn zero_state(position: usize, num_qubits: usize) -> Self {
let is_first = position == 0;
let is_last = position == num_qubits - 1;
let data = if is_first && is_last {
let mut tensor = Array3::zeros((1, 2, 1));
tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
tensor
} else if is_first {
let mut tensor = Array3::zeros((1, 2, 2));
tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
tensor
} else if is_last {
let mut tensor = Array3::zeros((2, 2, 1));
tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
tensor
} else {
let mut tensor = Array3::zeros((2, 2, 2));
tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
tensor
};
Self { data }
}
}
pub struct BasicMPS {
tensors: Vec<MPSTensor>,
num_qubits: usize,
config: BasicMPSConfig,
}
impl BasicMPS {
#[must_use]
pub fn new(num_qubits: usize, config: BasicMPSConfig) -> Self {
let tensors = (0..num_qubits)
.map(|i| MPSTensor::zero_state(i, num_qubits))
.collect();
Self {
tensors,
num_qubits,
config,
}
}
pub fn apply_single_qubit_gate(
&mut self,
gate_matrix: &Array2<Complex64>,
qubit: usize,
) -> QuantRS2Result<()> {
if qubit >= self.num_qubits {
return Err(QuantRS2Error::InvalidQubitId(qubit as u32));
}
let tensor = &mut self.tensors[qubit];
let shape = tensor.data.shape();
let (left_dim, _, right_dim) = (shape[0], shape[1], shape[2]);
let mut new_data = Array3::zeros((left_dim, 2, right_dim));
for l in 0..left_dim {
for r in 0..right_dim {
for new_phys in 0..2 {
for old_phys in 0..2 {
new_data[[l, new_phys, r]] +=
gate_matrix[[new_phys, old_phys]] * tensor.data[[l, old_phys, r]];
}
}
}
}
tensor.data = new_data;
Ok(())
}
pub fn apply_two_qubit_gate(
&mut self,
gate_matrix: &Array2<Complex64>,
qubit1: usize,
qubit2: usize,
) -> QuantRS2Result<()> {
if (qubit1 as i32 - qubit2 as i32).abs() != 1 {
return Err(QuantRS2Error::InvalidInput(
"MPS requires adjacent qubits for two-qubit gates".to_string(),
));
}
let (left_q, right_q) = if qubit1 < qubit2 {
(qubit1, qubit2)
} else {
(qubit2, qubit1)
};
let left_shape = self.tensors[left_q].data.shape().to_vec();
let right_shape = self.tensors[right_q].data.shape().to_vec();
let mut combined = Array3::<Complex64>::zeros((left_shape[0], 4, right_shape[2]));
for l in 0..left_shape[0] {
for r in 0..right_shape[2] {
for i in 0..2 {
for j in 0..2 {
for m in 0..left_shape[2] {
combined[[l, i * 2 + j, r]] += self.tensors[left_q].data[[l, i, m]]
* self.tensors[right_q].data[[m, j, r]];
}
}
}
}
}
let mut result = Array3::<Complex64>::zeros((left_shape[0], 4, right_shape[2]));
for l in 0..left_shape[0] {
for r in 0..right_shape[2] {
for out_idx in 0..4 {
for in_idx in 0..4 {
result[[l, out_idx, r]] +=
gate_matrix[[out_idx, in_idx]] * combined[[l, in_idx, r]];
}
}
}
}
let new_bond = 2.min(self.config.max_bond_dim);
let mut left_new = Array3::zeros((left_shape[0], 2, new_bond));
let mut right_new = Array3::zeros((new_bond, 2, right_shape[2]));
for l in 0..left_shape[0] {
for r in 0..right_shape[2] {
for i in 0..2 {
for j in 0..2 {
let bond_idx = (i + j) % new_bond;
left_new[[l, i, bond_idx]] = result[[l, i * 2 + j, r]];
right_new[[bond_idx, j, r]] = Complex64::new(1.0, 0.0);
}
}
}
}
self.tensors[left_q].data = left_new;
self.tensors[right_q].data = right_new;
Ok(())
}
pub fn get_amplitude(&self, bitstring: &[bool]) -> QuantRS2Result<Complex64> {
if bitstring.len() != self.num_qubits {
return Err(QuantRS2Error::InvalidInput(format!(
"Bitstring length {} doesn't match qubit count {}",
bitstring.len(),
self.num_qubits
)));
}
let mut result = Array2::from_elem((1, 1), Complex64::new(1.0, 0.0));
for (i, &bit) in bitstring.iter().enumerate() {
let tensor = &self.tensors[i];
let physical_idx = i32::from(bit);
let matrix = tensor.data.slice(s![.., physical_idx, ..]);
result = result.dot(&matrix);
}
Ok(result[[0, 0]])
}
#[must_use]
pub fn sample(&self) -> Vec<bool> {
let mut rng = thread_rng();
let mut result = vec![false; self.num_qubits];
let mut accumulated = Array2::from_elem((1, 1), Complex64::new(1.0, 0.0));
for (i, tensor) in self.tensors.iter().enumerate() {
let matrix0 = tensor.data.slice(s![.., 0, ..]);
let matrix1 = tensor.data.slice(s![.., 1, ..]);
let branch0: Array2<Complex64> = accumulated.dot(&matrix0);
let branch1: Array2<Complex64> = accumulated.dot(&matrix1);
let norm0_sq: f64 = branch0.iter().map(scirs2_core::Complex::norm_sqr).sum();
let norm1_sq: f64 = branch1.iter().map(scirs2_core::Complex::norm_sqr).sum();
let total = norm0_sq + norm1_sq;
let prob0 = norm0_sq / total;
if rng.random::<f64>() < prob0 {
result[i] = false;
accumulated = branch0;
} else {
result[i] = true;
accumulated = branch1;
}
let norm_sq: f64 = accumulated.iter().map(scirs2_core::Complex::norm_sqr).sum();
if norm_sq > 0.0 {
accumulated /= Complex64::new(norm_sq.sqrt(), 0.0);
}
}
result
}
}
pub struct BasicMPSSimulator {
config: BasicMPSConfig,
}
impl BasicMPSSimulator {
#[must_use]
pub const fn new(config: BasicMPSConfig) -> Self {
Self { config }
}
#[must_use]
pub fn default() -> Self {
Self::new(BasicMPSConfig::default())
}
}
impl<const N: usize> Simulator<N> for BasicMPSSimulator {
fn run(&self, circuit: &Circuit<N>) -> QuantRS2Result<Register<N>> {
let mut mps = BasicMPS::new(N, self.config.clone());
for gate in circuit.gates() {
match gate.name() {
"H" => {
let h_matrix = {
let h = 1.0 / SQRT_2;
array![
[Complex64::new(h, 0.), Complex64::new(h, 0.)],
[Complex64::new(h, 0.), Complex64::new(-h, 0.)]
]
};
if let Some(&qubit) = gate.qubits().first() {
mps.apply_single_qubit_gate(&h_matrix, qubit.id() as usize)?;
}
}
"X" => {
let x_matrix = array![
[Complex64::new(0., 0.), Complex64::new(1., 0.)],
[Complex64::new(1., 0.), Complex64::new(0., 0.)]
];
if let Some(&qubit) = gate.qubits().first() {
mps.apply_single_qubit_gate(&x_matrix, qubit.id() as usize)?;
}
}
"CNOT" | "CX" => {
let cnot_matrix = array![
[
Complex64::new(1., 0.),
Complex64::new(0., 0.),
Complex64::new(0., 0.),
Complex64::new(0., 0.)
],
[
Complex64::new(0., 0.),
Complex64::new(1., 0.),
Complex64::new(0., 0.),
Complex64::new(0., 0.)
],
[
Complex64::new(0., 0.),
Complex64::new(0., 0.),
Complex64::new(0., 0.),
Complex64::new(1., 0.)
],
[
Complex64::new(0., 0.),
Complex64::new(0., 0.),
Complex64::new(1., 0.),
Complex64::new(0., 0.)
],
];
let qubits = gate.qubits();
if qubits.len() == 2 {
mps.apply_two_qubit_gate(
&cnot_matrix,
qubits[0].id() as usize,
qubits[1].id() as usize,
)?;
}
}
_ => {
}
}
}
Ok(Register::new())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_mps_initialization() {
let mps = BasicMPS::new(4, BasicMPSConfig::default());
let amp = mps
.get_amplitude(&[false, false, false, false])
.expect("Failed to get amplitude for |0000>");
assert!((amp.norm() - 1.0).abs() < 1e-10);
let amp = mps
.get_amplitude(&[true, false, false, false])
.expect("Failed to get amplitude for |1000>");
assert!(amp.norm() < 1e-10);
}
#[test]
fn test_single_qubit_gate() {
let mut mps = BasicMPS::new(3, BasicMPSConfig::default());
let x_matrix = array![
[Complex64::new(0., 0.), Complex64::new(1., 0.)],
[Complex64::new(1., 0.), Complex64::new(0., 0.)]
];
mps.apply_single_qubit_gate(&x_matrix, 0)
.expect("Failed to apply X gate");
let amp = mps
.get_amplitude(&[true, false, false])
.expect("Failed to get amplitude for |100>");
assert!((amp.norm() - 1.0).abs() < 1e-10);
}
}