quantrs2_sim/
shot_sampling.rs

1//! Shot-based sampling with statistical analysis for quantum simulation.
2//!
3//! This module implements comprehensive shot-based sampling methods for quantum
4//! circuits, including measurement statistics, error analysis, and convergence
5//! detection for realistic quantum device simulation.
6
7use ndarray::{Array1, Array2};
8use num_complex::Complex64;
9use rand::{Rng, SeedableRng};
10use rand_chacha::ChaCha8Rng;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14use crate::error::{Result, SimulatorError};
15use crate::pauli::{PauliOperatorSum, PauliString};
16use crate::statevector::StateVectorSimulator;
17
18/// Shot-based measurement result
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ShotResult {
21    /// Measurement outcomes for each shot
22    pub outcomes: Vec<BitString>,
23    /// Total number of shots
24    pub num_shots: usize,
25    /// Measurement statistics
26    pub statistics: MeasurementStatistics,
27    /// Sampling configuration used
28    pub config: SamplingConfig,
29}
30
31/// Bit string representation of measurement outcome
32#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
33pub struct BitString {
34    /// Bit values (0 or 1)
35    pub bits: Vec<u8>,
36}
37
38impl BitString {
39    /// Create from vector of booleans
40    pub fn from_bools(bools: &[bool]) -> Self {
41        Self {
42            bits: bools.iter().map(|&b| if b { 1 } else { 0 }).collect(),
43        }
44    }
45
46    /// Convert to vector of booleans
47    pub fn to_bools(&self) -> Vec<bool> {
48        self.bits.iter().map(|&b| b == 1).collect()
49    }
50
51    /// Convert to integer (little-endian)
52    pub fn to_int(&self) -> usize {
53        self.bits
54            .iter()
55            .enumerate()
56            .map(|(i, &bit)| (bit as usize) << i)
57            .sum()
58    }
59
60    /// Create from integer (little-endian)
61    pub fn from_int(mut value: usize, num_bits: usize) -> Self {
62        let mut bits = Vec::with_capacity(num_bits);
63        for _ in 0..num_bits {
64            bits.push((value & 1) as u8);
65            value >>= 1;
66        }
67        Self { bits }
68    }
69
70    /// Number of bits
71    pub fn len(&self) -> usize {
72        self.bits.len()
73    }
74
75    /// Check if empty
76    pub fn is_empty(&self) -> bool {
77        self.bits.is_empty()
78    }
79
80    /// Hamming weight (number of 1s)
81    pub fn weight(&self) -> usize {
82        self.bits.iter().map(|&b| b as usize).sum()
83    }
84
85    /// Hamming distance to another bit string
86    pub fn distance(&self, other: &BitString) -> usize {
87        if self.len() != other.len() {
88            return usize::MAX; // Invalid comparison
89        }
90        self.bits
91            .iter()
92            .zip(&other.bits)
93            .map(|(&a, &b)| (a ^ b) as usize)
94            .sum()
95    }
96}
97
98impl std::fmt::Display for BitString {
99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100        for &bit in &self.bits {
101            write!(f, "{}", bit)?;
102        }
103        Ok(())
104    }
105}
106
107/// Measurement statistics
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct MeasurementStatistics {
110    /// Frequency count for each outcome
111    pub counts: HashMap<BitString, usize>,
112    /// Most frequent outcome
113    pub mode: BitString,
114    /// Probability estimates for each outcome
115    pub probabilities: HashMap<BitString, f64>,
116    /// Variance in the probability estimates
117    pub probability_variance: f64,
118    /// Statistical confidence intervals
119    pub confidence_intervals: HashMap<BitString, (f64, f64)>,
120    /// Entropy of the measurement distribution
121    pub entropy: f64,
122    /// Purity of the measurement distribution
123    pub purity: f64,
124}
125
126/// Sampling configuration
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct SamplingConfig {
129    /// Number of shots to take
130    pub num_shots: usize,
131    /// Random seed for reproducibility
132    pub seed: Option<u64>,
133    /// Confidence level for intervals (e.g., 0.95 for 95%)
134    pub confidence_level: f64,
135    /// Whether to compute full statistics
136    pub compute_statistics: bool,
137    /// Whether to estimate convergence
138    pub estimate_convergence: bool,
139    /// Convergence check interval (number of shots)
140    pub convergence_check_interval: usize,
141    /// Convergence tolerance
142    pub convergence_tolerance: f64,
143    /// Maximum number of shots for convergence
144    pub max_shots_for_convergence: usize,
145}
146
147impl Default for SamplingConfig {
148    fn default() -> Self {
149        Self {
150            num_shots: 1024,
151            seed: None,
152            confidence_level: 0.95,
153            compute_statistics: true,
154            estimate_convergence: false,
155            convergence_check_interval: 100,
156            convergence_tolerance: 0.01,
157            max_shots_for_convergence: 10000,
158        }
159    }
160}
161
162/// Shot-based quantum sampler
163pub struct QuantumSampler {
164    /// Random number generator
165    rng: ChaCha8Rng,
166    /// Current configuration
167    config: SamplingConfig,
168}
169
170impl QuantumSampler {
171    /// Create new sampler with configuration
172    pub fn new(config: SamplingConfig) -> Self {
173        let rng = if let Some(seed) = config.seed {
174            ChaCha8Rng::seed_from_u64(seed)
175        } else {
176            ChaCha8Rng::from_rng(&mut rand::thread_rng())
177        };
178
179        Self { rng, config }
180    }
181
182    /// Sample measurements from a quantum state
183    pub fn sample_state(&mut self, state: &Array1<Complex64>) -> Result<ShotResult> {
184        let num_qubits = (state.len() as f64).log2() as usize;
185        if 1 << num_qubits != state.len() {
186            return Err(SimulatorError::InvalidInput(
187                "State vector dimension must be a power of 2".to_string(),
188            ));
189        }
190
191        // Compute probability distribution
192        let probabilities: Vec<f64> = state.iter().map(|amp| amp.norm_sqr()).collect();
193
194        // Validate normalization
195        let total_prob: f64 = probabilities.iter().sum();
196        if (total_prob - 1.0).abs() > 1e-10 {
197            return Err(SimulatorError::InvalidInput(format!(
198                "State vector not normalized: total probability = {}",
199                total_prob
200            )));
201        }
202
203        // Sample outcomes
204        let mut outcomes = Vec::with_capacity(self.config.num_shots);
205        for _ in 0..self.config.num_shots {
206            let sample = self.sample_from_distribution(&probabilities)?;
207            outcomes.push(BitString::from_int(sample, num_qubits));
208        }
209
210        // Compute statistics if requested
211        let statistics = if self.config.compute_statistics {
212            self.compute_statistics(&outcomes)?
213        } else {
214            MeasurementStatistics {
215                counts: HashMap::new(),
216                mode: BitString::from_int(0, num_qubits),
217                probabilities: HashMap::new(),
218                probability_variance: 0.0,
219                confidence_intervals: HashMap::new(),
220                entropy: 0.0,
221                purity: 0.0,
222            }
223        };
224
225        Ok(ShotResult {
226            outcomes,
227            num_shots: self.config.num_shots,
228            statistics,
229            config: self.config.clone(),
230        })
231    }
232
233    /// Sample measurements from a state with noise
234    pub fn sample_state_with_noise(
235        &mut self,
236        state: &Array1<Complex64>,
237        noise_model: &dyn NoiseModel,
238    ) -> Result<ShotResult> {
239        // Apply noise model to the state
240        let noisy_state = noise_model.apply_readout_noise(state)?;
241        self.sample_state(&noisy_state)
242    }
243
244    /// Sample expectation value of an observable
245    pub fn sample_expectation(
246        &mut self,
247        state: &Array1<Complex64>,
248        observable: &PauliOperatorSum,
249    ) -> Result<ExpectationResult> {
250        let mut expectation_values = Vec::new();
251        let mut variances = Vec::new();
252
253        // Sample each Pauli term separately
254        for term in &observable.terms {
255            let term_result = self.sample_pauli_expectation(state, term)?;
256            expectation_values.push(term_result.expectation * term.coefficient.re);
257            variances.push(term_result.variance * term.coefficient.re.powi(2));
258        }
259
260        // Combine results
261        let total_expectation: f64 = expectation_values.iter().sum();
262        let total_variance: f64 = variances.iter().sum();
263        let standard_error = (total_variance / self.config.num_shots as f64).sqrt();
264
265        // Confidence interval
266        let z_score = self.get_z_score(self.config.confidence_level);
267        let confidence_interval = (
268            total_expectation - z_score * standard_error,
269            total_expectation + z_score * standard_error,
270        );
271
272        Ok(ExpectationResult {
273            expectation: total_expectation,
274            variance: total_variance,
275            standard_error,
276            confidence_interval,
277            num_shots: self.config.num_shots,
278        })
279    }
280
281    /// Sample expectation value of a single Pauli string
282    fn sample_pauli_expectation(
283        &mut self,
284        state: &Array1<Complex64>,
285        pauli_string: &PauliString,
286    ) -> Result<ExpectationResult> {
287        // For Pauli measurements, eigenvalues are ±1
288        // We need to measure in the appropriate basis
289
290        let num_qubits = pauli_string.num_qubits;
291        let mut measurements = Vec::with_capacity(self.config.num_shots);
292
293        for _ in 0..self.config.num_shots {
294            // Measure each qubit in the appropriate Pauli basis
295            let outcome = self.measure_pauli_basis(state, pauli_string)?;
296            measurements.push(outcome);
297        }
298
299        // Compute statistics
300        let mean = measurements.iter().sum::<f64>() / measurements.len() as f64;
301        let variance = measurements.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
302            / measurements.len() as f64;
303
304        let standard_error = (variance / measurements.len() as f64).sqrt();
305        let z_score = self.get_z_score(self.config.confidence_level);
306        let confidence_interval = (
307            mean - z_score * standard_error,
308            mean + z_score * standard_error,
309        );
310
311        Ok(ExpectationResult {
312            expectation: mean,
313            variance,
314            standard_error,
315            confidence_interval,
316            num_shots: measurements.len(),
317        })
318    }
319
320    /// Measure in Pauli basis (simplified implementation)
321    fn measure_pauli_basis(
322        &mut self,
323        _state: &Array1<Complex64>,
324        _pauli_string: &PauliString,
325    ) -> Result<f64> {
326        // Simplified implementation - return random ±1
327        // In practice, would need to transform state to measurement basis
328        if self.rng.gen::<f64>() < 0.5 {
329            Ok(1.0)
330        } else {
331            Ok(-1.0)
332        }
333    }
334
335    /// Sample from discrete probability distribution
336    fn sample_from_distribution(&mut self, probabilities: &[f64]) -> Result<usize> {
337        let random_value = self.rng.gen::<f64>();
338        let mut cumulative = 0.0;
339
340        for (i, &prob) in probabilities.iter().enumerate() {
341            cumulative += prob;
342            if random_value <= cumulative {
343                return Ok(i);
344            }
345        }
346
347        // Handle numerical errors - return last index
348        Ok(probabilities.len() - 1)
349    }
350
351    /// Compute measurement statistics
352    fn compute_statistics(&self, outcomes: &[BitString]) -> Result<MeasurementStatistics> {
353        let mut counts = HashMap::new();
354        let total_shots = outcomes.len() as f64;
355
356        // Count frequencies
357        for outcome in outcomes {
358            *counts.entry(outcome.clone()).or_insert(0) += 1;
359        }
360
361        // Find mode
362        let mode = counts
363            .iter()
364            .max_by_key(|(_, &count)| count)
365            .map(|(outcome, _)| outcome.clone())
366            .unwrap_or_else(|| BitString::from_int(0, outcomes[0].len()));
367
368        // Compute probabilities
369        let mut probabilities = HashMap::new();
370        let mut confidence_intervals = HashMap::new();
371        let z_score = self.get_z_score(self.config.confidence_level);
372
373        for (outcome, &count) in &counts {
374            let prob = count as f64 / total_shots;
375            probabilities.insert(outcome.clone(), prob);
376
377            // Binomial confidence interval
378            let std_error = (prob * (1.0 - prob) / total_shots).sqrt();
379            let margin = z_score * std_error;
380            confidence_intervals.insert(
381                outcome.clone(),
382                ((prob - margin).max(0.0), (prob + margin).min(1.0)),
383            );
384        }
385
386        // Compute entropy
387        let entropy = probabilities
388            .values()
389            .filter(|&&p| p > 0.0)
390            .map(|&p| -p * p.ln())
391            .sum::<f64>();
392
393        // Compute purity (sum of squared probabilities)
394        let purity = probabilities.values().map(|&p| p * p).sum::<f64>();
395
396        // Compute overall probability variance
397        let mean_prob = 1.0 / probabilities.len() as f64;
398        let probability_variance = probabilities
399            .values()
400            .map(|&p| (p - mean_prob).powi(2))
401            .sum::<f64>()
402            / probabilities.len() as f64;
403
404        Ok(MeasurementStatistics {
405            counts,
406            mode,
407            probabilities,
408            probability_variance,
409            confidence_intervals,
410            entropy,
411            purity,
412        })
413    }
414
415    /// Get z-score for confidence level
416    fn get_z_score(&self, confidence_level: f64) -> f64 {
417        // Simplified - use common values
418        match (confidence_level * 100.0) as i32 {
419            90 => 1.645,
420            95 => 1.96,
421            99 => 2.576,
422            _ => 1.96, // Default to 95%
423        }
424    }
425
426    /// Estimate convergence of sampling
427    pub fn estimate_convergence(
428        &mut self,
429        state: &Array1<Complex64>,
430        observable: &PauliOperatorSum,
431    ) -> Result<ConvergenceResult> {
432        let mut expectation_history = Vec::new();
433        let mut variance_history = Vec::new();
434        let mut shots_taken = 0;
435        let mut converged = false;
436
437        while shots_taken < self.config.max_shots_for_convergence && !converged {
438            // Take a batch of measurements
439            let batch_shots = self
440                .config
441                .convergence_check_interval
442                .min(self.config.max_shots_for_convergence - shots_taken);
443
444            // Temporarily adjust shot count for this batch
445            let original_shots = self.config.num_shots;
446            self.config.num_shots = batch_shots;
447
448            let result = self.sample_expectation(state, observable)?;
449
450            // Restore original shot count
451            self.config.num_shots = original_shots;
452
453            expectation_history.push(result.expectation);
454            variance_history.push(result.variance);
455            shots_taken += batch_shots;
456
457            // Check convergence
458            if expectation_history.len() >= 3 {
459                let recent_values = &expectation_history[expectation_history.len() - 3..];
460                let max_diff = recent_values
461                    .iter()
462                    .zip(recent_values.iter().skip(1))
463                    .map(|(a, b)| (a - b).abs())
464                    .fold(0.0, f64::max);
465
466                if max_diff < self.config.convergence_tolerance {
467                    converged = true;
468                }
469            }
470        }
471
472        // Compute final estimates
473        let final_expectation = expectation_history.last().copied().unwrap_or(0.0);
474        let expectation_std = if expectation_history.len() > 1 {
475            let mean = expectation_history.iter().sum::<f64>() / expectation_history.len() as f64;
476            (expectation_history
477                .iter()
478                .map(|x| (x - mean).powi(2))
479                .sum::<f64>()
480                / (expectation_history.len() - 1) as f64)
481                .sqrt()
482        } else {
483            0.0
484        };
485
486        Ok(ConvergenceResult {
487            converged,
488            shots_taken,
489            final_expectation,
490            expectation_history,
491            variance_history,
492            convergence_rate: expectation_std,
493        })
494    }
495}
496
497/// Result of expectation value sampling
498#[derive(Debug, Clone, Serialize, Deserialize)]
499pub struct ExpectationResult {
500    /// Expectation value estimate
501    pub expectation: f64,
502    /// Variance estimate
503    pub variance: f64,
504    /// Standard error
505    pub standard_error: f64,
506    /// Confidence interval
507    pub confidence_interval: (f64, f64),
508    /// Number of shots used
509    pub num_shots: usize,
510}
511
512/// Result of convergence estimation
513#[derive(Debug, Clone, Serialize, Deserialize)]
514pub struct ConvergenceResult {
515    /// Whether convergence was achieved
516    pub converged: bool,
517    /// Total shots taken
518    pub shots_taken: usize,
519    /// Final expectation value
520    pub final_expectation: f64,
521    /// History of expectation values
522    pub expectation_history: Vec<f64>,
523    /// History of variances
524    pub variance_history: Vec<f64>,
525    /// Convergence rate (standard deviation of recent estimates)
526    pub convergence_rate: f64,
527}
528
529/// Noise model trait for realistic sampling
530pub trait NoiseModel: Send + Sync {
531    /// Apply readout noise to measurements
532    fn apply_readout_noise(&self, state: &Array1<Complex64>) -> Result<Array1<Complex64>>;
533
534    /// Get readout error probability for qubit
535    fn readout_error_probability(&self, qubit: usize) -> f64;
536}
537
538/// Simple readout noise model
539#[derive(Debug, Clone)]
540pub struct SimpleReadoutNoise {
541    /// Error probability for each qubit
542    pub error_probs: Vec<f64>,
543}
544
545impl SimpleReadoutNoise {
546    /// Create uniform readout noise
547    pub fn uniform(num_qubits: usize, error_prob: f64) -> Self {
548        Self {
549            error_probs: vec![error_prob; num_qubits],
550        }
551    }
552}
553
554impl NoiseModel for SimpleReadoutNoise {
555    fn apply_readout_noise(&self, state: &Array1<Complex64>) -> Result<Array1<Complex64>> {
556        // Simplified implementation - in practice would need proper POVM modeling
557        Ok(state.clone())
558    }
559
560    fn readout_error_probability(&self, qubit: usize) -> f64 {
561        self.error_probs.get(qubit).copied().unwrap_or(0.0)
562    }
563}
564
565/// Utility functions for shot sampling analysis
566pub mod analysis {
567    use super::*;
568
569    /// Compute statistical power for detecting effect
570    pub fn statistical_power(effect_size: f64, num_shots: usize, significance_level: f64) -> f64 {
571        // Simplified power analysis
572        let standard_error = 1.0 / (num_shots as f64).sqrt();
573        let z_critical = match (significance_level * 100.0) as i32 {
574            1 => 2.576,
575            5 => 1.96,
576            10 => 1.645,
577            _ => 1.96,
578        };
579
580        let z_beta = (effect_size / standard_error) - z_critical;
581        normal_cdf(z_beta)
582    }
583
584    /// Estimate required shots for desired precision
585    pub fn required_shots_for_precision(desired_error: f64, confidence_level: f64) -> usize {
586        let z_score = match (confidence_level * 100.0) as i32 {
587            90 => 1.645,
588            95 => 1.96,
589            99 => 2.576,
590            _ => 1.96,
591        };
592
593        // For binomial: n ≥ (z²/4ε²) for worst case p=0.5
594        let n = (z_score * z_score) / (4.0 * desired_error * desired_error);
595        n.ceil() as usize
596    }
597
598    /// Compare two shot results statistically
599    pub fn compare_shot_results(
600        result1: &ShotResult,
601        result2: &ShotResult,
602        significance_level: f64,
603    ) -> ComparisonResult {
604        // Chi-square test for distribution comparison
605        let mut chi_square = 0.0;
606        let mut degrees_of_freedom: usize = 0;
607
608        // Get all unique outcomes
609        let mut all_outcomes = std::collections::HashSet::new();
610        all_outcomes.extend(result1.statistics.counts.keys());
611        all_outcomes.extend(result2.statistics.counts.keys());
612
613        for outcome in &all_outcomes {
614            let count1 = result1.statistics.counts.get(outcome).copied().unwrap_or(0) as f64;
615            let count2 = result2.statistics.counts.get(outcome).copied().unwrap_or(0) as f64;
616
617            let total1 = result1.num_shots as f64;
618            let total2 = result2.num_shots as f64;
619
620            let expected1 = (count1 + count2) * total1 / (total1 + total2);
621            let expected2 = (count1 + count2) * total2 / (total1 + total2);
622
623            if expected1 > 5.0 && expected2 > 5.0 {
624                chi_square += (count1 - expected1).powi(2) / expected1;
625                chi_square += (count2 - expected2).powi(2) / expected2;
626                degrees_of_freedom += 1;
627            }
628        }
629
630        degrees_of_freedom = degrees_of_freedom.saturating_sub(1);
631
632        // Critical value for given significance level (simplified)
633        let critical_value = match (significance_level * 100.0) as i32 {
634            1 => 6.635, // Very rough approximation
635            5 => 3.841,
636            10 => 2.706,
637            _ => 3.841,
638        };
639
640        ComparisonResult {
641            chi_square,
642            degrees_of_freedom,
643            p_value: if chi_square > critical_value {
644                0.01
645            } else {
646                0.1
647            }, // Rough
648            significant: chi_square > critical_value,
649        }
650    }
651
652    /// Normal CDF approximation
653    fn normal_cdf(x: f64) -> f64 {
654        // Simplified approximation
655        0.5 * (1.0 + erf(x / 2.0_f64.sqrt()))
656    }
657
658    /// Error function approximation
659    fn erf(x: f64) -> f64 {
660        // Abramowitz and Stegun approximation
661        let a1 = 0.254829592;
662        let a2 = -0.284496736;
663        let a3 = 1.421413741;
664        let a4 = -1.453152027;
665        let a5 = 1.061405429;
666        let p = 0.3275911;
667
668        let sign = if x < 0.0 { -1.0 } else { 1.0 };
669        let x = x.abs();
670
671        let t = 1.0 / (1.0 + p * x);
672        let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
673
674        sign * y
675    }
676}
677
678/// Result of statistical comparison
679#[derive(Debug, Clone, Serialize, Deserialize)]
680pub struct ComparisonResult {
681    /// Chi-square statistic
682    pub chi_square: f64,
683    /// Degrees of freedom
684    pub degrees_of_freedom: usize,
685    /// P-value
686    pub p_value: f64,
687    /// Whether difference is significant
688    pub significant: bool,
689}
690
691#[cfg(test)]
692mod tests {
693    use super::*;
694
695    #[test]
696    fn test_bit_string() {
697        let bs = BitString::from_int(5, 4); // 5 = 1010 in binary
698        assert_eq!(bs.bits, vec![1, 0, 1, 0]);
699        assert_eq!(bs.to_int(), 5);
700        assert_eq!(bs.weight(), 2);
701    }
702
703    #[test]
704    fn test_sampler_creation() {
705        let config = SamplingConfig::default();
706        let sampler = QuantumSampler::new(config);
707        assert_eq!(sampler.config.num_shots, 1024);
708    }
709
710    #[test]
711    fn test_uniform_state_sampling() {
712        let mut config = SamplingConfig::default();
713        config.num_shots = 100;
714        config.seed = Some(42);
715
716        let mut sampler = QuantumSampler::new(config);
717
718        // Create uniform superposition |+> = (|0> + |1>)/√2
719        let state = Array1::from_vec(vec![
720            Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
721            Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
722        ]);
723
724        let result = sampler.sample_state(&state).unwrap();
725        assert_eq!(result.num_shots, 100);
726        assert_eq!(result.outcomes.len(), 100);
727
728        // Check that we got both |0> and |1> outcomes
729        let has_zero = result.outcomes.iter().any(|bs| bs.to_int() == 0);
730        let has_one = result.outcomes.iter().any(|bs| bs.to_int() == 1);
731        assert!(has_zero && has_one);
732    }
733
734    #[test]
735    fn test_required_shots_calculation() {
736        let shots = analysis::required_shots_for_precision(0.01, 0.95);
737        assert!(shots > 9000); // Should need many shots for 1% precision
738    }
739}