quantrs2_sim/
shot_sampling.rs

1//! Shot-based sampling with statistical analysis for quantum simulation.
2//!
3//! This module implements comprehensive shot-based sampling methods for quantum
4//! circuits, including measurement statistics, error analysis, and convergence
5//! detection for realistic quantum device simulation.
6
7use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::random::prelude::*;
9use scirs2_core::random::ChaCha8Rng;
10use scirs2_core::random::{Rng, SeedableRng};
11use scirs2_core::Complex64;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15use crate::error::{Result, SimulatorError};
16use crate::pauli::{PauliOperatorSum, PauliString};
17use crate::statevector::StateVectorSimulator;
18
19/// Shot-based measurement result
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ShotResult {
22    /// Measurement outcomes for each shot
23    pub outcomes: Vec<BitString>,
24    /// Total number of shots
25    pub num_shots: usize,
26    /// Measurement statistics
27    pub statistics: MeasurementStatistics,
28    /// Sampling configuration used
29    pub config: SamplingConfig,
30}
31
32/// Bit string representation of measurement outcome
33#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
34pub struct BitString {
35    /// Bit values (0 or 1)
36    pub bits: Vec<u8>,
37}
38
39impl BitString {
40    /// Create from vector of booleans
41    pub fn from_bools(bools: &[bool]) -> Self {
42        Self {
43            bits: bools.iter().map(|&b| u8::from(b)).collect(),
44        }
45    }
46
47    /// Convert to vector of booleans
48    pub fn to_bools(&self) -> Vec<bool> {
49        self.bits.iter().map(|&b| b == 1).collect()
50    }
51
52    /// Convert to integer (little-endian)
53    pub fn to_int(&self) -> usize {
54        self.bits
55            .iter()
56            .enumerate()
57            .map(|(i, &bit)| (bit as usize) << i)
58            .sum()
59    }
60
61    /// Create from integer (little-endian)
62    pub fn from_int(mut value: usize, num_bits: usize) -> Self {
63        let mut bits = Vec::with_capacity(num_bits);
64        for _ in 0..num_bits {
65            bits.push((value & 1) as u8);
66            value >>= 1;
67        }
68        Self { bits }
69    }
70
71    /// Number of bits
72    pub fn len(&self) -> usize {
73        self.bits.len()
74    }
75
76    /// Check if empty
77    pub fn is_empty(&self) -> bool {
78        self.bits.is_empty()
79    }
80
81    /// Hamming weight (number of 1s)
82    pub fn weight(&self) -> usize {
83        self.bits.iter().map(|&b| b as usize).sum()
84    }
85
86    /// Hamming distance to another bit string
87    pub fn distance(&self, other: &Self) -> usize {
88        if self.len() != other.len() {
89            return usize::MAX; // Invalid comparison
90        }
91        self.bits
92            .iter()
93            .zip(&other.bits)
94            .map(|(&a, &b)| (a ^ b) as usize)
95            .sum()
96    }
97}
98
99impl std::fmt::Display for BitString {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        for &bit in &self.bits {
102            write!(f, "{bit}")?;
103        }
104        Ok(())
105    }
106}
107
108/// Measurement statistics
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct MeasurementStatistics {
111    /// Frequency count for each outcome
112    pub counts: HashMap<BitString, usize>,
113    /// Most frequent outcome
114    pub mode: BitString,
115    /// Probability estimates for each outcome
116    pub probabilities: HashMap<BitString, f64>,
117    /// Variance in the probability estimates
118    pub probability_variance: f64,
119    /// Statistical confidence intervals
120    pub confidence_intervals: HashMap<BitString, (f64, f64)>,
121    /// Entropy of the measurement distribution
122    pub entropy: f64,
123    /// Purity of the measurement distribution
124    pub purity: f64,
125}
126
127/// Sampling configuration
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct SamplingConfig {
130    /// Number of shots to take
131    pub num_shots: usize,
132    /// Random seed for reproducibility
133    pub seed: Option<u64>,
134    /// Confidence level for intervals (e.g., 0.95 for 95%)
135    pub confidence_level: f64,
136    /// Whether to compute full statistics
137    pub compute_statistics: bool,
138    /// Whether to estimate convergence
139    pub estimate_convergence: bool,
140    /// Convergence check interval (number of shots)
141    pub convergence_check_interval: usize,
142    /// Convergence tolerance
143    pub convergence_tolerance: f64,
144    /// Maximum number of shots for convergence
145    pub max_shots_for_convergence: usize,
146}
147
148impl Default for SamplingConfig {
149    fn default() -> Self {
150        Self {
151            num_shots: 1024,
152            seed: None,
153            confidence_level: 0.95,
154            compute_statistics: true,
155            estimate_convergence: false,
156            convergence_check_interval: 100,
157            convergence_tolerance: 0.01,
158            max_shots_for_convergence: 10000,
159        }
160    }
161}
162
163/// Shot-based quantum sampler
164pub struct QuantumSampler {
165    /// Random number generator
166    rng: ChaCha8Rng,
167    /// Current configuration
168    config: SamplingConfig,
169}
170
171impl QuantumSampler {
172    /// Create new sampler with configuration
173    pub fn new(config: SamplingConfig) -> Self {
174        let rng = if let Some(seed) = config.seed {
175            ChaCha8Rng::seed_from_u64(seed)
176        } else {
177            ChaCha8Rng::from_rng(&mut thread_rng())
178        };
179
180        Self { rng, config }
181    }
182
183    /// Sample measurements from a quantum state
184    pub fn sample_state(&mut self, state: &Array1<Complex64>) -> Result<ShotResult> {
185        let num_qubits = (state.len() as f64).log2() as usize;
186        if 1 << num_qubits != state.len() {
187            return Err(SimulatorError::InvalidInput(
188                "State vector dimension must be a power of 2".to_string(),
189            ));
190        }
191
192        // Compute probability distribution
193        let probabilities: Vec<f64> = state.iter().map(|amp| amp.norm_sqr()).collect();
194
195        // Validate normalization
196        let total_prob: f64 = probabilities.iter().sum();
197        if (total_prob - 1.0).abs() > 1e-10 {
198            return Err(SimulatorError::InvalidInput(format!(
199                "State vector not normalized: total probability = {total_prob}"
200            )));
201        }
202
203        // Sample outcomes
204        let mut outcomes = Vec::with_capacity(self.config.num_shots);
205        for _ in 0..self.config.num_shots {
206            let sample = self.sample_from_distribution(&probabilities)?;
207            outcomes.push(BitString::from_int(sample, num_qubits));
208        }
209
210        // Compute statistics if requested
211        let statistics = if self.config.compute_statistics {
212            self.compute_statistics(&outcomes)?
213        } else {
214            MeasurementStatistics {
215                counts: HashMap::new(),
216                mode: BitString::from_int(0, num_qubits),
217                probabilities: HashMap::new(),
218                probability_variance: 0.0,
219                confidence_intervals: HashMap::new(),
220                entropy: 0.0,
221                purity: 0.0,
222            }
223        };
224
225        Ok(ShotResult {
226            outcomes,
227            num_shots: self.config.num_shots,
228            statistics,
229            config: self.config.clone(),
230        })
231    }
232
233    /// Sample measurements from a state with noise
234    pub fn sample_state_with_noise(
235        &mut self,
236        state: &Array1<Complex64>,
237        noise_model: &dyn NoiseModel,
238    ) -> Result<ShotResult> {
239        // Apply noise model to the state
240        let noisy_state = noise_model.apply_readout_noise(state)?;
241        self.sample_state(&noisy_state)
242    }
243
244    /// Sample expectation value of an observable
245    pub fn sample_expectation(
246        &mut self,
247        state: &Array1<Complex64>,
248        observable: &PauliOperatorSum,
249    ) -> Result<ExpectationResult> {
250        let mut expectation_values = Vec::new();
251        let mut variances = Vec::new();
252
253        // Sample each Pauli term separately
254        for term in &observable.terms {
255            let term_result = self.sample_pauli_expectation(state, term)?;
256            expectation_values.push(term_result.expectation * term.coefficient.re);
257            variances.push(term_result.variance * term.coefficient.re.powi(2));
258        }
259
260        // Combine results
261        let total_expectation: f64 = expectation_values.iter().sum();
262        let total_variance: f64 = variances.iter().sum();
263        let standard_error = (total_variance / self.config.num_shots as f64).sqrt();
264
265        // Confidence interval
266        let z_score = self.get_z_score(self.config.confidence_level);
267        let confidence_interval = (
268            z_score.mul_add(-standard_error, total_expectation),
269            z_score.mul_add(standard_error, total_expectation),
270        );
271
272        Ok(ExpectationResult {
273            expectation: total_expectation,
274            variance: total_variance,
275            standard_error,
276            confidence_interval,
277            num_shots: self.config.num_shots,
278        })
279    }
280
281    /// Sample expectation value of a single Pauli string
282    fn sample_pauli_expectation(
283        &mut self,
284        state: &Array1<Complex64>,
285        pauli_string: &PauliString,
286    ) -> Result<ExpectationResult> {
287        // For Pauli measurements, eigenvalues are ±1
288        // We need to measure in the appropriate basis
289
290        let num_qubits = pauli_string.num_qubits;
291        let mut measurements = Vec::with_capacity(self.config.num_shots);
292
293        for _ in 0..self.config.num_shots {
294            // Measure each qubit in the appropriate Pauli basis
295            let outcome = self.measure_pauli_basis(state, pauli_string)?;
296            measurements.push(outcome);
297        }
298
299        // Compute statistics
300        let mean = measurements.iter().sum::<f64>() / measurements.len() as f64;
301        let variance = measurements.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
302            / measurements.len() as f64;
303
304        let standard_error = (variance / measurements.len() as f64).sqrt();
305        let z_score = self.get_z_score(self.config.confidence_level);
306        let confidence_interval = (
307            z_score.mul_add(-standard_error, mean),
308            z_score.mul_add(standard_error, mean),
309        );
310
311        Ok(ExpectationResult {
312            expectation: mean,
313            variance,
314            standard_error,
315            confidence_interval,
316            num_shots: measurements.len(),
317        })
318    }
319
320    /// Measure in Pauli basis (simplified implementation)
321    fn measure_pauli_basis(
322        &mut self,
323        _state: &Array1<Complex64>,
324        _pauli_string: &PauliString,
325    ) -> Result<f64> {
326        // Simplified implementation - return random ±1
327        // In practice, would need to transform state to measurement basis
328        if self.rng.gen::<f64>() < 0.5 {
329            Ok(1.0)
330        } else {
331            Ok(-1.0)
332        }
333    }
334
335    /// Sample from discrete probability distribution
336    fn sample_from_distribution(&mut self, probabilities: &[f64]) -> Result<usize> {
337        let random_value = self.rng.gen::<f64>();
338        let mut cumulative = 0.0;
339
340        for (i, &prob) in probabilities.iter().enumerate() {
341            cumulative += prob;
342            if random_value <= cumulative {
343                return Ok(i);
344            }
345        }
346
347        // Handle numerical errors - return last index
348        Ok(probabilities.len() - 1)
349    }
350
351    /// Compute measurement statistics
352    fn compute_statistics(&self, outcomes: &[BitString]) -> Result<MeasurementStatistics> {
353        let mut counts = HashMap::new();
354        let total_shots = outcomes.len() as f64;
355
356        // Count frequencies
357        for outcome in outcomes {
358            *counts.entry(outcome.clone()).or_insert(0) += 1;
359        }
360
361        // Find mode
362        let mode = counts.iter().max_by_key(|(_, &count)| count).map_or_else(
363            || BitString::from_int(0, outcomes[0].len()),
364            |(outcome, _)| outcome.clone(),
365        );
366
367        // Compute probabilities
368        let mut probabilities = HashMap::new();
369        let mut confidence_intervals = HashMap::new();
370        let z_score = self.get_z_score(self.config.confidence_level);
371
372        for (outcome, &count) in &counts {
373            let prob = count as f64 / total_shots;
374            probabilities.insert(outcome.clone(), prob);
375
376            // Binomial confidence interval
377            let std_error = (prob * (1.0 - prob) / total_shots).sqrt();
378            let margin = z_score * std_error;
379            confidence_intervals.insert(
380                outcome.clone(),
381                ((prob - margin).max(0.0), (prob + margin).min(1.0)),
382            );
383        }
384
385        // Compute entropy
386        let entropy = probabilities
387            .values()
388            .filter(|&&p| p > 0.0)
389            .map(|&p| -p * p.ln())
390            .sum::<f64>();
391
392        // Compute purity (sum of squared probabilities)
393        let purity = probabilities.values().map(|&p| p * p).sum::<f64>();
394
395        // Compute overall probability variance
396        let mean_prob = 1.0 / probabilities.len() as f64;
397        let probability_variance = probabilities
398            .values()
399            .map(|&p| (p - mean_prob).powi(2))
400            .sum::<f64>()
401            / probabilities.len() as f64;
402
403        Ok(MeasurementStatistics {
404            counts,
405            mode,
406            probabilities,
407            probability_variance,
408            confidence_intervals,
409            entropy,
410            purity,
411        })
412    }
413
414    /// Get z-score for confidence level
415    fn get_z_score(&self, confidence_level: f64) -> f64 {
416        // Simplified - use common values
417        match (confidence_level * 100.0) as i32 {
418            90 => 1.645,
419            95 => 1.96,
420            99 => 2.576,
421            _ => 1.96, // Default to 95%
422        }
423    }
424
425    /// Estimate convergence of sampling
426    pub fn estimate_convergence(
427        &mut self,
428        state: &Array1<Complex64>,
429        observable: &PauliOperatorSum,
430    ) -> Result<ConvergenceResult> {
431        let mut expectation_history = Vec::new();
432        let mut variance_history = Vec::new();
433        let mut shots_taken = 0;
434        let mut converged = false;
435
436        while shots_taken < self.config.max_shots_for_convergence && !converged {
437            // Take a batch of measurements
438            let batch_shots = self
439                .config
440                .convergence_check_interval
441                .min(self.config.max_shots_for_convergence - shots_taken);
442
443            // Temporarily adjust shot count for this batch
444            let original_shots = self.config.num_shots;
445            self.config.num_shots = batch_shots;
446
447            let result = self.sample_expectation(state, observable)?;
448
449            // Restore original shot count
450            self.config.num_shots = original_shots;
451
452            expectation_history.push(result.expectation);
453            variance_history.push(result.variance);
454            shots_taken += batch_shots;
455
456            // Check convergence
457            if expectation_history.len() >= 3 {
458                let recent_values = &expectation_history[expectation_history.len() - 3..];
459                let max_diff = recent_values
460                    .iter()
461                    .zip(recent_values.iter().skip(1))
462                    .map(|(a, b)| (a - b).abs())
463                    .fold(0.0, f64::max);
464
465                if max_diff < self.config.convergence_tolerance {
466                    converged = true;
467                }
468            }
469        }
470
471        // Compute final estimates
472        let final_expectation = expectation_history.last().copied().unwrap_or(0.0);
473        let expectation_std = if expectation_history.len() > 1 {
474            let mean = expectation_history.iter().sum::<f64>() / expectation_history.len() as f64;
475            (expectation_history
476                .iter()
477                .map(|x| (x - mean).powi(2))
478                .sum::<f64>()
479                / (expectation_history.len() - 1) as f64)
480                .sqrt()
481        } else {
482            0.0
483        };
484
485        Ok(ConvergenceResult {
486            converged,
487            shots_taken,
488            final_expectation,
489            expectation_history,
490            variance_history,
491            convergence_rate: expectation_std,
492        })
493    }
494}
495
496/// Result of expectation value sampling
497#[derive(Debug, Clone, Serialize, Deserialize)]
498pub struct ExpectationResult {
499    /// Expectation value estimate
500    pub expectation: f64,
501    /// Variance estimate
502    pub variance: f64,
503    /// Standard error
504    pub standard_error: f64,
505    /// Confidence interval
506    pub confidence_interval: (f64, f64),
507    /// Number of shots used
508    pub num_shots: usize,
509}
510
511/// Result of convergence estimation
512#[derive(Debug, Clone, Serialize, Deserialize)]
513pub struct ConvergenceResult {
514    /// Whether convergence was achieved
515    pub converged: bool,
516    /// Total shots taken
517    pub shots_taken: usize,
518    /// Final expectation value
519    pub final_expectation: f64,
520    /// History of expectation values
521    pub expectation_history: Vec<f64>,
522    /// History of variances
523    pub variance_history: Vec<f64>,
524    /// Convergence rate (standard deviation of recent estimates)
525    pub convergence_rate: f64,
526}
527
528/// Noise model trait for realistic sampling
529pub trait NoiseModel: Send + Sync {
530    /// Apply readout noise to measurements
531    fn apply_readout_noise(&self, state: &Array1<Complex64>) -> Result<Array1<Complex64>>;
532
533    /// Get readout error probability for qubit
534    fn readout_error_probability(&self, qubit: usize) -> f64;
535}
536
537/// Simple readout noise model
538#[derive(Debug, Clone)]
539pub struct SimpleReadoutNoise {
540    /// Error probability for each qubit
541    pub error_probs: Vec<f64>,
542}
543
544impl SimpleReadoutNoise {
545    /// Create uniform readout noise
546    pub fn uniform(num_qubits: usize, error_prob: f64) -> Self {
547        Self {
548            error_probs: vec![error_prob; num_qubits],
549        }
550    }
551}
552
553impl NoiseModel for SimpleReadoutNoise {
554    fn apply_readout_noise(&self, state: &Array1<Complex64>) -> Result<Array1<Complex64>> {
555        // Simplified implementation - in practice would need proper POVM modeling
556        Ok(state.clone())
557    }
558
559    fn readout_error_probability(&self, qubit: usize) -> f64 {
560        self.error_probs.get(qubit).copied().unwrap_or(0.0)
561    }
562}
563
564/// Utility functions for shot sampling analysis
565pub mod analysis {
566    use super::*;
567
568    /// Compute statistical power for detecting effect
569    pub fn statistical_power(effect_size: f64, num_shots: usize, significance_level: f64) -> f64 {
570        // Simplified power analysis
571        let standard_error = 1.0 / (num_shots as f64).sqrt();
572        let z_critical = match (significance_level * 100.0) as i32 {
573            1 => 2.576,
574            5 => 1.96,
575            10 => 1.645,
576            _ => 1.96,
577        };
578
579        let z_beta = (effect_size / standard_error) - z_critical;
580        normal_cdf(z_beta)
581    }
582
583    /// Estimate required shots for desired precision
584    pub fn required_shots_for_precision(desired_error: f64, confidence_level: f64) -> usize {
585        let z_score = match (confidence_level * 100.0) as i32 {
586            90 => 1.645,
587            95 => 1.96,
588            99 => 2.576,
589            _ => 1.96,
590        };
591
592        // For binomial: n ≥ (z²/4ε²) for worst case p=0.5
593        let n = (z_score * z_score) / (4.0 * desired_error * desired_error);
594        n.ceil() as usize
595    }
596
597    /// Compare two shot results statistically
598    pub fn compare_shot_results(
599        result1: &ShotResult,
600        result2: &ShotResult,
601        significance_level: f64,
602    ) -> ComparisonResult {
603        // Chi-square test for distribution comparison
604        let mut chi_square = 0.0;
605        let mut degrees_of_freedom: usize = 0;
606
607        // Get all unique outcomes
608        let mut all_outcomes = std::collections::HashSet::new();
609        all_outcomes.extend(result1.statistics.counts.keys());
610        all_outcomes.extend(result2.statistics.counts.keys());
611
612        for outcome in &all_outcomes {
613            let count1 = result1.statistics.counts.get(outcome).copied().unwrap_or(0) as f64;
614            let count2 = result2.statistics.counts.get(outcome).copied().unwrap_or(0) as f64;
615
616            let total1 = result1.num_shots as f64;
617            let total2 = result2.num_shots as f64;
618
619            let expected1 = (count1 + count2) * total1 / (total1 + total2);
620            let expected2 = (count1 + count2) * total2 / (total1 + total2);
621
622            if expected1 > 5.0 && expected2 > 5.0 {
623                chi_square += (count1 - expected1).powi(2) / expected1;
624                chi_square += (count2 - expected2).powi(2) / expected2;
625                degrees_of_freedom += 1;
626            }
627        }
628
629        degrees_of_freedom = degrees_of_freedom.saturating_sub(1);
630
631        // Critical value for given significance level (simplified)
632        let critical_value = match (significance_level * 100.0) as i32 {
633            1 => 6.635, // Very rough approximation
634            5 => 3.841,
635            10 => 2.706,
636            _ => 3.841,
637        };
638
639        ComparisonResult {
640            chi_square,
641            degrees_of_freedom,
642            p_value: if chi_square > critical_value {
643                0.01
644            } else {
645                0.1
646            }, // Rough
647            significant: chi_square > critical_value,
648        }
649    }
650
651    /// Normal CDF approximation
652    fn normal_cdf(x: f64) -> f64 {
653        // Simplified approximation
654        0.5 * (1.0 + erf(x / 2.0_f64.sqrt()))
655    }
656
657    /// Error function approximation
658    fn erf(x: f64) -> f64 {
659        // Abramowitz and Stegun approximation
660        let a1 = 0.254829592;
661        let a2 = -0.284496736;
662        let a3 = 1.421413741;
663        let a4 = -1.453152027;
664        let a5 = 1.061405429;
665        let p = 0.3275911;
666
667        let sign = if x < 0.0 { -1.0 } else { 1.0 };
668        let x = x.abs();
669
670        let t = 1.0 / (1.0 + p * x);
671        let y = ((a5 * t + a4).mul_add(t, a3).mul_add(t, a2).mul_add(t, a1) * t)
672            .mul_add(-(-x * x).exp(), 1.0);
673
674        sign * y
675    }
676}
677
678/// Result of statistical comparison
679#[derive(Debug, Clone, Serialize, Deserialize)]
680pub struct ComparisonResult {
681    /// Chi-square statistic
682    pub chi_square: f64,
683    /// Degrees of freedom
684    pub degrees_of_freedom: usize,
685    /// P-value
686    pub p_value: f64,
687    /// Whether difference is significant
688    pub significant: bool,
689}
690
691#[cfg(test)]
692mod tests {
693    use super::*;
694
695    #[test]
696    fn test_bit_string() {
697        let bs = BitString::from_int(5, 4); // 5 = 1010 in binary
698        assert_eq!(bs.bits, vec![1, 0, 1, 0]);
699        assert_eq!(bs.to_int(), 5);
700        assert_eq!(bs.weight(), 2);
701    }
702
703    #[test]
704    fn test_sampler_creation() {
705        let config = SamplingConfig::default();
706        let sampler = QuantumSampler::new(config);
707        assert_eq!(sampler.config.num_shots, 1024);
708    }
709
710    #[test]
711    fn test_uniform_state_sampling() {
712        let mut config = SamplingConfig::default();
713        config.num_shots = 100;
714        config.seed = Some(42);
715
716        let mut sampler = QuantumSampler::new(config);
717
718        // Create uniform superposition |+> = (|0> + |1>)/√2
719        let state = Array1::from_vec(vec![
720            Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
721            Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
722        ]);
723
724        let result = sampler.sample_state(&state).unwrap();
725        assert_eq!(result.num_shots, 100);
726        assert_eq!(result.outcomes.len(), 100);
727
728        // Check that we got both |0> and |1> outcomes
729        let has_zero = result.outcomes.iter().any(|bs| bs.to_int() == 0);
730        let has_one = result.outcomes.iter().any(|bs| bs.to_int() == 1);
731        assert!(has_zero && has_one);
732    }
733
734    #[test]
735    fn test_required_shots_calculation() {
736        let shots = analysis::required_shots_for_precision(0.01, 0.95);
737        assert!(shots > 9000); // Should need many shots for 1% precision
738    }
739}