use crate::error::{SpatialError, SpatialResult};
use scirs2_core::ndarray::Array1;
use scirs2_core::numeric::Complex64;
use scirs2_core::random::{Rng, RngExt};
use std::f64::consts::SQRT_2;
pub type QuantumAmplitude = Complex64;
#[derive(Debug, Clone)]
pub struct QuantumState {
pub amplitudes: Array1<QuantumAmplitude>,
pub numqubits: usize,
}
impl QuantumState {
pub fn new(amplitudes: Array1<QuantumAmplitude>) -> SpatialResult<Self> {
let num_states = amplitudes.len();
if !num_states.is_power_of_two() {
return Err(SpatialError::InvalidInput(
"Number of amplitudes must be a power of 2".to_string(),
));
}
let numqubits = (num_states as f64).log2() as usize;
Ok(Self {
amplitudes,
numqubits,
})
}
pub fn zero_state(numqubits: usize) -> Self {
let num_states = 1 << numqubits;
let mut amplitudes = Array1::zeros(num_states);
amplitudes[0] = Complex64::new(1.0, 0.0);
Self {
amplitudes,
numqubits,
}
}
pub fn uniform_superposition(numqubits: usize) -> Self {
let num_states = 1 << numqubits;
let amplitude = Complex64::new(1.0 / (num_states as f64).sqrt(), 0.0);
let amplitudes = Array1::from_elem(num_states, amplitude);
Self {
amplitudes,
numqubits,
}
}
pub fn measure(&self) -> usize {
let mut rng = scirs2_core::random::rng();
let probabilities: Vec<f64> = self.amplitudes.iter().map(|amp| amp.norm_sqr()).collect();
let mut cumulative = 0.0;
let random_value = rng.random_range(0.0..1.0);
for (i, &prob) in probabilities.iter().enumerate() {
cumulative += prob;
if random_value <= cumulative {
return i;
}
}
probabilities.len() - 1
}
pub fn probability(&self, state: usize) -> f64 {
if state >= self.amplitudes.len() {
0.0
} else {
self.amplitudes[state].norm_sqr()
}
}
pub fn hadamard(&mut self, qubit: usize) -> SpatialResult<()> {
if qubit >= self.numqubits {
return Err(SpatialError::InvalidInput(format!(
"Qubit index {qubit} out of range"
)));
}
let mut new_amplitudes = self.amplitudes.clone();
let qubit_mask = 1 << qubit;
for i in 0..self.amplitudes.len() {
let j = i ^ qubit_mask; if i < j {
let amp_i = self.amplitudes[i];
let amp_j = self.amplitudes[j];
new_amplitudes[i] = (amp_i + amp_j) / SQRT_2;
new_amplitudes[j] = (amp_i - amp_j) / SQRT_2;
}
}
self.amplitudes = new_amplitudes;
Ok(())
}
pub fn phase_rotation(&mut self, qubit: usize, angle: f64) -> SpatialResult<()> {
if qubit >= self.numqubits {
return Err(SpatialError::InvalidInput(format!(
"Qubit index {qubit} out of range"
)));
}
let phase = Complex64::new(0.0, angle).exp();
let qubit_mask = 1 << qubit;
for i in 0..self.amplitudes.len() {
if (i & qubit_mask) != 0 {
self.amplitudes[i] *= phase;
}
}
Ok(())
}
pub fn controlled_rotation(
&mut self,
control: usize,
target: usize,
angle: f64,
) -> SpatialResult<()> {
if control >= self.numqubits || target >= self.numqubits {
return Err(SpatialError::InvalidInput(
"Qubit indices out of range".to_string(),
));
}
let control_mask = 1 << control;
let target_mask = 1 << target;
let cos_half = (angle / 2.0).cos();
let sin_half = (angle / 2.0).sin();
let mut new_amplitudes = self.amplitudes.clone();
for i in 0..self.amplitudes.len() {
if (i & control_mask) != 0 {
let j = i ^ target_mask; if i < j {
let amp_i = self.amplitudes[i];
let amp_j = self.amplitudes[j];
new_amplitudes[i] = Complex64::new(cos_half, 0.0) * amp_i
- Complex64::new(0.0, sin_half) * amp_j;
new_amplitudes[j] = Complex64::new(0.0, sin_half) * amp_i
+ Complex64::new(cos_half, 0.0) * amp_j;
}
}
}
self.amplitudes = new_amplitudes;
Ok(())
}
pub fn pauli_x(&mut self, qubit: usize) -> SpatialResult<()> {
if qubit >= self.numqubits {
return Err(SpatialError::InvalidInput(format!(
"Qubit index {qubit} out of range"
)));
}
let qubit_mask = 1 << qubit;
let mut new_amplitudes = self.amplitudes.clone();
for i in 0..self.amplitudes.len() {
let j = i ^ qubit_mask; new_amplitudes[i] = self.amplitudes[j];
}
self.amplitudes = new_amplitudes;
Ok(())
}
pub fn pauli_y(&mut self, qubit: usize) -> SpatialResult<()> {
if qubit >= self.numqubits {
return Err(SpatialError::InvalidInput(format!(
"Qubit index {qubit} out of range"
)));
}
let qubit_mask = 1 << qubit;
let mut new_amplitudes = self.amplitudes.clone();
let i_complex = Complex64::new(0.0, 1.0);
for i in 0..self.amplitudes.len() {
let j = i ^ qubit_mask; if (i & qubit_mask) == 0 {
new_amplitudes[j] = i_complex * self.amplitudes[i];
new_amplitudes[i] = Complex64::new(0.0, 0.0);
} else {
new_amplitudes[j] = -i_complex * self.amplitudes[i];
new_amplitudes[i] = Complex64::new(0.0, 0.0);
}
}
self.amplitudes = new_amplitudes;
Ok(())
}
pub fn pauli_z(&mut self, qubit: usize) -> SpatialResult<()> {
if qubit >= self.numqubits {
return Err(SpatialError::InvalidInput(format!(
"Qubit index {qubit} out of range"
)));
}
let qubit_mask = 1 << qubit;
for i in 0..self.amplitudes.len() {
if (i & qubit_mask) != 0 {
self.amplitudes[i] *= -1.0;
}
}
Ok(())
}
pub fn num_qubits(&self) -> usize {
self.numqubits
}
pub fn num_states(&self) -> usize {
self.amplitudes.len()
}
pub fn is_normalized(&self) -> bool {
let norm_squared: f64 = self.amplitudes.iter().map(|amp| amp.norm_sqr()).sum();
(norm_squared - 1.0).abs() < 1e-10
}
pub fn normalize(&mut self) {
let norm: f64 = self
.amplitudes
.iter()
.map(|amp| amp.norm_sqr())
.sum::<f64>()
.sqrt();
if norm > 1e-10 {
for amp in self.amplitudes.iter_mut() {
*amp /= norm;
}
}
}
pub fn amplitude(&self, state: usize) -> Option<QuantumAmplitude> {
self.amplitudes.get(state).copied()
}
pub fn set_amplitude(
&mut self,
state: usize,
amplitude: QuantumAmplitude,
) -> SpatialResult<()> {
if state >= self.amplitudes.len() {
return Err(SpatialError::InvalidInput(
"State index out of range".to_string(),
));
}
self.amplitudes[state] = amplitude;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64::consts::PI;
#[test]
fn test_zero_state_creation() {
let state = QuantumState::zero_state(2);
assert_eq!(state.num_qubits(), 2);
assert_eq!(state.num_states(), 4);
assert_eq!(state.probability(0), 1.0);
assert_eq!(state.probability(1), 0.0);
assert!(state.is_normalized());
}
#[test]
fn test_uniform_superposition() {
let state = QuantumState::uniform_superposition(2);
assert_eq!(state.num_qubits(), 2);
assert_eq!(state.num_states(), 4);
for i in 0..4 {
assert!((state.probability(i) - 0.25).abs() < 1e-10);
}
assert!(state.is_normalized());
}
#[test]
fn test_hadamard_gate() {
let mut state = QuantumState::zero_state(1);
state.hadamard(0).expect("Operation failed");
assert!((state.probability(0) - 0.5).abs() < 1e-10);
assert!((state.probability(1) - 0.5).abs() < 1e-10);
assert!(state.is_normalized());
}
#[test]
fn test_pauli_x_gate() {
let mut state = QuantumState::zero_state(1);
state.pauli_x(0).expect("Operation failed");
assert_eq!(state.probability(0), 0.0);
assert_eq!(state.probability(1), 1.0);
assert!(state.is_normalized());
}
#[test]
fn test_phase_rotation() {
let mut state = QuantumState::uniform_superposition(1);
state.phase_rotation(0, PI).expect("Operation failed");
assert!(state.is_normalized());
assert!((state.probability(0) - 0.5).abs() < 1e-10);
assert!((state.probability(1) - 0.5).abs() < 1e-10);
}
#[test]
fn test_controlled_rotation() {
let mut state = QuantumState::zero_state(2);
state.hadamard(0).expect("Operation failed");
state
.controlled_rotation(0, 1, PI)
.expect("Operation failed");
assert!(state.is_normalized());
assert!((state.probability(0) - 0.5).abs() < 1e-10);
assert!((state.probability(3) - 0.5).abs() < 1e-10);
}
#[test]
fn test_measurement() {
let state = QuantumState::zero_state(2);
let result = state.measure();
assert_eq!(result, 0); }
#[test]
fn test_invalid_qubit_index() {
let mut state = QuantumState::zero_state(2);
assert!(state.hadamard(2).is_err()); assert!(state.pauli_x(2).is_err());
assert!(state.phase_rotation(2, PI).is_err());
}
#[test]
fn test_amplitude_access() {
let state = QuantumState::zero_state(2);
assert_eq!(state.amplitude(0), Some(Complex64::new(1.0, 0.0)));
assert_eq!(state.amplitude(1), Some(Complex64::new(0.0, 0.0)));
assert_eq!(state.amplitude(10), None); }
#[test]
fn test_normalization() {
let amplitudes = Array1::from_vec(vec![
Complex64::new(2.0, 0.0),
Complex64::new(2.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
]);
let mut state = QuantumState::new(amplitudes).expect("Operation failed");
assert!(!state.is_normalized());
state.normalize();
assert!(state.is_normalized());
assert!((state.probability(0) - 0.5).abs() < 1e-10);
assert!((state.probability(1) - 0.5).abs() < 1e-10);
}
}