use crate::circuit::ZERO_MARGIN;
use crate::Complex;
use crate::QuantrError;
use crate::{complex_Re, COMPLEX_ZERO};
use std::collections::HashMap;
use std::hash::Hash;
#[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)]
pub enum Qubit {
Zero,
One,
}
impl Qubit {
pub fn kronecker_prod(self, other: Qubit) -> ProductState {
ProductState::new_unchecked(&[self, other])
}
pub fn into_state(self) -> ProductState {
ProductState::new_unchecked(&[self])
}
}
#[derive(Clone, Hash, PartialEq, Eq, Debug)]
pub struct ProductState {
pub qubits: Vec<Qubit>,
}
impl ProductState {
pub fn new(product_state: &[Qubit]) -> Result<ProductState, QuantrError> {
if product_state.is_empty() {
return Err(QuantrError {
message: String::from(
"The slice of qubits is empty, it needs to at least have one element.",
),
});
}
Ok(ProductState {
qubits: product_state.to_vec(),
})
}
pub(super) fn new_unchecked(product_state: &[Qubit]) -> ProductState {
ProductState {
qubits: product_state.to_vec(),
}
}
pub(super) fn insert_qubits(&self, qubits: &[Qubit], pos: &[usize]) -> ProductState {
let mut edited_qubits: Vec<Qubit> = self.qubits.clone();
let num_qubits: usize = qubits.len();
if num_qubits != pos.len() {
panic!("Size of qubits and positions must be equal.")
}
for (index, position) in pos.iter().enumerate() {
edited_qubits[*position] = qubits[index];
}
ProductState::new_unchecked(&edited_qubits)
}
fn num_qubits(&self) -> usize {
self.qubits.len()
}
pub fn invert_digit(&mut self, place_num: usize) -> Result<(), QuantrError> {
if place_num >= self.num_qubits() {
return Err(QuantrError { message: format!("The position of the binary digit, {}, is out of bounds. The product dimension is {}, and so the position must be strictly less.", place_num, self.num_qubits()) });
}
let old_qubit: Qubit = self.qubits[place_num].clone();
self.qubits[place_num] = if old_qubit == Qubit::Zero {
Qubit::One
} else {
Qubit::Zero
};
Ok(())
}
pub fn kronecker_prod(mut self, other: Qubit) -> ProductState {
self.qubits.push(other);
self
}
pub(super) fn get(&self, qubit_number: usize) -> Qubit {
self.qubits[qubit_number]
}
pub fn to_string(&self) -> String {
self.qubits
.iter()
.map(|q| match q {
Qubit::Zero => "0",
Qubit::One => "1",
})
.collect::<String>()
}
pub fn into_super_position(self) -> SuperPosition {
SuperPosition::new(self.num_qubits())
.set_amplitudes_from_states_unchecked(&HashMap::from([(self, complex_Re!(1f64))]))
}
fn comp_basis(&self) -> usize {
self.qubits
.iter()
.rev()
.enumerate()
.map(|(pos, i)| match i {
Qubit::Zero => 0u32,
Qubit::One => 2u32.pow(pos as u32),
})
.fold(0, |sum, i| sum + i) as usize
}
fn binary_basis(index: usize, basis_size: usize) -> ProductState {
let binary_index: Vec<Qubit> = (0..basis_size)
.rev()
.map(|n| match (index >> n) & 1 == 1 {
false => Qubit::Zero,
true => Qubit::One,
})
.collect();
ProductState::new_unchecked(binary_index.as_slice())
}
}
#[derive(PartialEq, Debug, Clone)]
pub struct SuperPosition {
pub amplitudes: Vec<Complex<f64>>,
pub product_dim: usize,
index: usize,
}
pub struct SuperPositionIterator<'a> {
super_position: &'a SuperPosition,
index: usize,
}
impl<'a> IntoIterator for &'a SuperPosition {
type Item = (ProductState, Complex<f64>);
type IntoIter = SuperPositionIterator<'a>;
fn into_iter(self) -> Self::IntoIter {
SuperPositionIterator {
super_position: self,
index: 0,
}
}
}
impl<'a> Iterator for SuperPositionIterator<'a> {
type Item = (ProductState, Complex<f64>);
fn next(&mut self) -> Option<Self::Item> {
if self.index < self.super_position.amplitudes.len() {
let option_state: Self::Item = (
ProductState::binary_basis(self.index, self.super_position.product_dim),
self.super_position.amplitudes[self.index],
);
self.index += 1;
Some(option_state)
} else {
self.index = 0;
None
}
}
}
impl SuperPosition {
pub fn new(num_qubits: usize) -> SuperPosition {
let mut new_amps: Vec<Complex<f64>> = vec![COMPLEX_ZERO; 2usize.pow(num_qubits as u32)];
new_amps[0] = complex_Re!(1f64);
SuperPosition {
amplitudes: new_amps,
product_dim: num_qubits,
index: 0,
}
}
pub fn get_amplitude(&self, pos: usize) -> Result<Complex<f64>, QuantrError> {
if pos >= self.amplitudes.len() {
let length = self.amplitudes.len();
Err(QuantrError { message: format!("Failed to retrieve amplitude from list. Index given was, {pos}, which is greater than length of list, {length}."),
})
} else {
Ok(*self.amplitudes.get(pos).unwrap())
}
}
pub fn get_amplitude_from_state(
&self,
prod_state: ProductState,
) -> Result<Complex<f64>, QuantrError> {
if 2usize.pow(prod_state.qubits.len() as u32) != self.amplitudes.len() {
return Err(QuantrError { message: format!("Unable to retreive product state, |{:?}> with dimension {}. The superposition is a linear combination of states with different dimension. These dimensions should be equal.", prod_state.to_string(), prod_state.num_qubits()),});
}
Ok(*self.amplitudes.get(prod_state.comp_basis()).unwrap())
}
pub fn set_amplitudes(self, amplitudes: &[Complex<f64>]) -> Result<SuperPosition, QuantrError> {
if amplitudes.len() != self.amplitudes.len() {
return Err(QuantrError {
message: format!("The slice given to set the amplitudes in the computational basis has length {}, when it should have length {}.", amplitudes.len(), self.amplitudes.len()),
});
}
if !Self::equal_within_error(amplitudes.iter().map(|x| x.abs_square()).sum::<f64>(), 1f64) {
return Err(QuantrError {
message: String::from("Slice given to set amplitudes in super position does not conserve probability, the absolute square sum of the coefficents must be one."),
});
}
let mut new_amps: Vec<Complex<f64>> = (*self.amplitudes).to_vec();
Self::copy_slice_to_vec(&mut new_amps, amplitudes);
Ok(SuperPosition {
amplitudes: new_amps,
product_dim: self.product_dim,
index: 0,
})
}
fn copy_slice_to_vec(vector: &mut Vec<Complex<f64>>, slice: &[Complex<f64>]) {
for (pos, amp) in slice.iter().enumerate() {
vector[pos] = *amp;
}
}
fn equal_within_error(num: f64, compare_num: f64) -> bool {
num < compare_num + ZERO_MARGIN && num > compare_num - ZERO_MARGIN
}
pub(crate) fn set_amplitudes_unchecked(
self,
amplitudes: &[Complex<f64>],
) -> Result<SuperPosition, QuantrError> {
let mut new_amps: Vec<Complex<f64>> = (*self.amplitudes).to_vec();
Self::copy_slice_to_vec(&mut new_amps, amplitudes);
Ok(SuperPosition {
amplitudes: new_amps,
product_dim: self.product_dim,
index: 0,
})
}
pub fn set_amplitudes_from_states(
&self,
amplitudes: &HashMap<ProductState, Complex<f64>>,
) -> Result<SuperPosition, QuantrError> {
if amplitudes.is_empty() {
return Err(QuantrError { message: String::from("An empty HashMap was given. A superposition must have at least one non-zero state.") });
}
let product_size: usize = self.amplitudes.len().trailing_zeros() as usize;
let mut total_amplitude: f64 = 0f64;
for (states, amplitude) in amplitudes {
if states.num_qubits() != product_size {
return Err(QuantrError { message: format!("The first state has product dimension of {}, whilst the state, |{}>, found as a key in the HashMap has dimension {}.", product_size, states.to_string(), states.num_qubits()) });
}
total_amplitude += amplitude.abs_square();
}
if !Self::equal_within_error(total_amplitude, ZERO_MARGIN) {
return Err(QuantrError { message: String::from("The total sum of the absolute square of all amplitudes does not equal 1. That is, the superpositon does not conserve probability.") });
}
let mut new_amps: Vec<Complex<f64>> = vec![COMPLEX_ZERO; 2usize.pow(product_size as u32)];
Self::from_hash_to_array(amplitudes, &mut new_amps);
Ok(SuperPosition {
amplitudes: new_amps,
product_dim: self.product_dim,
index: 0,
})
}
pub(super) fn set_amplitudes_from_states_unchecked(
&self,
amplitudes: &HashMap<ProductState, Complex<f64>>,
) -> SuperPosition {
let product_size: usize = amplitudes.keys().next().unwrap().num_qubits();
let mut new_amps: Vec<Complex<f64>> = vec![COMPLEX_ZERO; 2usize.pow(product_size as u32)];
Self::from_hash_to_array(amplitudes, &mut new_amps);
SuperPosition {
amplitudes: new_amps,
product_dim: self.product_dim,
index: 0,
}
}
pub fn to_hash_map(&self) -> HashMap<ProductState, Complex<f64>> {
let mut super_pos_as_hash: HashMap<ProductState, Complex<f64>> = Default::default();
for (i, amp) in self.amplitudes.iter().enumerate() {
if !Self::equal_within_error(amp.abs_square(), 0f64) {
super_pos_as_hash.insert(ProductState::binary_basis(i, self.product_dim), *amp);
}
}
super_pos_as_hash
}
fn from_hash_to_array(
hash_amplitudes: &HashMap<ProductState, Complex<f64>>,
vec_amplitudes: &mut Vec<Complex<f64>>,
) {
for (key, val) in hash_amplitudes {
vec_amplitudes[key.comp_basis()] = *val;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::complex_Im;
use std::f64::consts::FRAC_1_SQRT_2;
#[test]
fn converts_productstate_to_superpos() {
assert_eq!(
ProductState::new_unchecked(&[Qubit::One, Qubit::Zero]).into_super_position(),
SuperPosition::new(2)
.set_amplitudes(&[COMPLEX_ZERO, COMPLEX_ZERO, complex_Re!(1f64), COMPLEX_ZERO])
.unwrap()
)
}
#[test]
fn converts_from_binary_to_comp_basis() {
assert_eq!(
ProductState::new_unchecked(&[Qubit::One, Qubit::Zero, Qubit::One]).comp_basis(),
5usize
);
assert_eq!(
ProductState::new_unchecked(&[Qubit::One, Qubit::One, Qubit::One]).comp_basis(),
7usize
);
assert_eq!(
ProductState::new_unchecked(&[Qubit::One, Qubit::Zero]).comp_basis(),
2usize
);
assert_eq!(
ProductState::new_unchecked(&[Qubit::One, Qubit::Zero, Qubit::One, Qubit::One])
.comp_basis(),
11usize
);
}
#[test]
fn retrieve_amplitude_from_state() {
assert_eq!(
SuperPosition::new(2)
.set_amplitudes(&[
COMPLEX_ZERO,
complex_Re!(FRAC_1_SQRT_2),
complex_Im!(-FRAC_1_SQRT_2),
COMPLEX_ZERO
])
.unwrap()
.get_amplitude_from_state(ProductState::new_unchecked(&[Qubit::Zero, Qubit::One]))
.unwrap(),
complex_Re!(FRAC_1_SQRT_2)
)
}
#[test]
fn retrieve_amplitude_from_list_pos() {
assert_eq!(
SuperPosition::new(2)
.set_amplitudes(&[
COMPLEX_ZERO,
complex_Re!(FRAC_1_SQRT_2),
complex_Im!(-FRAC_1_SQRT_2),
COMPLEX_ZERO
])
.unwrap()
.get_amplitude(2)
.unwrap(),
complex_Im!(-FRAC_1_SQRT_2)
)
}
#[test]
fn insert_qubits_in_state() {
assert_eq!(
ProductState::new_unchecked(&[Qubit::Zero, Qubit::Zero, Qubit::One]).qubits,
ProductState::new_unchecked(&[Qubit::One, Qubit::One, Qubit::One])
.insert_qubits(&[Qubit::Zero, Qubit::Zero], &[0, 1])
.qubits
);
}
#[test]
fn sets_amplitude_from_states() {
let states: HashMap<ProductState, Complex<f64>> = HashMap::from([
(
ProductState::new_unchecked(&[Qubit::Zero, Qubit::One]),
complex_Re!(FRAC_1_SQRT_2),
),
(
ProductState::new_unchecked(&[Qubit::One, Qubit::Zero]),
complex_Im!(-FRAC_1_SQRT_2),
),
]);
assert_eq!(
SuperPosition::new(2)
.set_amplitudes(&[
COMPLEX_ZERO,
complex_Re!(FRAC_1_SQRT_2),
complex_Im!(-FRAC_1_SQRT_2),
COMPLEX_ZERO
])
.unwrap()
.amplitudes,
SuperPosition::new(2)
.set_amplitudes_from_states_unchecked(&states)
.amplitudes
)
}
#[test]
#[should_panic]
fn sets_amplitude_from_states_wrong_dimension() {
let states: HashMap<ProductState, Complex<f64>> = HashMap::from([
(
ProductState::new_unchecked(&[Qubit::Zero, Qubit::One]),
complex_Re!(FRAC_1_SQRT_2),
),
(
ProductState::new_unchecked(&[Qubit::One, Qubit::Zero, Qubit::One]),
complex_Im!(-FRAC_1_SQRT_2),
),
]);
SuperPosition::new(2)
.set_amplitudes_from_states(&states)
.unwrap();
}
#[test]
#[should_panic]
fn sets_amplitude_from_states_breaks_probability() {
let states: HashMap<ProductState, Complex<f64>> = HashMap::from([
(
ProductState::new_unchecked(&[Qubit::Zero, Qubit::One]),
complex_Re!(FRAC_1_SQRT_2),
),
(
ProductState::new_unchecked(&[Qubit::One, Qubit::Zero]),
complex_Im!(-FRAC_1_SQRT_2 * 0.5f64),
),
]);
SuperPosition::new(2)
.set_amplitudes_from_states(&states)
.unwrap();
}
#[test]
#[should_panic]
fn catches_retrieve_amplitude_from_wrong_state() {
SuperPosition::new(2)
.set_amplitudes(&[
COMPLEX_ZERO,
complex_Re!(FRAC_1_SQRT_2),
complex_Im!(-FRAC_1_SQRT_2),
COMPLEX_ZERO,
])
.unwrap()
.get_amplitude_from_state(ProductState::new_unchecked(&[
Qubit::Zero,
Qubit::One,
Qubit::One,
]))
.unwrap();
}
#[test]
#[should_panic]
fn catches_retrieve_amplitude_from_wrong_list_pos() {
SuperPosition::new(2)
.set_amplitudes(&[
COMPLEX_ZERO,
complex_Re!(FRAC_1_SQRT_2),
complex_Im!(-FRAC_1_SQRT_2),
COMPLEX_ZERO,
])
.unwrap()
.get_amplitude(4)
.unwrap();
}
#[test]
#[should_panic]
fn catches_super_position_breaking_conservation() {
SuperPosition::new(2)
.set_amplitudes(&[
COMPLEX_ZERO,
complex_Re!(0.5f64),
COMPLEX_ZERO,
complex_Im!(-0.5f64),
])
.unwrap();
}
#[test]
fn converts_from_integer_to_product_state() {
assert_eq!(
ProductState::new_unchecked(&[Qubit::One, Qubit::One, Qubit::Zero]),
ProductState::binary_basis(6, 3)
)
}
#[test]
fn inverting_binary_digit() {
let mut inverted = ProductState::new_unchecked(&[Qubit::One, Qubit::One, Qubit::Zero]);
inverted.invert_digit(2).unwrap();
assert_eq!(
ProductState::new_unchecked(&[Qubit::One, Qubit::One, Qubit::One]),
inverted
)
}
}