use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use crate::new_modules::quantum::statevector::StateVector;
use num_traits::Float;
use scirs2_core::random::{Rng, RngExt, SeedableRng, StdRng};
use scirs2_core::Complex;
use std::collections::HashMap;
use std::fmt::Debug;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct MeasurementResult {
pub outcome: usize,
pub num_qubits: usize,
}
impl MeasurementResult {
pub fn as_binary_string(&self) -> String {
format!("{:0width$b}", self.outcome, width = self.num_qubits)
}
pub fn get_bit(&self, qubit: usize) -> Result<u8> {
if qubit >= self.num_qubits {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Qubit {} out of bounds",
qubit
)));
}
Ok(((self.outcome >> qubit) & 1) as u8)
}
}
#[derive(Clone, Debug)]
pub struct MeasurementStatistics {
pub counts: HashMap<usize, usize>,
pub total_shots: usize,
pub num_qubits: usize,
}
impl MeasurementStatistics {
pub fn new(num_qubits: usize) -> Self {
Self {
counts: HashMap::new(),
total_shots: 0,
num_qubits,
}
}
pub fn add_result(&mut self, outcome: usize) {
*self.counts.entry(outcome).or_insert(0) += 1;
self.total_shots += 1;
}
pub fn get_probability(&self, outcome: usize) -> f64 {
if self.total_shots == 0 {
return 0.0;
}
let count = self.counts.get(&outcome).unwrap_or(&0);
*count as f64 / self.total_shots as f64
}
pub fn get_probabilities(&self) -> HashMap<usize, f64> {
let mut probs = HashMap::new();
for (&outcome, &count) in &self.counts {
probs.insert(outcome, count as f64 / self.total_shots as f64);
}
probs
}
pub fn most_frequent(&self) -> Option<usize> {
self.counts
.iter()
.max_by_key(|(_, &count)| count)
.map(|(&outcome, _)| outcome)
}
pub fn entropy(&self) -> f64 {
if self.total_shots == 0 {
return 0.0;
}
let mut entropy = 0.0;
for &count in self.counts.values() {
if count > 0 {
let p = count as f64 / self.total_shots as f64;
entropy -= p * p.log2();
}
}
entropy
}
}
pub struct Measurement;
impl Measurement {
pub fn measure_all<T>(
state: &StateVector<T>,
seed: Option<u64>,
) -> Result<(MeasurementResult, StateVector<T>)>
where
T: Float + Clone + Debug + Into<f64> + From<f64>,
{
let mut rng = match seed {
Some(s) => StdRng::seed_from_u64(s),
None => {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_else(|_| Duration::from_secs(1));
StdRng::seed_from_u64(now.as_secs())
}
};
let probs = state.get_probabilities();
let prob_vec = probs.to_vec();
let prob_vec_f64: Vec<f64> = prob_vec.iter().map(|p| (*p).into()).collect();
let outcome = sample_discrete(&mut rng, &prob_vec_f64)?;
let dim = state.dim();
let mut post_amplitudes = vec![Complex::new(T::zero(), T::zero()); dim];
post_amplitudes[outcome] = Complex::new(T::one(), T::zero());
let post_state = StateVector::from_amplitudes(post_amplitudes)?;
Ok((
MeasurementResult {
outcome,
num_qubits: state.num_qubits(),
},
post_state,
))
}
pub fn measure_qubits<T>(
state: &StateVector<T>,
qubits: &[usize],
seed: Option<u64>,
) -> Result<Vec<u8>>
where
T: Float + Clone + Debug + Into<f64> + From<f64>,
{
for &qubit in qubits {
if qubit >= state.num_qubits() {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Qubit {} out of bounds",
qubit
)));
}
}
let mut rng = match seed {
Some(s) => StdRng::seed_from_u64(s),
None => {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_else(|_| Duration::from_secs(1));
StdRng::seed_from_u64(now.as_secs())
}
};
let mut results = Vec::new();
for &qubit in qubits {
let mut prob_one = T::zero();
for i in 0..state.dim() {
if (i >> qubit) & 1 == 1 {
prob_one = prob_one + state.get_probability(i)?;
}
}
let prob_one_f64: f64 = prob_one.into();
let random_val: f64 = rng.random();
results.push(if random_val < prob_one_f64 { 1 } else { 0 });
}
Ok(results)
}
pub fn sample<T>(
state: &StateVector<T>,
num_shots: usize,
seed: Option<u64>,
) -> Result<MeasurementStatistics>
where
T: Float + Clone + Debug + Into<f64> + From<f64>,
{
let mut rng = match seed {
Some(s) => StdRng::seed_from_u64(s),
None => {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_else(|_| Duration::from_secs(1));
StdRng::seed_from_u64(now.as_secs())
}
};
let mut stats = MeasurementStatistics::new(state.num_qubits());
let probs = state.get_probabilities();
let prob_vec = probs.to_vec();
let prob_vec_f64: Vec<f64> = prob_vec.iter().map(|p| (*p).into()).collect();
for _ in 0..num_shots {
let outcome = sample_discrete(&mut rng, &prob_vec_f64)?;
stats.add_result(outcome);
}
Ok(stats)
}
pub fn measure_x<T>(state: &StateVector<T>, qubit: usize, seed: Option<u64>) -> Result<u8>
where
T: Float + Clone + Debug + Into<f64> + From<f64>,
{
if qubit >= state.num_qubits() {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Qubit {} out of bounds",
qubit
)));
}
let mut state_copy = state.clone();
let h_gate = crate::new_modules::quantum::gates::hadamard()?;
state_copy.apply_gate(&h_gate, &[qubit])?;
let results = Self::measure_qubits(&state_copy, &[qubit], seed)?;
Ok(results[0])
}
pub fn measure_y<T>(state: &StateVector<T>, qubit: usize, seed: Option<u64>) -> Result<u8>
where
T: Float + Clone + Debug + Into<f64> + From<f64>,
{
if qubit >= state.num_qubits() {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Qubit {} out of bounds",
qubit
)));
}
let mut state_copy = state.clone();
let s_gate = crate::new_modules::quantum::gates::phase_gate()?;
state_copy.apply_gate(&s_gate, &[qubit])?;
state_copy.apply_gate(&s_gate, &[qubit])?;
state_copy.apply_gate(&s_gate, &[qubit])?;
let h_gate = crate::new_modules::quantum::gates::hadamard()?;
state_copy.apply_gate(&h_gate, &[qubit])?;
let results = Self::measure_qubits(&state_copy, &[qubit], seed)?;
Ok(results[0])
}
pub fn measure_z<T>(state: &StateVector<T>, qubit: usize, seed: Option<u64>) -> Result<u8>
where
T: Float + Clone + Debug + Into<f64> + From<f64>,
{
let results = Self::measure_qubits(state, &[qubit], seed)?;
Ok(results[0])
}
pub fn expectation_z<T>(state: &StateVector<T>, qubit: usize) -> Result<f64>
where
T: Float + Clone + Debug + Into<f64> + From<f64>,
{
if qubit >= state.num_qubits() {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Qubit {} out of bounds",
qubit
)));
}
let mut prob_zero = 0.0;
let mut prob_one = 0.0;
for i in 0..state.dim() {
let prob: f64 = state.get_probability(i)?.into();
if (i >> qubit) & 1 == 0 {
prob_zero += prob;
} else {
prob_one += prob;
}
}
Ok(prob_zero - prob_one)
}
}
fn sample_discrete<R: Rng>(rng: &mut R, probabilities: &[f64]) -> Result<usize> {
let random_val: f64 = rng.random();
let mut cumulative = 0.0;
for (i, &prob) in probabilities.iter().enumerate() {
cumulative += prob;
if random_val < cumulative {
return Ok(i);
}
}
Ok(probabilities.len() - 1)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_measurement_result_binary_string() {
let result = MeasurementResult {
outcome: 5,
num_qubits: 3,
};
assert_eq!(result.as_binary_string(), "101");
}
#[test]
fn test_measurement_result_get_bit() {
let result = MeasurementResult {
outcome: 6, num_qubits: 3,
};
assert_eq!(result.get_bit(0).expect("test: valid bit index"), 0);
assert_eq!(result.get_bit(1).expect("test: valid bit index"), 1);
assert_eq!(result.get_bit(2).expect("test: valid bit index"), 1);
}
#[test]
fn test_measure_all_zero_state() {
let state = StateVector::<f64>::new(2).expect("test: valid qubit count");
let (result, _) =
Measurement::measure_all(&state, Some(42)).expect("test: valid measurement");
assert_eq!(result.outcome, 0);
assert_eq!(result.num_qubits, 2);
}
#[test]
fn test_measure_all_superposition() {
let amplitudes = vec![
Complex::new(1.0 / 2.0_f64.sqrt(), 0.0),
Complex::new(1.0 / 2.0_f64.sqrt(), 0.0),
];
let state = StateVector::from_amplitudes(amplitudes).expect("test: valid amplitudes");
let mut zeros = 0;
let mut ones = 0;
for i in 0..100 {
let (result, _) =
Measurement::measure_all(&state, Some(i as u64)).expect("test: valid measurement");
if result.outcome == 0 {
zeros += 1;
} else {
ones += 1;
}
}
assert!(zeros > 30 && zeros < 70);
assert!(ones > 30 && ones < 70);
}
#[test]
fn test_sample() {
let amplitudes = vec![
Complex::new(1.0 / 2.0_f64.sqrt(), 0.0),
Complex::new(1.0 / 2.0_f64.sqrt(), 0.0),
];
let state = StateVector::from_amplitudes(amplitudes).expect("test: valid amplitudes");
let stats = Measurement::sample(&state, 1000, Some(42)).expect("test: valid sampling");
assert_eq!(stats.total_shots, 1000);
let prob_0 = stats.get_probability(0);
let prob_1 = stats.get_probability(1);
assert_relative_eq!(prob_0, 0.5, epsilon = 0.1);
assert_relative_eq!(prob_1, 0.5, epsilon = 0.1);
}
#[test]
fn test_measurement_statistics() {
let mut stats = MeasurementStatistics::new(2);
stats.add_result(0);
stats.add_result(0);
stats.add_result(1);
stats.add_result(3);
assert_eq!(stats.total_shots, 4);
assert_eq!(stats.get_probability(0), 0.5);
assert_eq!(stats.get_probability(1), 0.25);
assert_eq!(stats.most_frequent(), Some(0));
}
#[test]
fn test_entropy() {
let mut stats = MeasurementStatistics::new(2);
stats.add_result(0);
stats.add_result(1);
stats.add_result(2);
stats.add_result(3);
let entropy = stats.entropy();
assert_relative_eq!(entropy, 2.0, epsilon = 1e-10);
}
#[test]
fn test_measure_qubits() {
let state = StateVector::<f64>::new(3).expect("test: valid qubit count");
let results = Measurement::measure_qubits(&state, &[0, 1, 2], Some(42))
.expect("test: valid qubit measurement");
assert_eq!(results.len(), 3);
assert_eq!(results, vec![0, 0, 0]);
}
#[test]
fn test_measure_x_basis() {
let state = StateVector::<f64>::new(1).expect("test: valid qubit count");
let mut count_0 = 0;
for i in 0..100 {
if Measurement::measure_x(&state, 0, Some(i)).expect("test: valid x-basis measurement")
== 0
{
count_0 += 1;
}
}
assert!(count_0 > 30 && count_0 < 70);
}
#[test]
fn test_expectation_z() {
let state = StateVector::<f64>::new(1).expect("test: valid qubit count");
let exp = Measurement::expectation_z(&state, 0).expect("test: valid expectation value");
assert_relative_eq!(exp, 1.0, epsilon = 1e-10);
let amplitudes = vec![Complex::new(0.0, 0.0), Complex::new(1.0, 0.0)];
let state = StateVector::from_amplitudes(amplitudes).expect("test: valid amplitudes");
let exp = Measurement::expectation_z(&state, 0).expect("test: valid expectation value");
assert_relative_eq!(exp, -1.0, epsilon = 1e-10);
}
#[test]
fn test_invalid_qubit_measurement() {
let state = StateVector::<f64>::new(2).expect("test: valid qubit count");
let result = Measurement::measure_qubits(&state, &[5], Some(42));
assert!(result.is_err());
}
}