use crate::{Measurement, QuantumChannel, QuantumState, errors::StateError};
use std::collections::HashMap;
#[derive(Debug, Clone, Default)]
pub struct Sampler {
pub channel: Option<QuantumChannel>,
}
impl Sampler {
pub fn new() -> Self {
Self { channel: None }
}
pub fn with_channel(mut self, channel: QuantumChannel) -> Self {
self.channel = Some(channel);
self
}
pub fn run(
&self,
state: &QuantumState,
measurement: &Measurement,
targets: &[usize],
num_shots: usize,
) -> Result<HashMap<String, usize>, StateError> {
let mut state_copy = state.clone();
if let Some(chan) = &self.channel {
state_copy.apply_channel(chan, targets)?;
}
let probs = state_copy.set_measurement(measurement, targets)?;
let mut cdf = Vec::with_capacity(probs.len());
let mut current_sum = 0.0;
for &p in &probs {
current_sum += p;
cdf.push(current_sum);
}
let mut raw_counts = vec![0usize; probs.len()];
for _ in 0..num_shots {
let r: f64 = crate::rng::random_f64();
let mut outcome_idx = 0;
for (i, &cumulative_prob) in cdf.iter().enumerate() {
if r < cumulative_prob {
outcome_idx = i;
break;
}
}
raw_counts[outcome_idx] += 1;
}
let mut counts = HashMap::new();
for (idx, &count) in raw_counts.iter().enumerate() {
if count > 0 {
let val_string = measurement.values[idx].to_string();
counts.insert(val_string, count);
}
}
Ok(counts)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Gate;
#[test]
fn test_sampler_deterministic_zero() {
let state = QuantumState::new(1); let sampler = Sampler::new();
let counts = sampler
.run(&state, &Measurement::z_basis(), &[0], 100)
.unwrap();
assert_eq!(counts.len(), 1);
assert_eq!(*counts.get("0").unwrap(), 100);
}
#[test]
fn test_sampler_deterministic_one() {
let mut state = QuantumState::new(1);
state.apply(&Gate::x(), &[0]).unwrap();
let sampler = Sampler::new();
let counts = sampler
.run(&state, &Measurement::z_basis(), &[0], 100)
.unwrap();
assert_eq!(counts.len(), 1);
assert_eq!(*counts.get("1").unwrap(), 100);
}
#[test]
fn test_sampler_superposition() {
let mut state = QuantumState::new(1);
state.apply(&Gate::h(), &[0]).unwrap();
let sampler = Sampler::new();
let num_shots = 1000;
let counts = sampler
.run(&state, &Measurement::z_basis(), &[0], num_shots)
.unwrap();
let count_0 = *counts.get("0").unwrap_or(&0);
let count_1 = *counts.get("1").unwrap_or(&0);
assert!(
count_0 > 350 && count_0 < 650,
"Expected roughly 500, got {}",
count_0
);
assert!(
count_1 > 350 && count_1 < 650,
"Expected roughly 500, got {}",
count_1
);
assert_eq!(count_0 + count_1, num_shots);
}
#[test]
fn test_sampler_with_bit_flip_channel() {
let state = QuantumState::new(1); let channel = QuantumChannel::bit_flip(1.0);
let sampler = Sampler::new().with_channel(channel);
let counts = sampler
.run(&state, &Measurement::z_basis(), &[0], 100)
.unwrap();
assert_eq!(counts.len(), 1);
assert_eq!(*counts.get("1").unwrap(), 100);
}
#[test]
fn test_sampler_errors_propagated() {
let state = QuantumState::new(1);
let sampler = Sampler::new();
let result = sampler.run(&state, &Measurement::z_basis(), &[5], 10);
assert!(result.is_err());
let result2 = sampler.run(&state, &Measurement::bell_basis(), &[0], 10);
assert!(result2.is_err());
}
}