quantrs2_sim/
shot_sampling.rs

1//! Shot-based sampling with statistical analysis for quantum simulation.
2//!
3//! This module implements comprehensive shot-based sampling methods for quantum
4//! circuits, including measurement statistics, error analysis, and convergence
5//! detection for realistic quantum device simulation.
6
7use scirs2_core::random::prelude::*;
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::Complex64;
10use scirs2_core::random::{Rng, SeedableRng};
11use scirs2_core::random::ChaCha8Rng;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15use crate::error::{Result, SimulatorError};
16use crate::pauli::{PauliOperatorSum, PauliString};
17use crate::statevector::StateVectorSimulator;
18
19/// Shot-based measurement result
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ShotResult {
22    /// Measurement outcomes for each shot
23    pub outcomes: Vec<BitString>,
24    /// Total number of shots
25    pub num_shots: usize,
26    /// Measurement statistics
27    pub statistics: MeasurementStatistics,
28    /// Sampling configuration used
29    pub config: SamplingConfig,
30}
31
32/// Bit string representation of measurement outcome
33#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
34pub struct BitString {
35    /// Bit values (0 or 1)
36    pub bits: Vec<u8>,
37}
38
39impl BitString {
40    /// Create from vector of booleans
41    pub fn from_bools(bools: &[bool]) -> Self {
42        Self {
43            bits: bools.iter().map(|&b| if b { 1 } else { 0 }).collect(),
44        }
45    }
46
47    /// Convert to vector of booleans
48    pub fn to_bools(&self) -> Vec<bool> {
49        self.bits.iter().map(|&b| b == 1).collect()
50    }
51
52    /// Convert to integer (little-endian)
53    pub fn to_int(&self) -> usize {
54        self.bits
55            .iter()
56            .enumerate()
57            .map(|(i, &bit)| (bit as usize) << i)
58            .sum()
59    }
60
61    /// Create from integer (little-endian)
62    pub fn from_int(mut value: usize, num_bits: usize) -> Self {
63        let mut bits = Vec::with_capacity(num_bits);
64        for _ in 0..num_bits {
65            bits.push((value & 1) as u8);
66            value >>= 1;
67        }
68        Self { bits }
69    }
70
71    /// Number of bits
72    pub fn len(&self) -> usize {
73        self.bits.len()
74    }
75
76    /// Check if empty
77    pub fn is_empty(&self) -> bool {
78        self.bits.is_empty()
79    }
80
81    /// Hamming weight (number of 1s)
82    pub fn weight(&self) -> usize {
83        self.bits.iter().map(|&b| b as usize).sum()
84    }
85
86    /// Hamming distance to another bit string
87    pub fn distance(&self, other: &BitString) -> usize {
88        if self.len() != other.len() {
89            return usize::MAX; // Invalid comparison
90        }
91        self.bits
92            .iter()
93            .zip(&other.bits)
94            .map(|(&a, &b)| (a ^ b) as usize)
95            .sum()
96    }
97}
98
99impl std::fmt::Display for BitString {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        for &bit in &self.bits {
102            write!(f, "{}", bit)?;
103        }
104        Ok(())
105    }
106}
107
108/// Measurement statistics
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct MeasurementStatistics {
111    /// Frequency count for each outcome
112    pub counts: HashMap<BitString, usize>,
113    /// Most frequent outcome
114    pub mode: BitString,
115    /// Probability estimates for each outcome
116    pub probabilities: HashMap<BitString, f64>,
117    /// Variance in the probability estimates
118    pub probability_variance: f64,
119    /// Statistical confidence intervals
120    pub confidence_intervals: HashMap<BitString, (f64, f64)>,
121    /// Entropy of the measurement distribution
122    pub entropy: f64,
123    /// Purity of the measurement distribution
124    pub purity: f64,
125}
126
127/// Sampling configuration
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct SamplingConfig {
130    /// Number of shots to take
131    pub num_shots: usize,
132    /// Random seed for reproducibility
133    pub seed: Option<u64>,
134    /// Confidence level for intervals (e.g., 0.95 for 95%)
135    pub confidence_level: f64,
136    /// Whether to compute full statistics
137    pub compute_statistics: bool,
138    /// Whether to estimate convergence
139    pub estimate_convergence: bool,
140    /// Convergence check interval (number of shots)
141    pub convergence_check_interval: usize,
142    /// Convergence tolerance
143    pub convergence_tolerance: f64,
144    /// Maximum number of shots for convergence
145    pub max_shots_for_convergence: usize,
146}
147
148impl Default for SamplingConfig {
149    fn default() -> Self {
150        Self {
151            num_shots: 1024,
152            seed: None,
153            confidence_level: 0.95,
154            compute_statistics: true,
155            estimate_convergence: false,
156            convergence_check_interval: 100,
157            convergence_tolerance: 0.01,
158            max_shots_for_convergence: 10000,
159        }
160    }
161}
162
163/// Shot-based quantum sampler
164pub struct QuantumSampler {
165    /// Random number generator
166    rng: ChaCha8Rng,
167    /// Current configuration
168    config: SamplingConfig,
169}
170
171impl QuantumSampler {
172    /// Create new sampler with configuration
173    pub fn new(config: SamplingConfig) -> Self {
174        let rng = if let Some(seed) = config.seed {
175            ChaCha8Rng::seed_from_u64(seed)
176        } else {
177            ChaCha8Rng::from_rng(&mut thread_rng())
178        };
179
180        Self { rng, config }
181    }
182
183    /// Sample measurements from a quantum state
184    pub fn sample_state(&mut self, state: &Array1<Complex64>) -> Result<ShotResult> {
185        let num_qubits = (state.len() as f64).log2() as usize;
186        if 1 << num_qubits != state.len() {
187            return Err(SimulatorError::InvalidInput(
188                "State vector dimension must be a power of 2".to_string(),
189            ));
190        }
191
192        // Compute probability distribution
193        let probabilities: Vec<f64> = state.iter().map(|amp| amp.norm_sqr()).collect();
194
195        // Validate normalization
196        let total_prob: f64 = probabilities.iter().sum();
197        if (total_prob - 1.0).abs() > 1e-10 {
198            return Err(SimulatorError::InvalidInput(format!(
199                "State vector not normalized: total probability = {}",
200                total_prob
201            )));
202        }
203
204        // Sample outcomes
205        let mut outcomes = Vec::with_capacity(self.config.num_shots);
206        for _ in 0..self.config.num_shots {
207            let sample = self.sample_from_distribution(&probabilities)?;
208            outcomes.push(BitString::from_int(sample, num_qubits));
209        }
210
211        // Compute statistics if requested
212        let statistics = if self.config.compute_statistics {
213            self.compute_statistics(&outcomes)?
214        } else {
215            MeasurementStatistics {
216                counts: HashMap::new(),
217                mode: BitString::from_int(0, num_qubits),
218                probabilities: HashMap::new(),
219                probability_variance: 0.0,
220                confidence_intervals: HashMap::new(),
221                entropy: 0.0,
222                purity: 0.0,
223            }
224        };
225
226        Ok(ShotResult {
227            outcomes,
228            num_shots: self.config.num_shots,
229            statistics,
230            config: self.config.clone(),
231        })
232    }
233
234    /// Sample measurements from a state with noise
235    pub fn sample_state_with_noise(
236        &mut self,
237        state: &Array1<Complex64>,
238        noise_model: &dyn NoiseModel,
239    ) -> Result<ShotResult> {
240        // Apply noise model to the state
241        let noisy_state = noise_model.apply_readout_noise(state)?;
242        self.sample_state(&noisy_state)
243    }
244
245    /// Sample expectation value of an observable
246    pub fn sample_expectation(
247        &mut self,
248        state: &Array1<Complex64>,
249        observable: &PauliOperatorSum,
250    ) -> Result<ExpectationResult> {
251        let mut expectation_values = Vec::new();
252        let mut variances = Vec::new();
253
254        // Sample each Pauli term separately
255        for term in &observable.terms {
256            let term_result = self.sample_pauli_expectation(state, term)?;
257            expectation_values.push(term_result.expectation * term.coefficient.re);
258            variances.push(term_result.variance * term.coefficient.re.powi(2));
259        }
260
261        // Combine results
262        let total_expectation: f64 = expectation_values.iter().sum();
263        let total_variance: f64 = variances.iter().sum();
264        let standard_error = (total_variance / self.config.num_shots as f64).sqrt();
265
266        // Confidence interval
267        let z_score = self.get_z_score(self.config.confidence_level);
268        let confidence_interval = (
269            total_expectation - z_score * standard_error,
270            total_expectation + z_score * standard_error,
271        );
272
273        Ok(ExpectationResult {
274            expectation: total_expectation,
275            variance: total_variance,
276            standard_error,
277            confidence_interval,
278            num_shots: self.config.num_shots,
279        })
280    }
281
282    /// Sample expectation value of a single Pauli string
283    fn sample_pauli_expectation(
284        &mut self,
285        state: &Array1<Complex64>,
286        pauli_string: &PauliString,
287    ) -> Result<ExpectationResult> {
288        // For Pauli measurements, eigenvalues are ±1
289        // We need to measure in the appropriate basis
290
291        let num_qubits = pauli_string.num_qubits;
292        let mut measurements = Vec::with_capacity(self.config.num_shots);
293
294        for _ in 0..self.config.num_shots {
295            // Measure each qubit in the appropriate Pauli basis
296            let outcome = self.measure_pauli_basis(state, pauli_string)?;
297            measurements.push(outcome);
298        }
299
300        // Compute statistics
301        let mean = measurements.iter().sum::<f64>() / measurements.len() as f64;
302        let variance = measurements.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
303            / measurements.len() as f64;
304
305        let standard_error = (variance / measurements.len() as f64).sqrt();
306        let z_score = self.get_z_score(self.config.confidence_level);
307        let confidence_interval = (
308            mean - z_score * standard_error,
309            mean + z_score * standard_error,
310        );
311
312        Ok(ExpectationResult {
313            expectation: mean,
314            variance,
315            standard_error,
316            confidence_interval,
317            num_shots: measurements.len(),
318        })
319    }
320
321    /// Measure in Pauli basis (simplified implementation)
322    fn measure_pauli_basis(
323        &mut self,
324        _state: &Array1<Complex64>,
325        _pauli_string: &PauliString,
326    ) -> Result<f64> {
327        // Simplified implementation - return random ±1
328        // In practice, would need to transform state to measurement basis
329        if self.rng.gen::<f64>() < 0.5 {
330            Ok(1.0)
331        } else {
332            Ok(-1.0)
333        }
334    }
335
336    /// Sample from discrete probability distribution
337    fn sample_from_distribution(&mut self, probabilities: &[f64]) -> Result<usize> {
338        let random_value = self.rng.gen::<f64>();
339        let mut cumulative = 0.0;
340
341        for (i, &prob) in probabilities.iter().enumerate() {
342            cumulative += prob;
343            if random_value <= cumulative {
344                return Ok(i);
345            }
346        }
347
348        // Handle numerical errors - return last index
349        Ok(probabilities.len() - 1)
350    }
351
352    /// Compute measurement statistics
353    fn compute_statistics(&self, outcomes: &[BitString]) -> Result<MeasurementStatistics> {
354        let mut counts = HashMap::new();
355        let total_shots = outcomes.len() as f64;
356
357        // Count frequencies
358        for outcome in outcomes {
359            *counts.entry(outcome.clone()).or_insert(0) += 1;
360        }
361
362        // Find mode
363        let mode = counts
364            .iter()
365            .max_by_key(|(_, &count)| count)
366            .map(|(outcome, _)| outcome.clone())
367            .unwrap_or_else(|| BitString::from_int(0, outcomes[0].len()));
368
369        // Compute probabilities
370        let mut probabilities = HashMap::new();
371        let mut confidence_intervals = HashMap::new();
372        let z_score = self.get_z_score(self.config.confidence_level);
373
374        for (outcome, &count) in &counts {
375            let prob = count as f64 / total_shots;
376            probabilities.insert(outcome.clone(), prob);
377
378            // Binomial confidence interval
379            let std_error = (prob * (1.0 - prob) / total_shots).sqrt();
380            let margin = z_score * std_error;
381            confidence_intervals.insert(
382                outcome.clone(),
383                ((prob - margin).max(0.0), (prob + margin).min(1.0)),
384            );
385        }
386
387        // Compute entropy
388        let entropy = probabilities
389            .values()
390            .filter(|&&p| p > 0.0)
391            .map(|&p| -p * p.ln())
392            .sum::<f64>();
393
394        // Compute purity (sum of squared probabilities)
395        let purity = probabilities.values().map(|&p| p * p).sum::<f64>();
396
397        // Compute overall probability variance
398        let mean_prob = 1.0 / probabilities.len() as f64;
399        let probability_variance = probabilities
400            .values()
401            .map(|&p| (p - mean_prob).powi(2))
402            .sum::<f64>()
403            / probabilities.len() as f64;
404
405        Ok(MeasurementStatistics {
406            counts,
407            mode,
408            probabilities,
409            probability_variance,
410            confidence_intervals,
411            entropy,
412            purity,
413        })
414    }
415
416    /// Get z-score for confidence level
417    fn get_z_score(&self, confidence_level: f64) -> f64 {
418        // Simplified - use common values
419        match (confidence_level * 100.0) as i32 {
420            90 => 1.645,
421            95 => 1.96,
422            99 => 2.576,
423            _ => 1.96, // Default to 95%
424        }
425    }
426
427    /// Estimate convergence of sampling
428    pub fn estimate_convergence(
429        &mut self,
430        state: &Array1<Complex64>,
431        observable: &PauliOperatorSum,
432    ) -> Result<ConvergenceResult> {
433        let mut expectation_history = Vec::new();
434        let mut variance_history = Vec::new();
435        let mut shots_taken = 0;
436        let mut converged = false;
437
438        while shots_taken < self.config.max_shots_for_convergence && !converged {
439            // Take a batch of measurements
440            let batch_shots = self
441                .config
442                .convergence_check_interval
443                .min(self.config.max_shots_for_convergence - shots_taken);
444
445            // Temporarily adjust shot count for this batch
446            let original_shots = self.config.num_shots;
447            self.config.num_shots = batch_shots;
448
449            let result = self.sample_expectation(state, observable)?;
450
451            // Restore original shot count
452            self.config.num_shots = original_shots;
453
454            expectation_history.push(result.expectation);
455            variance_history.push(result.variance);
456            shots_taken += batch_shots;
457
458            // Check convergence
459            if expectation_history.len() >= 3 {
460                let recent_values = &expectation_history[expectation_history.len() - 3..];
461                let max_diff = recent_values
462                    .iter()
463                    .zip(recent_values.iter().skip(1))
464                    .map(|(a, b)| (a - b).abs())
465                    .fold(0.0, f64::max);
466
467                if max_diff < self.config.convergence_tolerance {
468                    converged = true;
469                }
470            }
471        }
472
473        // Compute final estimates
474        let final_expectation = expectation_history.last().copied().unwrap_or(0.0);
475        let expectation_std = if expectation_history.len() > 1 {
476            let mean = expectation_history.iter().sum::<f64>() / expectation_history.len() as f64;
477            (expectation_history
478                .iter()
479                .map(|x| (x - mean).powi(2))
480                .sum::<f64>()
481                / (expectation_history.len() - 1) as f64)
482                .sqrt()
483        } else {
484            0.0
485        };
486
487        Ok(ConvergenceResult {
488            converged,
489            shots_taken,
490            final_expectation,
491            expectation_history,
492            variance_history,
493            convergence_rate: expectation_std,
494        })
495    }
496}
497
498/// Result of expectation value sampling
499#[derive(Debug, Clone, Serialize, Deserialize)]
500pub struct ExpectationResult {
501    /// Expectation value estimate
502    pub expectation: f64,
503    /// Variance estimate
504    pub variance: f64,
505    /// Standard error
506    pub standard_error: f64,
507    /// Confidence interval
508    pub confidence_interval: (f64, f64),
509    /// Number of shots used
510    pub num_shots: usize,
511}
512
513/// Result of convergence estimation
514#[derive(Debug, Clone, Serialize, Deserialize)]
515pub struct ConvergenceResult {
516    /// Whether convergence was achieved
517    pub converged: bool,
518    /// Total shots taken
519    pub shots_taken: usize,
520    /// Final expectation value
521    pub final_expectation: f64,
522    /// History of expectation values
523    pub expectation_history: Vec<f64>,
524    /// History of variances
525    pub variance_history: Vec<f64>,
526    /// Convergence rate (standard deviation of recent estimates)
527    pub convergence_rate: f64,
528}
529
530/// Noise model trait for realistic sampling
531pub trait NoiseModel: Send + Sync {
532    /// Apply readout noise to measurements
533    fn apply_readout_noise(&self, state: &Array1<Complex64>) -> Result<Array1<Complex64>>;
534
535    /// Get readout error probability for qubit
536    fn readout_error_probability(&self, qubit: usize) -> f64;
537}
538
539/// Simple readout noise model
540#[derive(Debug, Clone)]
541pub struct SimpleReadoutNoise {
542    /// Error probability for each qubit
543    pub error_probs: Vec<f64>,
544}
545
546impl SimpleReadoutNoise {
547    /// Create uniform readout noise
548    pub fn uniform(num_qubits: usize, error_prob: f64) -> Self {
549        Self {
550            error_probs: vec![error_prob; num_qubits],
551        }
552    }
553}
554
555impl NoiseModel for SimpleReadoutNoise {
556    fn apply_readout_noise(&self, state: &Array1<Complex64>) -> Result<Array1<Complex64>> {
557        // Simplified implementation - in practice would need proper POVM modeling
558        Ok(state.clone())
559    }
560
561    fn readout_error_probability(&self, qubit: usize) -> f64 {
562        self.error_probs.get(qubit).copied().unwrap_or(0.0)
563    }
564}
565
566/// Utility functions for shot sampling analysis
567pub mod analysis {
568    use super::*;
569
570    /// Compute statistical power for detecting effect
571    pub fn statistical_power(effect_size: f64, num_shots: usize, significance_level: f64) -> f64 {
572        // Simplified power analysis
573        let standard_error = 1.0 / (num_shots as f64).sqrt();
574        let z_critical = match (significance_level * 100.0) as i32 {
575            1 => 2.576,
576            5 => 1.96,
577            10 => 1.645,
578            _ => 1.96,
579        };
580
581        let z_beta = (effect_size / standard_error) - z_critical;
582        normal_cdf(z_beta)
583    }
584
585    /// Estimate required shots for desired precision
586    pub fn required_shots_for_precision(desired_error: f64, confidence_level: f64) -> usize {
587        let z_score = match (confidence_level * 100.0) as i32 {
588            90 => 1.645,
589            95 => 1.96,
590            99 => 2.576,
591            _ => 1.96,
592        };
593
594        // For binomial: n ≥ (z²/4ε²) for worst case p=0.5
595        let n = (z_score * z_score) / (4.0 * desired_error * desired_error);
596        n.ceil() as usize
597    }
598
599    /// Compare two shot results statistically
600    pub fn compare_shot_results(
601        result1: &ShotResult,
602        result2: &ShotResult,
603        significance_level: f64,
604    ) -> ComparisonResult {
605        // Chi-square test for distribution comparison
606        let mut chi_square = 0.0;
607        let mut degrees_of_freedom: usize = 0;
608
609        // Get all unique outcomes
610        let mut all_outcomes = std::collections::HashSet::new();
611        all_outcomes.extend(result1.statistics.counts.keys());
612        all_outcomes.extend(result2.statistics.counts.keys());
613
614        for outcome in &all_outcomes {
615            let count1 = result1.statistics.counts.get(outcome).copied().unwrap_or(0) as f64;
616            let count2 = result2.statistics.counts.get(outcome).copied().unwrap_or(0) as f64;
617
618            let total1 = result1.num_shots as f64;
619            let total2 = result2.num_shots as f64;
620
621            let expected1 = (count1 + count2) * total1 / (total1 + total2);
622            let expected2 = (count1 + count2) * total2 / (total1 + total2);
623
624            if expected1 > 5.0 && expected2 > 5.0 {
625                chi_square += (count1 - expected1).powi(2) / expected1;
626                chi_square += (count2 - expected2).powi(2) / expected2;
627                degrees_of_freedom += 1;
628            }
629        }
630
631        degrees_of_freedom = degrees_of_freedom.saturating_sub(1);
632
633        // Critical value for given significance level (simplified)
634        let critical_value = match (significance_level * 100.0) as i32 {
635            1 => 6.635, // Very rough approximation
636            5 => 3.841,
637            10 => 2.706,
638            _ => 3.841,
639        };
640
641        ComparisonResult {
642            chi_square,
643            degrees_of_freedom,
644            p_value: if chi_square > critical_value {
645                0.01
646            } else {
647                0.1
648            }, // Rough
649            significant: chi_square > critical_value,
650        }
651    }
652
653    /// Normal CDF approximation
654    fn normal_cdf(x: f64) -> f64 {
655        // Simplified approximation
656        0.5 * (1.0 + erf(x / 2.0_f64.sqrt()))
657    }
658
659    /// Error function approximation
660    fn erf(x: f64) -> f64 {
661        // Abramowitz and Stegun approximation
662        let a1 = 0.254829592;
663        let a2 = -0.284496736;
664        let a3 = 1.421413741;
665        let a4 = -1.453152027;
666        let a5 = 1.061405429;
667        let p = 0.3275911;
668
669        let sign = if x < 0.0 { -1.0 } else { 1.0 };
670        let x = x.abs();
671
672        let t = 1.0 / (1.0 + p * x);
673        let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
674
675        sign * y
676    }
677}
678
679/// Result of statistical comparison
680#[derive(Debug, Clone, Serialize, Deserialize)]
681pub struct ComparisonResult {
682    /// Chi-square statistic
683    pub chi_square: f64,
684    /// Degrees of freedom
685    pub degrees_of_freedom: usize,
686    /// P-value
687    pub p_value: f64,
688    /// Whether difference is significant
689    pub significant: bool,
690}
691
692#[cfg(test)]
693mod tests {
694    use super::*;
695
696    #[test]
697    fn test_bit_string() {
698        let bs = BitString::from_int(5, 4); // 5 = 1010 in binary
699        assert_eq!(bs.bits, vec![1, 0, 1, 0]);
700        assert_eq!(bs.to_int(), 5);
701        assert_eq!(bs.weight(), 2);
702    }
703
704    #[test]
705    fn test_sampler_creation() {
706        let config = SamplingConfig::default();
707        let sampler = QuantumSampler::new(config);
708        assert_eq!(sampler.config.num_shots, 1024);
709    }
710
711    #[test]
712    fn test_uniform_state_sampling() {
713        let mut config = SamplingConfig::default();
714        config.num_shots = 100;
715        config.seed = Some(42);
716
717        let mut sampler = QuantumSampler::new(config);
718
719        // Create uniform superposition |+> = (|0> + |1>)/√2
720        let state = Array1::from_vec(vec![
721            Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
722            Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
723        ]);
724
725        let result = sampler.sample_state(&state).unwrap();
726        assert_eq!(result.num_shots, 100);
727        assert_eq!(result.outcomes.len(), 100);
728
729        // Check that we got both |0> and |1> outcomes
730        let has_zero = result.outcomes.iter().any(|bs| bs.to_int() == 0);
731        let has_one = result.outcomes.iter().any(|bs| bs.to_int() == 1);
732        assert!(has_zero && has_one);
733    }
734
735    #[test]
736    fn test_required_shots_calculation() {
737        let shots = analysis::required_shots_for_precision(0.01, 0.95);
738        assert!(shots > 9000); // Should need many shots for 1% precision
739    }
740}