use crate::complex::Complex64;
use crate::error::{Error, Result};
#[derive(Clone, PartialEq, Debug)]
pub struct State {
amps: Vec<Complex64>,
n: usize,
}
impl State {
#[must_use]
pub fn zero(n: usize) -> Self {
assert!(n < usize::BITS as usize, "2^{n} amplitudes overflows usize");
let mut amps = vec![Complex64::ZERO; 1usize << n];
amps[0] = Complex64::ONE;
Self { amps, n }
}
pub fn from_amplitudes(amps: Vec<Complex64>) -> Result<Self> {
let len = amps.len();
if !len.is_power_of_two() {
return Err(Error::DimensionMismatch {
len,
expected: len.next_power_of_two(),
});
}
let n = len.trailing_zeros() as usize;
Ok(Self { amps, n })
}
#[inline]
#[must_use]
pub fn num_qubits(&self) -> usize {
self.n
}
#[inline]
#[must_use]
pub fn dim(&self) -> usize {
self.amps.len()
}
#[inline]
#[must_use]
pub fn amplitudes(&self) -> &[Complex64] {
&self.amps
}
#[inline]
#[must_use]
pub fn amplitudes_mut(&mut self) -> &mut [Complex64] {
&mut self.amps
}
#[must_use]
pub fn norm_sqr(&self) -> f64 {
self.amps.iter().map(|a| a.norm_sqr()).sum()
}
pub fn normalize(&mut self) {
let norm = self.norm_sqr().sqrt();
if norm > 0.0 {
let inv = 1.0 / norm;
for a in &mut self.amps {
*a *= inv;
}
}
}
#[must_use]
pub fn probability(&self, basis: usize) -> f64 {
assert!(basis < self.amps.len(), "basis state {basis} out of range");
self.amps[basis].norm_sqr()
}
#[must_use]
pub fn overlap(&self, other: &Self) -> Complex64 {
assert_eq!(self.n, other.n, "states must have equal qubit counts");
let mut acc = Complex64::ZERO;
for (a, b) in self.amps.iter().zip(&other.amps) {
acc += b.conj() * *a;
}
acc
}
#[must_use]
pub fn fidelity(&self, other: &Self) -> f64 {
self.overlap(other).norm_sqr()
}
#[must_use]
pub fn bloch_vector(&self, qubit: usize) -> [f64; 3] {
assert!(qubit < self.n, "qubit {qubit} out of range");
let bit = 1usize << qubit;
let mut r00 = 0.0;
let mut r11 = 0.0;
let mut r01 = Complex64::ZERO;
for (j, amp) in self.amps.iter().enumerate() {
if j & bit == 0 {
r00 += amp.norm_sqr();
let partner = self.amps[j | bit];
r01 += *amp * partner.conj();
} else {
r11 += amp.norm_sqr();
}
}
[2.0 * r01.re, 2.0 * r01.im, r00 - r11]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn zero_state_is_normalized() {
for n in 0..6 {
assert!((State::zero(n).norm_sqr() - 1.0).abs() < 1e-12);
}
}
#[test]
fn from_amplitudes_rejects_non_power_of_two() {
let amps = vec![Complex64::ONE; 3];
assert!(matches!(
State::from_amplitudes(amps),
Err(Error::DimensionMismatch { len: 3, .. })
));
}
#[test]
fn fidelity_ignores_global_phase() {
let a = State::zero(2);
let mut b = a.clone();
for amp in b.amplitudes_mut() {
*amp *= Complex64::expi(0.7);
}
assert!((a.fidelity(&b) - 1.0).abs() < 1e-12);
}
#[test]
fn bloch_vector_of_zero_state_points_up() {
let s = State::zero(1);
let [x, y, z] = s.bloch_vector(0);
assert!(x.abs() < 1e-12 && y.abs() < 1e-12);
assert!((z - 1.0).abs() < 1e-12);
}
#[test]
fn normalize_fixes_drift() {
let mut s =
State::from_amplitudes(vec![Complex64::new(2.0, 0.0), Complex64::new(0.0, 2.0)])
.unwrap();
s.normalize();
assert!((s.norm_sqr() - 1.0).abs() < 1e-12);
}
}