quantrs2_sim/
shot_sampling.rs

1//! Shot-based sampling with statistical analysis for quantum simulation.
2//!
3//! This module implements comprehensive shot-based sampling methods for quantum
4//! circuits, including measurement statistics, error analysis, and convergence
5//! detection for realistic quantum device simulation.
6
7use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::random::prelude::*;
9use scirs2_core::random::ChaCha8Rng;
10use scirs2_core::random::{Rng, SeedableRng};
11use scirs2_core::Complex64;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15use crate::error::{Result, SimulatorError};
16use crate::pauli::{PauliOperatorSum, PauliString};
17use crate::statevector::StateVectorSimulator;
18
19/// Shot-based measurement result
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ShotResult {
22    /// Measurement outcomes for each shot
23    pub outcomes: Vec<BitString>,
24    /// Total number of shots
25    pub num_shots: usize,
26    /// Measurement statistics
27    pub statistics: MeasurementStatistics,
28    /// Sampling configuration used
29    pub config: SamplingConfig,
30}
31
32/// Bit string representation of measurement outcome
33#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
34pub struct BitString {
35    /// Bit values (0 or 1)
36    pub bits: Vec<u8>,
37}
38
39impl BitString {
40    /// Create from vector of booleans
41    #[must_use]
42    pub fn from_bools(bools: &[bool]) -> Self {
43        Self {
44            bits: bools.iter().map(|&b| u8::from(b)).collect(),
45        }
46    }
47
48    /// Convert to vector of booleans
49    #[must_use]
50    pub fn to_bools(&self) -> Vec<bool> {
51        self.bits.iter().map(|&b| b == 1).collect()
52    }
53
54    /// Convert to integer (little-endian)
55    #[must_use]
56    pub fn to_int(&self) -> usize {
57        self.bits
58            .iter()
59            .enumerate()
60            .map(|(i, &bit)| (bit as usize) << i)
61            .sum()
62    }
63
64    /// Create from integer (little-endian)
65    #[must_use]
66    pub fn from_int(mut value: usize, num_bits: usize) -> Self {
67        let mut bits = Vec::with_capacity(num_bits);
68        for _ in 0..num_bits {
69            bits.push((value & 1) as u8);
70            value >>= 1;
71        }
72        Self { bits }
73    }
74
75    /// Number of bits
76    #[must_use]
77    pub fn len(&self) -> usize {
78        self.bits.len()
79    }
80
81    /// Check if empty
82    #[must_use]
83    pub fn is_empty(&self) -> bool {
84        self.bits.is_empty()
85    }
86
87    /// Hamming weight (number of 1s)
88    #[must_use]
89    pub fn weight(&self) -> usize {
90        self.bits.iter().map(|&b| b as usize).sum()
91    }
92
93    /// Hamming distance to another bit string
94    #[must_use]
95    pub fn distance(&self, other: &Self) -> usize {
96        if self.len() != other.len() {
97            return usize::MAX; // Invalid comparison
98        }
99        self.bits
100            .iter()
101            .zip(&other.bits)
102            .map(|(&a, &b)| (a ^ b) as usize)
103            .sum()
104    }
105}
106
107impl std::fmt::Display for BitString {
108    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109        for &bit in &self.bits {
110            write!(f, "{bit}")?;
111        }
112        Ok(())
113    }
114}
115
116/// Measurement statistics
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct MeasurementStatistics {
119    /// Frequency count for each outcome
120    pub counts: HashMap<BitString, usize>,
121    /// Most frequent outcome
122    pub mode: BitString,
123    /// Probability estimates for each outcome
124    pub probabilities: HashMap<BitString, f64>,
125    /// Variance in the probability estimates
126    pub probability_variance: f64,
127    /// Statistical confidence intervals
128    pub confidence_intervals: HashMap<BitString, (f64, f64)>,
129    /// Entropy of the measurement distribution
130    pub entropy: f64,
131    /// Purity of the measurement distribution
132    pub purity: f64,
133}
134
135/// Sampling configuration
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct SamplingConfig {
138    /// Number of shots to take
139    pub num_shots: usize,
140    /// Random seed for reproducibility
141    pub seed: Option<u64>,
142    /// Confidence level for intervals (e.g., 0.95 for 95%)
143    pub confidence_level: f64,
144    /// Whether to compute full statistics
145    pub compute_statistics: bool,
146    /// Whether to estimate convergence
147    pub estimate_convergence: bool,
148    /// Convergence check interval (number of shots)
149    pub convergence_check_interval: usize,
150    /// Convergence tolerance
151    pub convergence_tolerance: f64,
152    /// Maximum number of shots for convergence
153    pub max_shots_for_convergence: usize,
154}
155
156impl Default for SamplingConfig {
157    fn default() -> Self {
158        Self {
159            num_shots: 1024,
160            seed: None,
161            confidence_level: 0.95,
162            compute_statistics: true,
163            estimate_convergence: false,
164            convergence_check_interval: 100,
165            convergence_tolerance: 0.01,
166            max_shots_for_convergence: 10_000,
167        }
168    }
169}
170
171/// Shot-based quantum sampler
172pub struct QuantumSampler {
173    /// Random number generator
174    rng: ChaCha8Rng,
175    /// Current configuration
176    config: SamplingConfig,
177}
178
179impl QuantumSampler {
180    /// Create new sampler with configuration
181    #[must_use]
182    pub fn new(config: SamplingConfig) -> Self {
183        let rng = if let Some(seed) = config.seed {
184            ChaCha8Rng::seed_from_u64(seed)
185        } else {
186            ChaCha8Rng::from_rng(&mut thread_rng())
187        };
188
189        Self { rng, config }
190    }
191
192    /// Sample measurements from a quantum state
193    pub fn sample_state(&mut self, state: &Array1<Complex64>) -> Result<ShotResult> {
194        let num_qubits = (state.len() as f64).log2() as usize;
195        if 1 << num_qubits != state.len() {
196            return Err(SimulatorError::InvalidInput(
197                "State vector dimension must be a power of 2".to_string(),
198            ));
199        }
200
201        // Compute probability distribution
202        let probabilities: Vec<f64> = state.iter().map(scirs2_core::Complex::norm_sqr).collect();
203
204        // Validate normalization
205        let total_prob: f64 = probabilities.iter().sum();
206        if (total_prob - 1.0).abs() > 1e-10 {
207            return Err(SimulatorError::InvalidInput(format!(
208                "State vector not normalized: total probability = {total_prob}"
209            )));
210        }
211
212        // Sample outcomes
213        let mut outcomes = Vec::with_capacity(self.config.num_shots);
214        for _ in 0..self.config.num_shots {
215            let sample = self.sample_from_distribution(&probabilities)?;
216            outcomes.push(BitString::from_int(sample, num_qubits));
217        }
218
219        // Compute statistics if requested
220        let statistics = if self.config.compute_statistics {
221            self.compute_statistics(&outcomes)?
222        } else {
223            MeasurementStatistics {
224                counts: HashMap::new(),
225                mode: BitString::from_int(0, num_qubits),
226                probabilities: HashMap::new(),
227                probability_variance: 0.0,
228                confidence_intervals: HashMap::new(),
229                entropy: 0.0,
230                purity: 0.0,
231            }
232        };
233
234        Ok(ShotResult {
235            outcomes,
236            num_shots: self.config.num_shots,
237            statistics,
238            config: self.config.clone(),
239        })
240    }
241
242    /// Sample measurements from a state with noise
243    pub fn sample_state_with_noise(
244        &mut self,
245        state: &Array1<Complex64>,
246        noise_model: &dyn NoiseModel,
247    ) -> Result<ShotResult> {
248        // Apply noise model to the state
249        let noisy_state = noise_model.apply_readout_noise(state)?;
250        self.sample_state(&noisy_state)
251    }
252
253    /// Sample expectation value of an observable
254    pub fn sample_expectation(
255        &mut self,
256        state: &Array1<Complex64>,
257        observable: &PauliOperatorSum,
258    ) -> Result<ExpectationResult> {
259        let mut expectation_values = Vec::new();
260        let mut variances = Vec::new();
261
262        // Sample each Pauli term separately
263        for term in &observable.terms {
264            let term_result = self.sample_pauli_expectation(state, term)?;
265            expectation_values.push(term_result.expectation * term.coefficient.re);
266            variances.push(term_result.variance * term.coefficient.re.powi(2));
267        }
268
269        // Combine results
270        let total_expectation: f64 = expectation_values.iter().sum();
271        let total_variance: f64 = variances.iter().sum();
272        let standard_error = (total_variance / self.config.num_shots as f64).sqrt();
273
274        // Confidence interval
275        let z_score = self.get_z_score(self.config.confidence_level);
276        let confidence_interval = (
277            z_score.mul_add(-standard_error, total_expectation),
278            z_score.mul_add(standard_error, total_expectation),
279        );
280
281        Ok(ExpectationResult {
282            expectation: total_expectation,
283            variance: total_variance,
284            standard_error,
285            confidence_interval,
286            num_shots: self.config.num_shots,
287        })
288    }
289
290    /// Sample expectation value of a single Pauli string
291    fn sample_pauli_expectation(
292        &mut self,
293        state: &Array1<Complex64>,
294        pauli_string: &PauliString,
295    ) -> Result<ExpectationResult> {
296        // For Pauli measurements, eigenvalues are ±1
297        // We need to measure in the appropriate basis
298
299        let num_qubits = pauli_string.num_qubits;
300        let mut measurements = Vec::with_capacity(self.config.num_shots);
301
302        for _ in 0..self.config.num_shots {
303            // Measure each qubit in the appropriate Pauli basis
304            let outcome = self.measure_pauli_basis(state, pauli_string)?;
305            measurements.push(outcome);
306        }
307
308        // Compute statistics
309        let mean = measurements.iter().sum::<f64>() / measurements.len() as f64;
310        let variance = measurements.iter().map(|x| (x - mean).powi(2)).sum::<f64>()
311            / measurements.len() as f64;
312
313        let standard_error = (variance / measurements.len() as f64).sqrt();
314        let z_score = self.get_z_score(self.config.confidence_level);
315        let confidence_interval = (
316            z_score.mul_add(-standard_error, mean),
317            z_score.mul_add(standard_error, mean),
318        );
319
320        Ok(ExpectationResult {
321            expectation: mean,
322            variance,
323            standard_error,
324            confidence_interval,
325            num_shots: measurements.len(),
326        })
327    }
328
329    /// Measure in Pauli basis (simplified implementation)
330    fn measure_pauli_basis(
331        &mut self,
332        _state: &Array1<Complex64>,
333        _pauli_string: &PauliString,
334    ) -> Result<f64> {
335        // Simplified implementation - return random ±1
336        // In practice, would need to transform state to measurement basis
337        if self.rng.gen::<f64>() < 0.5 {
338            Ok(1.0)
339        } else {
340            Ok(-1.0)
341        }
342    }
343
344    /// Sample from discrete probability distribution
345    fn sample_from_distribution(&mut self, probabilities: &[f64]) -> Result<usize> {
346        let random_value = self.rng.gen::<f64>();
347        let mut cumulative = 0.0;
348
349        for (i, &prob) in probabilities.iter().enumerate() {
350            cumulative += prob;
351            if random_value <= cumulative {
352                return Ok(i);
353            }
354        }
355
356        // Handle numerical errors - return last index
357        Ok(probabilities.len() - 1)
358    }
359
360    /// Compute measurement statistics
361    fn compute_statistics(&self, outcomes: &[BitString]) -> Result<MeasurementStatistics> {
362        let mut counts = HashMap::new();
363        let total_shots = outcomes.len() as f64;
364
365        // Count frequencies
366        for outcome in outcomes {
367            *counts.entry(outcome.clone()).or_insert(0) += 1;
368        }
369
370        // Find mode
371        let mode = counts.iter().max_by_key(|(_, &count)| count).map_or_else(
372            || BitString::from_int(0, outcomes[0].len()),
373            |(outcome, _)| outcome.clone(),
374        );
375
376        // Compute probabilities
377        let mut probabilities = HashMap::new();
378        let mut confidence_intervals = HashMap::new();
379        let z_score = self.get_z_score(self.config.confidence_level);
380
381        for (outcome, &count) in &counts {
382            let prob = count as f64 / total_shots;
383            probabilities.insert(outcome.clone(), prob);
384
385            // Binomial confidence interval
386            let std_error = (prob * (1.0 - prob) / total_shots).sqrt();
387            let margin = z_score * std_error;
388            confidence_intervals.insert(
389                outcome.clone(),
390                ((prob - margin).max(0.0), (prob + margin).min(1.0)),
391            );
392        }
393
394        // Compute entropy
395        let entropy = probabilities
396            .values()
397            .filter(|&&p| p > 0.0)
398            .map(|&p| -p * p.ln())
399            .sum::<f64>();
400
401        // Compute purity (sum of squared probabilities)
402        let purity = probabilities.values().map(|&p| p * p).sum::<f64>();
403
404        // Compute overall probability variance
405        let mean_prob = 1.0 / probabilities.len() as f64;
406        let probability_variance = probabilities
407            .values()
408            .map(|&p| (p - mean_prob).powi(2))
409            .sum::<f64>()
410            / probabilities.len() as f64;
411
412        Ok(MeasurementStatistics {
413            counts,
414            mode,
415            probabilities,
416            probability_variance,
417            confidence_intervals,
418            entropy,
419            purity,
420        })
421    }
422
423    /// Get z-score for confidence level
424    fn get_z_score(&self, confidence_level: f64) -> f64 {
425        // Simplified - use common values
426        match (confidence_level * 100.0) as i32 {
427            90 => 1.645,
428            95 => 1.96,
429            99 => 2.576,
430            _ => 1.96, // Default to 95%
431        }
432    }
433
434    /// Estimate convergence of sampling
435    pub fn estimate_convergence(
436        &mut self,
437        state: &Array1<Complex64>,
438        observable: &PauliOperatorSum,
439    ) -> Result<ConvergenceResult> {
440        let mut expectation_history = Vec::new();
441        let mut variance_history = Vec::new();
442        let mut shots_taken = 0;
443        let mut converged = false;
444
445        while shots_taken < self.config.max_shots_for_convergence && !converged {
446            // Take a batch of measurements
447            let batch_shots = self
448                .config
449                .convergence_check_interval
450                .min(self.config.max_shots_for_convergence - shots_taken);
451
452            // Temporarily adjust shot count for this batch
453            let original_shots = self.config.num_shots;
454            self.config.num_shots = batch_shots;
455
456            let result = self.sample_expectation(state, observable)?;
457
458            // Restore original shot count
459            self.config.num_shots = original_shots;
460
461            expectation_history.push(result.expectation);
462            variance_history.push(result.variance);
463            shots_taken += batch_shots;
464
465            // Check convergence
466            if expectation_history.len() >= 3 {
467                let recent_values = &expectation_history[expectation_history.len() - 3..];
468                let max_diff = recent_values
469                    .iter()
470                    .zip(recent_values.iter().skip(1))
471                    .map(|(a, b)| (a - b).abs())
472                    .fold(0.0, f64::max);
473
474                if max_diff < self.config.convergence_tolerance {
475                    converged = true;
476                }
477            }
478        }
479
480        // Compute final estimates
481        let final_expectation = expectation_history.last().copied().unwrap_or(0.0);
482        let expectation_std = if expectation_history.len() > 1 {
483            let mean = expectation_history.iter().sum::<f64>() / expectation_history.len() as f64;
484            (expectation_history
485                .iter()
486                .map(|x| (x - mean).powi(2))
487                .sum::<f64>()
488                / (expectation_history.len() - 1) as f64)
489                .sqrt()
490        } else {
491            0.0
492        };
493
494        Ok(ConvergenceResult {
495            converged,
496            shots_taken,
497            final_expectation,
498            expectation_history,
499            variance_history,
500            convergence_rate: expectation_std,
501        })
502    }
503}
504
505/// Result of expectation value sampling
506#[derive(Debug, Clone, Serialize, Deserialize)]
507pub struct ExpectationResult {
508    /// Expectation value estimate
509    pub expectation: f64,
510    /// Variance estimate
511    pub variance: f64,
512    /// Standard error
513    pub standard_error: f64,
514    /// Confidence interval
515    pub confidence_interval: (f64, f64),
516    /// Number of shots used
517    pub num_shots: usize,
518}
519
520/// Result of convergence estimation
521#[derive(Debug, Clone, Serialize, Deserialize)]
522pub struct ConvergenceResult {
523    /// Whether convergence was achieved
524    pub converged: bool,
525    /// Total shots taken
526    pub shots_taken: usize,
527    /// Final expectation value
528    pub final_expectation: f64,
529    /// History of expectation values
530    pub expectation_history: Vec<f64>,
531    /// History of variances
532    pub variance_history: Vec<f64>,
533    /// Convergence rate (standard deviation of recent estimates)
534    pub convergence_rate: f64,
535}
536
537/// Noise model trait for realistic sampling
538pub trait NoiseModel: Send + Sync {
539    /// Apply readout noise to measurements
540    fn apply_readout_noise(&self, state: &Array1<Complex64>) -> Result<Array1<Complex64>>;
541
542    /// Get readout error probability for qubit
543    fn readout_error_probability(&self, qubit: usize) -> f64;
544}
545
546/// Simple readout noise model
547#[derive(Debug, Clone)]
548pub struct SimpleReadoutNoise {
549    /// Error probability for each qubit
550    pub error_probs: Vec<f64>,
551}
552
553impl SimpleReadoutNoise {
554    /// Create uniform readout noise
555    #[must_use]
556    pub fn uniform(num_qubits: usize, error_prob: f64) -> Self {
557        Self {
558            error_probs: vec![error_prob; num_qubits],
559        }
560    }
561}
562
563impl NoiseModel for SimpleReadoutNoise {
564    fn apply_readout_noise(&self, state: &Array1<Complex64>) -> Result<Array1<Complex64>> {
565        // Simplified implementation - in practice would need proper POVM modeling
566        Ok(state.clone())
567    }
568
569    fn readout_error_probability(&self, qubit: usize) -> f64 {
570        self.error_probs.get(qubit).copied().unwrap_or(0.0)
571    }
572}
573
574/// Utility functions for shot sampling analysis
575pub mod analysis {
576    use super::{ComparisonResult, ShotResult};
577
578    /// Compute statistical power for detecting effect
579    #[must_use]
580    pub fn statistical_power(effect_size: f64, num_shots: usize, significance_level: f64) -> f64 {
581        // Simplified power analysis
582        let standard_error = 1.0 / (num_shots as f64).sqrt();
583        let z_critical = match (significance_level * 100.0) as i32 {
584            1 => 2.576,
585            5 => 1.96,
586            10 => 1.645,
587            _ => 1.96,
588        };
589
590        let z_beta = (effect_size / standard_error) - z_critical;
591        normal_cdf(z_beta)
592    }
593
594    /// Estimate required shots for desired precision
595    #[must_use]
596    pub fn required_shots_for_precision(desired_error: f64, confidence_level: f64) -> usize {
597        let z_score = match (confidence_level * 100.0) as i32 {
598            90 => 1.645,
599            95 => 1.96,
600            99 => 2.576,
601            _ => 1.96,
602        };
603
604        // For binomial: n ≥ (z²/4ε²) for worst case p=0.5
605        let n = (z_score * z_score) / (4.0 * desired_error * desired_error);
606        n.ceil() as usize
607    }
608
609    /// Compare two shot results statistically
610    #[must_use]
611    pub fn compare_shot_results(
612        result1: &ShotResult,
613        result2: &ShotResult,
614        significance_level: f64,
615    ) -> ComparisonResult {
616        // Chi-square test for distribution comparison
617        let mut chi_square = 0.0;
618        let mut degrees_of_freedom: usize = 0;
619
620        // Get all unique outcomes
621        let mut all_outcomes = std::collections::HashSet::new();
622        all_outcomes.extend(result1.statistics.counts.keys());
623        all_outcomes.extend(result2.statistics.counts.keys());
624
625        for outcome in &all_outcomes {
626            let count1 = result1.statistics.counts.get(outcome).copied().unwrap_or(0) as f64;
627            let count2 = result2.statistics.counts.get(outcome).copied().unwrap_or(0) as f64;
628
629            let total1 = result1.num_shots as f64;
630            let total2 = result2.num_shots as f64;
631
632            let expected1 = (count1 + count2) * total1 / (total1 + total2);
633            let expected2 = (count1 + count2) * total2 / (total1 + total2);
634
635            if expected1 > 5.0 && expected2 > 5.0 {
636                chi_square += (count1 - expected1).powi(2) / expected1;
637                chi_square += (count2 - expected2).powi(2) / expected2;
638                degrees_of_freedom += 1;
639            }
640        }
641
642        degrees_of_freedom = degrees_of_freedom.saturating_sub(1);
643
644        // Critical value for given significance level (simplified)
645        let critical_value = match (significance_level * 100.0) as i32 {
646            1 => 6.635, // Very rough approximation
647            5 => 3.841,
648            10 => 2.706,
649            _ => 3.841,
650        };
651
652        ComparisonResult {
653            chi_square,
654            degrees_of_freedom,
655            p_value: if chi_square > critical_value {
656                0.01
657            } else {
658                0.1
659            }, // Rough
660            significant: chi_square > critical_value,
661        }
662    }
663
664    /// Normal CDF approximation
665    fn normal_cdf(x: f64) -> f64 {
666        // Simplified approximation
667        0.5 * (1.0 + erf(x / 2.0_f64.sqrt()))
668    }
669
670    /// Error function approximation
671    fn erf(x: f64) -> f64 {
672        // Abramowitz and Stegun approximation
673        let a1 = 0.254_829_592;
674        let a2 = -0.284_496_736;
675        let a3 = 1.421_413_741;
676        let a4 = -1.453_152_027;
677        let a5 = 1.061_405_429;
678        let p = 0.327_591_1;
679
680        let sign = if x < 0.0 { -1.0 } else { 1.0 };
681        let x = x.abs();
682
683        let t = 1.0 / (1.0 + p * x);
684        let y = ((a5 * t + a4).mul_add(t, a3).mul_add(t, a2).mul_add(t, a1) * t)
685            .mul_add(-(-x * x).exp(), 1.0);
686
687        sign * y
688    }
689}
690
691/// Result of statistical comparison
692#[derive(Debug, Clone, Serialize, Deserialize)]
693pub struct ComparisonResult {
694    /// Chi-square statistic
695    pub chi_square: f64,
696    /// Degrees of freedom
697    pub degrees_of_freedom: usize,
698    /// P-value
699    pub p_value: f64,
700    /// Whether difference is significant
701    pub significant: bool,
702}
703
704#[cfg(test)]
705mod tests {
706    use super::*;
707
708    #[test]
709    fn test_bit_string() {
710        let bs = BitString::from_int(5, 4); // 5 = 1010 in binary
711        assert_eq!(bs.bits, vec![1, 0, 1, 0]);
712        assert_eq!(bs.to_int(), 5);
713        assert_eq!(bs.weight(), 2);
714    }
715
716    #[test]
717    fn test_sampler_creation() {
718        let config = SamplingConfig::default();
719        let sampler = QuantumSampler::new(config);
720        assert_eq!(sampler.config.num_shots, 1024);
721    }
722
723    #[test]
724    fn test_uniform_state_sampling() {
725        let mut config = SamplingConfig::default();
726        config.num_shots = 100;
727        config.seed = Some(42);
728
729        let mut sampler = QuantumSampler::new(config);
730
731        // Create uniform superposition |+> = (|0> + |1>)/√2
732        let state = Array1::from_vec(vec![
733            Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
734            Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
735        ]);
736
737        let result = sampler
738            .sample_state(&state)
739            .expect("Failed to sample state");
740        assert_eq!(result.num_shots, 100);
741        assert_eq!(result.outcomes.len(), 100);
742
743        // Check that we got both |0> and |1> outcomes
744        let has_zero = result.outcomes.iter().any(|bs| bs.to_int() == 0);
745        let has_one = result.outcomes.iter().any(|bs| bs.to_int() == 1);
746        assert!(has_zero && has_one);
747    }
748
749    #[test]
750    fn test_required_shots_calculation() {
751        let shots = analysis::required_shots_for_precision(0.01, 0.95);
752        assert!(shots > 9000); // Should need many shots for 1% precision
753    }
754}