use crate::error::{IntegrateError, IntegrateResult as Result};
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::numeric::Complex64;
#[derive(Debug, Clone)]
pub struct MultiParticleEntanglement {
pub n_particles: usize,
pub hilbert_dim: usize,
pub state: Array1<Complex64>,
pub masses: Array1<f64>,
pub interactions: Array2<f64>,
}
impl MultiParticleEntanglement {
pub fn new(nparticles: usize, masses: Array1<f64>) -> Self {
let hilbert_dim = 2_usize.pow(nparticles as u32); let state = Array1::zeros(hilbert_dim);
let interactions = Array2::zeros((nparticles, nparticles));
Self {
n_particles: nparticles,
hilbert_dim,
state,
masses,
interactions,
}
}
pub fn create_bell_state(&mut self, belltype: BellState) -> Result<()> {
if self.n_particles != 2 {
return Err(IntegrateError::InvalidInput(
"Bell states require exactly 2 particles".to_string(),
));
}
let inv_sqrt2 = 1.0 / (2.0_f64).sqrt();
self.state = Array1::zeros(4);
match belltype {
BellState::PhiPlus => {
self.state[0] = Complex64::new(inv_sqrt2, 0.0); self.state[3] = Complex64::new(inv_sqrt2, 0.0); }
BellState::PhiMinus => {
self.state[0] = Complex64::new(inv_sqrt2, 0.0); self.state[3] = Complex64::new(-inv_sqrt2, 0.0); }
BellState::PsiPlus => {
self.state[1] = Complex64::new(inv_sqrt2, 0.0); self.state[2] = Complex64::new(inv_sqrt2, 0.0); }
BellState::PsiMinus => {
self.state[1] = Complex64::new(inv_sqrt2, 0.0); self.state[2] = Complex64::new(-inv_sqrt2, 0.0); }
}
Ok(())
}
pub fn create_ghz_state(&mut self) -> Result<()> {
if self.n_particles < 3 {
return Err(IntegrateError::InvalidInput(
"GHZ states require at least 3 particles".to_string(),
));
}
let inv_sqrt2 = 1.0 / (2.0_f64).sqrt();
self.state = Array1::zeros(self.hilbert_dim);
self.state[0] = Complex64::new(inv_sqrt2, 0.0); self.state[self.hilbert_dim - 1] = Complex64::new(inv_sqrt2, 0.0);
Ok(())
}
pub fn create_w_state(&mut self) -> Result<()> {
if self.n_particles < 3 {
return Err(IntegrateError::InvalidInput(
"W states require at least 3 particles".to_string(),
));
}
let inv_sqrt_n = 1.0 / (self.n_particles as f64).sqrt();
self.state = Array1::zeros(self.hilbert_dim);
for i in 0..self.n_particles {
let state_index = 1 << (self.n_particles - 1 - i);
self.state[state_index] = Complex64::new(inv_sqrt_n, 0.0);
}
Ok(())
}
pub fn calculate_entanglement_entropy(&self, subsystemqubits: &[usize]) -> Result<f64> {
let rho_sub = self.reduced_density_matrix(subsystemqubits)?;
let eigenvalues = self.compute_eigenvalues(&rho_sub)?;
let mut entropy = 0.0;
for &lambda in &eigenvalues {
if lambda > 1e-12 {
entropy += -lambda * lambda.ln();
}
}
Ok(entropy)
}
fn reduced_density_matrix(&self, subsystemqubits: &[usize]) -> Result<Array2<Complex64>> {
let subsystem_size = subsystemqubits.len();
let subsystem_dim = 1 << subsystem_size;
let mut rho_sub = Array2::zeros((subsystem_dim, subsystem_dim));
for i in 0..subsystem_dim {
for j in 0..subsystem_dim {
let mut sum = Complex64::new(0.0, 0.0);
let env_size = self.n_particles - subsystem_size;
let env_dim = 1 << env_size;
for env_config in 0..env_dim {
let full_i = self.combine_subsystem_env(i, env_config, subsystemqubits);
let full_j = self.combine_subsystem_env(j, env_config, subsystemqubits);
if full_i < self.hilbert_dim && full_j < self.hilbert_dim {
sum += self.state[full_i].conj() * self.state[full_j];
}
}
rho_sub[[i, j]] = sum;
}
}
Ok(rho_sub)
}
fn combine_subsystem_env(
&self,
sub_config: usize,
env_config: usize,
subsystem_qubits: &[usize],
) -> usize {
let mut full_config = 0;
let mut env_bit = 0;
for qubit in 0..self.n_particles {
if subsystem_qubits.contains(&qubit) {
let sub_bit_pos = subsystem_qubits
.iter()
.position(|&x| x == qubit)
.expect("Operation failed");
if (sub_config >> sub_bit_pos) & 1 == 1 {
full_config |= 1 << qubit;
}
} else {
if (env_config >> env_bit) & 1 == 1 {
full_config |= 1 << qubit;
}
env_bit += 1;
}
}
full_config
}
fn compute_eigenvalues(&self, rho: &Array2<Complex64>) -> Result<Vec<f64>> {
let n = rho.nrows();
let mut eigenvalues = Vec::new();
for i in 0..n {
eigenvalues.push(rho[[i, i]].re);
}
Ok(eigenvalues)
}
pub fn calculate_concurrence(&self) -> Result<f64> {
if self.n_particles != 2 {
return Err(IntegrateError::InvalidInput(
"Concurrence is defined only for two-qubit systems".to_string(),
));
}
let mut rho = Array2::zeros((4, 4));
for i in 0..4 {
for j in 0..4 {
rho[[i, j]] = self.state[i].conj() * self.state[j];
}
}
let a = self.state[0];
let b = self.state[1];
let c = self.state[2];
let d = self.state[3];
let concurrence = 2.0 * (a * d - b * c).norm();
Ok(concurrence)
}
pub fn apply_cnot(&mut self, control: usize, target: usize) -> Result<()> {
if control >= self.n_particles || target >= self.n_particles {
return Err(IntegrateError::InvalidInput(
"Qubit indices out of range".to_string(),
));
}
let mut new_state = Array1::zeros(self.hilbert_dim);
for i in 0..self.hilbert_dim {
let control_bit = (i >> (self.n_particles - 1 - control)) & 1;
let target_bit = (i >> (self.n_particles - 1 - target)) & 1;
let new_i = if control_bit == 1 {
i ^ (1 << (self.n_particles - 1 - target))
} else {
i
};
new_state[new_i] = self.state[i];
}
self.state = new_state;
Ok(())
}
pub fn apply_hadamard(&mut self, qubit: usize) -> Result<()> {
if qubit >= self.n_particles {
return Err(IntegrateError::InvalidInput(
"Qubit index out of range".to_string(),
));
}
let mut new_state = Array1::zeros(self.hilbert_dim);
let inv_sqrt2 = Complex64::new(1.0 / (2.0_f64).sqrt(), 0.0);
for i in 0..self.hilbert_dim {
let bit = (i >> (self.n_particles - 1 - qubit)) & 1;
let flipped_i = i ^ (1 << (self.n_particles - 1 - qubit));
if bit == 0 {
new_state[i] += inv_sqrt2 * self.state[i];
new_state[flipped_i] += inv_sqrt2 * self.state[i];
} else {
new_state[flipped_i] += inv_sqrt2 * self.state[i];
new_state[i] -= inv_sqrt2 * self.state[i];
}
}
self.state = new_state;
Ok(())
}
pub fn measure_entanglement_witness(
&self,
witness_operator: &Array2<Complex64>,
) -> Result<f64> {
if witness_operator.nrows() != self.hilbert_dim
|| witness_operator.ncols() != self.hilbert_dim
{
return Err(IntegrateError::InvalidInput(
"Witness operator dimension mismatch".to_string(),
));
}
let mut expectation = Complex64::new(0.0, 0.0);
for i in 0..self.hilbert_dim {
for j in 0..self.hilbert_dim {
expectation += self.state[i].conj() * witness_operator[[i, j]] * self.state[j];
}
}
Ok(expectation.re)
}
pub fn normalize(&mut self) {
let norm_squared: f64 = self.state.iter().map(|&c| (c.conj() * c).re).sum();
let norm = norm_squared.sqrt();
if norm > 1e-12 {
self.state.mapv_inplace(|c| c / norm);
}
}
pub fn get_state(&self) -> &Array1<Complex64> {
&self.state
}
pub fn set_state(&mut self, newstate: Array1<Complex64>) -> Result<()> {
if newstate.len() != self.hilbert_dim {
return Err(IntegrateError::InvalidInput(
"State dimension mismatch".to_string(),
));
}
self.state = newstate;
self.normalize();
Ok(())
}
}
#[derive(Debug, Clone, Copy)]
pub enum BellState {
PhiPlus,
PhiMinus,
PsiPlus,
PsiMinus,
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_bell_state_creation() {
let masses = Array1::from_vec(vec![1.0, 1.0]);
let mut system = MultiParticleEntanglement::new(2, masses);
system
.create_bell_state(BellState::PhiPlus)
.expect("Operation failed");
let state = system.get_state();
let inv_sqrt2 = 1.0 / (2.0_f64).sqrt();
assert_relative_eq!(state[0].re, inv_sqrt2, epsilon = 1e-10);
assert_relative_eq!(state[3].re, inv_sqrt2, epsilon = 1e-10);
assert_relative_eq!(state[1].norm(), 0.0, epsilon = 1e-10);
assert_relative_eq!(state[2].norm(), 0.0, epsilon = 1e-10);
}
#[test]
fn test_ghz_state() {
let masses = Array1::from_vec(vec![1.0, 1.0, 1.0]);
let mut system = MultiParticleEntanglement::new(3, masses);
system.create_ghz_state().expect("Operation failed");
let state = system.get_state();
let inv_sqrt2 = 1.0 / (2.0_f64).sqrt();
assert_relative_eq!(state[0].re, inv_sqrt2, epsilon = 1e-10);
assert_relative_eq!(state[7].re, inv_sqrt2, epsilon = 1e-10);
for i in 1..7 {
assert_relative_eq!(state[i].norm(), 0.0, epsilon = 1e-10);
}
}
#[test]
fn test_w_state() {
let masses = Array1::from_vec(vec![1.0, 1.0, 1.0]);
let mut system = MultiParticleEntanglement::new(3, masses);
system.create_w_state().expect("Operation failed");
let state = system.get_state();
let inv_sqrt3 = 1.0 / 3.0_f64.sqrt();
assert_relative_eq!(state[1].re, inv_sqrt3, epsilon = 1e-10); assert_relative_eq!(state[2].re, inv_sqrt3, epsilon = 1e-10); assert_relative_eq!(state[4].re, inv_sqrt3, epsilon = 1e-10);
assert_relative_eq!(state[0].norm(), 0.0, epsilon = 1e-10);
assert_relative_eq!(state[3].norm(), 0.0, epsilon = 1e-10);
assert_relative_eq!(state[5].norm(), 0.0, epsilon = 1e-10);
assert_relative_eq!(state[6].norm(), 0.0, epsilon = 1e-10);
assert_relative_eq!(state[7].norm(), 0.0, epsilon = 1e-10);
}
#[test]
fn test_concurrence() {
let masses = Array1::from_vec(vec![1.0, 1.0]);
let mut system = MultiParticleEntanglement::new(2, masses);
system
.create_bell_state(BellState::PhiPlus)
.expect("Operation failed");
let concurrence = system.calculate_concurrence().expect("Operation failed");
assert_relative_eq!(concurrence, 1.0, epsilon = 1e-10);
let separable_state = Array1::from_vec(vec![
Complex64::new(1.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
]);
system.set_state(separable_state).expect("Operation failed");
let concurrence = system.calculate_concurrence().expect("Operation failed");
assert_relative_eq!(concurrence, 0.0, epsilon = 1e-10);
}
#[test]
fn test_quantum_gates() {
let masses = Array1::from_vec(vec![1.0, 1.0]);
let mut system = MultiParticleEntanglement::new(2, masses);
let initial_state = Array1::from_vec(vec![
Complex64::new(1.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
]);
system.set_state(initial_state).expect("Operation failed");
system.apply_hadamard(0).expect("Operation failed");
let state = system.get_state();
let inv_sqrt2 = 1.0 / (2.0_f64).sqrt();
assert_relative_eq!(state[0].re, inv_sqrt2, epsilon = 1e-10); assert_relative_eq!(state[2].re, inv_sqrt2, epsilon = 1e-10);
system.apply_cnot(0, 1).expect("Operation failed");
let state = system.get_state();
assert_relative_eq!(state[0].re, inv_sqrt2, epsilon = 1e-10); assert_relative_eq!(state[3].re, inv_sqrt2, epsilon = 1e-10); assert_relative_eq!(state[1].norm(), 0.0, epsilon = 1e-10);
assert_relative_eq!(state[2].norm(), 0.0, epsilon = 1e-10);
}
}