use num_complex::Complex;
use rand::{Rng, RngExt};
use std::f64::consts::PI;
use super::error::{QuantumError, QuantumResult};
#[derive(Debug, Clone, PartialEq)]
pub struct Qubit {
pub(crate) alpha: Complex<f64>,
pub(crate) beta: Complex<f64>,
}
impl Qubit {
pub fn new(alpha: Complex<f64>, beta: Complex<f64>) -> QuantumResult<Self> {
let norm_sq = alpha.norm_sqr() + beta.norm_sqr();
if norm_sq < 1e-15 {
return Err(QuantumError::ZeroStateVector);
}
let norm = norm_sq.sqrt();
Ok(Self {
alpha: alpha / norm,
beta: beta / norm,
})
}
pub fn new_zero() -> Self {
Self {
alpha: Complex::new(1.0, 0.0),
beta: Complex::new(0.0, 0.0),
}
}
pub fn new_one() -> Self {
Self {
alpha: Complex::new(0.0, 0.0),
beta: Complex::new(1.0, 0.0),
}
}
pub fn new_superposition(theta: f64, phi: f64) -> Self {
let half = theta / 2.0;
Self {
alpha: Complex::new(half.cos(), 0.0),
beta: Complex::from_polar(half.sin(), phi),
}
}
pub fn new_plus() -> Self {
let s = 1.0 / 2.0_f64.sqrt();
Self {
alpha: Complex::new(s, 0.0),
beta: Complex::new(s, 0.0),
}
}
pub fn new_minus() -> Self {
let s = 1.0 / 2.0_f64.sqrt();
Self {
alpha: Complex::new(s, 0.0),
beta: Complex::new(-s, 0.0),
}
}
pub fn alpha(&self) -> Complex<f64> {
self.alpha
}
pub fn beta(&self) -> Complex<f64> {
self.beta
}
pub fn prob_zero(&self) -> f64 {
self.alpha.norm_sqr()
}
pub fn prob_one(&self) -> f64 {
self.beta.norm_sqr()
}
pub fn is_normalised(&self, tol: f64) -> bool {
((self.alpha.norm_sqr() + self.beta.norm_sqr()) - 1.0).abs() < tol
}
pub fn measure<R: Rng>(&self, rng: &mut R) -> (u8, Qubit) {
let p0 = self.prob_zero();
let sample: f64 = rng.random();
if sample < p0 {
(0, Qubit::new_zero())
} else {
(1, Qubit::new_one())
}
}
pub fn bloch_angles(&self) -> (f64, f64) {
let theta = 2.0 * self.alpha.norm().acos().min(PI);
let phi = {
let raw = self.beta.arg() - self.alpha.arg();
let normalised = raw.rem_euclid(2.0 * PI);
normalised
};
(theta, phi)
}
pub fn to_register(&self) -> QubitRegister {
QubitRegister {
amplitudes: vec![self.alpha, self.beta],
n_qubits: 1,
}
}
}
impl Default for Qubit {
fn default() -> Self {
Self::new_zero()
}
}
impl std::fmt::Display for Qubit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"({:.6} + {:.6}i)|0⟩ + ({:.6} + {:.6}i)|1⟩",
self.alpha.re, self.alpha.im, self.beta.re, self.beta.im
)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct QubitRegister {
pub(crate) amplitudes: Vec<Complex<f64>>,
pub(crate) n_qubits: usize,
}
impl QubitRegister {
pub fn new(n_qubits: usize, amplitudes: Vec<Complex<f64>>) -> QuantumResult<Self> {
let expected_dim = 1usize
.checked_shl(n_qubits as u32)
.ok_or(QuantumError::TooManyQubits(n_qubits))?;
if amplitudes.len() != expected_dim {
return Err(QuantumError::DimensionMismatch {
expected: expected_dim,
actual: amplitudes.len(),
});
}
let norm_sq: f64 = amplitudes.iter().map(|a| a.norm_sqr()).sum();
if norm_sq < 1e-15 {
return Err(QuantumError::ZeroStateVector);
}
let norm = norm_sq.sqrt();
let normalised = amplitudes.iter().map(|a| a / norm).collect();
Ok(Self {
amplitudes: normalised,
n_qubits,
})
}
pub fn from_qubit(q: &Qubit) -> Self {
Self {
amplitudes: vec![q.alpha, q.beta],
n_qubits: 1,
}
}
pub fn new_zero_state(n_qubits: usize) -> QuantumResult<Self> {
if n_qubits == 0 {
return Err(QuantumError::InvalidQubitCount(n_qubits));
}
let dim = 1usize
.checked_shl(n_qubits as u32)
.ok_or(QuantumError::TooManyQubits(n_qubits))?;
let mut amps = vec![Complex::new(0.0, 0.0); dim];
amps[0] = Complex::new(1.0, 0.0);
Ok(Self {
amplitudes: amps,
n_qubits,
})
}
pub fn new_uniform_superposition(n_qubits: usize) -> QuantumResult<Self> {
if n_qubits == 0 {
return Err(QuantumError::InvalidQubitCount(n_qubits));
}
let dim = 1usize
.checked_shl(n_qubits as u32)
.ok_or(QuantumError::TooManyQubits(n_qubits))?;
let amp = Complex::new(1.0 / (dim as f64).sqrt(), 0.0);
Ok(Self {
amplitudes: vec![amp; dim],
n_qubits,
})
}
pub fn new_basis_state(n_qubits: usize, k: usize) -> QuantumResult<Self> {
if n_qubits == 0 {
return Err(QuantumError::InvalidQubitCount(n_qubits));
}
let dim = 1usize
.checked_shl(n_qubits as u32)
.ok_or(QuantumError::TooManyQubits(n_qubits))?;
if k >= dim {
return Err(QuantumError::BasisIndexOutOfRange { index: k, dim });
}
let mut amps = vec![Complex::new(0.0, 0.0); dim];
amps[k] = Complex::new(1.0, 0.0);
Ok(Self {
amplitudes: amps,
n_qubits,
})
}
pub fn n_qubits(&self) -> usize {
self.n_qubits
}
pub fn dim(&self) -> usize {
self.amplitudes.len()
}
pub fn amplitudes(&self) -> &[Complex<f64>] {
&self.amplitudes
}
pub fn amplitudes_mut(&mut self) -> &mut Vec<Complex<f64>> {
&mut self.amplitudes
}
pub fn amplitude(&self, k: usize) -> QuantumResult<Complex<f64>> {
self.amplitudes
.get(k)
.copied()
.ok_or(QuantumError::BasisIndexOutOfRange {
index: k,
dim: self.dim(),
})
}
pub fn probability(&self, k: usize) -> QuantumResult<f64> {
Ok(self.amplitude(k)?.norm_sqr())
}
pub fn probabilities(&self) -> Vec<f64> {
self.amplitudes.iter().map(|a| a.norm_sqr()).collect()
}
pub fn normalise(&mut self) -> QuantumResult<()> {
let norm_sq: f64 = self.amplitudes.iter().map(|a| a.norm_sqr()).sum();
if norm_sq < 1e-15 {
return Err(QuantumError::ZeroStateVector);
}
let norm = norm_sq.sqrt();
for a in &mut self.amplitudes {
*a /= norm;
}
Ok(())
}
pub fn is_normalised(&self, tol: f64) -> bool {
let norm_sq: f64 = self.amplitudes.iter().map(|a| a.norm_sqr()).sum();
(norm_sq - 1.0).abs() < tol
}
pub fn tensor_product(a: &QubitRegister, b: &QubitRegister) -> QubitRegister {
let n = a.n_qubits + b.n_qubits;
let mut amps = Vec::with_capacity(a.dim() * b.dim());
for &_a in &a.amplitudes {
for &_b in &b.amplitudes {
amps.push(amp_a * amp_b);
}
}
QubitRegister {
amplitudes: amps,
n_qubits: n,
}
}
pub fn measure_qubit<R: Rng>(
&self,
qubit_idx: usize,
rng: &mut R,
) -> QuantumResult<(u8, QubitRegister)> {
if qubit_idx >= self.n_qubits {
return Err(QuantumError::QubitIndexOutOfRange {
index: qubit_idx,
n_qubits: self.n_qubits,
});
}
let mut prob_one: f64 = 0.0;
for (k, amp) in self.amplitudes.iter().enumerate() {
if (k >> qubit_idx) & 1 == 1 {
prob_one += amp.norm_sqr();
}
}
let sample: f64 = rng.random();
let outcome: u8 = if sample < prob_one { 1 } else { 0 };
let mut new_amps = self.amplitudes.clone();
for (k, amp) in new_amps.iter_mut().enumerate() {
let bit = ((k >> qubit_idx) & 1) as u8;
if bit != outcome {
*amp = Complex::new(0.0, 0.0);
}
}
let mut collapsed = QubitRegister {
amplitudes: new_amps,
n_qubits: self.n_qubits,
};
collapsed.normalise()?;
Ok((outcome, collapsed))
}
pub fn measure_all<R: Rng>(&self, rng: &mut R) -> Vec<u8> {
let probs = self.probabilities();
let sample: f64 = rng.random();
let mut cumulative = 0.0;
let mut outcome_index = probs.len().saturating_sub(1);
for (i, &p) in probs.iter().enumerate() {
cumulative += p;
if sample < cumulative {
outcome_index = i;
break;
}
}
(0..self.n_qubits)
.map(|q| ((outcome_index >> q) & 1) as u8)
.collect()
}
pub fn inner_product(&self, other: &QubitRegister) -> QuantumResult<Complex<f64>> {
if self.n_qubits != other.n_qubits {
return Err(QuantumError::DimensionMismatch {
expected: self.dim(),
actual: other.dim(),
});
}
let ip = self
.amplitudes
.iter()
.zip(other.amplitudes.iter())
.map(|(a, b)| b.conj() * a)
.sum();
Ok(ip)
}
pub fn fidelity(&self, other: &QubitRegister) -> QuantumResult<f64> {
let ip = self.inner_product(other)?;
Ok(ip.norm_sqr())
}
pub fn entropy(&self) -> f64 {
let probs = self.probabilities();
-probs
.iter()
.filter(|&&p| p > 1e-15)
.map(|&p| p * p.ln())
.sum::<f64>()
}
}
impl Default for QubitRegister {
fn default() -> Self {
Self {
amplitudes: vec![Complex::new(1.0, 0.0), Complex::new(0.0, 0.0)],
n_qubits: 1,
}
}
}
impl std::fmt::Display for QubitRegister {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "QubitRegister({} qubits) [", self.n_qubits)?;
for (i, amp) in self.amplitudes.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
if amp.im >= 0.0 {
write!(f, "|{:0>width$b}⟩: {:.4}+{:.4}i", i, amp.re, amp.im, width = self.n_qubits)?;
} else {
write!(f, "|{:0>width$b}⟩: {:.4}{:.4}i", i, amp.re, amp.im, width = self.n_qubits)?;
}
}
write!(f, "]")
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::SeedableRng;
use rand_chacha::ChaCha20Rng;
const TOL: f64 = 1e-12;
#[test]
fn test_qubit_zero_normalised() {
let q = Qubit::new_zero();
assert!(q.is_normalised(TOL));
assert!((q.prob_zero() - 1.0).abs() < TOL);
assert!(q.prob_one().abs() < TOL);
}
#[test]
fn test_qubit_one_normalised() {
let q = Qubit::new_one();
assert!(q.is_normalised(TOL));
assert!(q.prob_zero().abs() < TOL);
assert!((q.prob_one() - 1.0).abs() < TOL);
}
#[test]
fn test_qubit_superposition() {
let q = Qubit::new_plus();
assert!(q.is_normalised(TOL));
assert!((q.prob_zero() - 0.5).abs() < TOL);
assert!((q.prob_one() - 0.5).abs() < TOL);
}
#[test]
fn test_bloch_sphere_zero() {
let q = Qubit::new_superposition(0.0, 0.0);
assert!((q.prob_zero() - 1.0).abs() < TOL);
}
#[test]
fn test_bloch_sphere_one() {
let q = Qubit::new_superposition(std::f64::consts::PI, 0.0);
assert!((q.prob_one() - 1.0).abs() < 1e-10);
}
#[test]
fn test_qubit_measure_deterministic_zero() {
let q = Qubit::new_zero();
let mut rng = ChaCha20Rng::seed_from_u64(0);
let (outcome, post) = q.measure(&mut rng);
assert_eq!(outcome, 0);
assert!(post.is_normalised(TOL));
assert!((post.prob_zero() - 1.0).abs() < TOL);
}
#[test]
fn test_qubit_measure_deterministic_one() {
let q = Qubit::new_one();
let mut rng = ChaCha20Rng::seed_from_u64(0);
let (outcome, post) = q.measure(&mut rng);
assert_eq!(outcome, 1);
assert!(post.is_normalised(TOL));
}
#[test]
fn test_register_zero_state() {
let reg = QubitRegister::new_zero_state(3).expect("valid");
assert_eq!(reg.n_qubits(), 3);
assert_eq!(reg.dim(), 8);
assert!((reg.probability(0).expect("ok") - 1.0).abs() < TOL);
}
#[test]
fn test_register_uniform_superposition() {
let reg = QubitRegister::new_uniform_superposition(2).expect("valid");
let p = reg.probability(0).expect("ok");
assert!((p - 0.25).abs() < TOL);
}
#[test]
fn test_tensor_product_dims() {
let r1 = QubitRegister::new_zero_state(2).expect("valid");
let r2 = QubitRegister::new_zero_state(3).expect("valid");
let combined = QubitRegister::tensor_product(&r1, &r2);
assert_eq!(combined.n_qubits(), 5);
assert_eq!(combined.dim(), 32);
}
#[test]
fn test_measure_all_basis_state() {
let reg = QubitRegister::new_basis_state(3, 5).expect("valid");
let mut rng = ChaCha20Rng::seed_from_u64(42);
let bits = reg.measure_all(&mut rng);
assert_eq!(bits, vec![1, 0, 1]);
}
#[test]
fn test_fidelity_same_state() {
let r = QubitRegister::new_zero_state(2).expect("valid");
let f = r.fidelity(&r).expect("ok");
assert!((f - 1.0).abs() < TOL);
}
#[test]
fn test_fidelity_orthogonal() {
let r0 = QubitRegister::new_basis_state(1, 0).expect("valid");
let r1 = QubitRegister::new_basis_state(1, 1).expect("valid");
let f = r0.fidelity(&r1).expect("ok");
assert!(f.abs() < TOL);
}
}