quantrs2_core/batch/
measurement.rs

1//! Batch measurement operations using SciRS2 parallel algorithms
2
3use super::{BatchMeasurementResult, BatchStateVector};
4use crate::{
5    error::{QuantRS2Error, QuantRS2Result},
6    qubit::QubitId,
7};
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::random::prelude::*;
10use scirs2_core::Complex64;
11// use scirs2_core::parallel_ops::*;
12use crate::parallel_ops_stubs::*;
13use std::collections::HashMap;
14
15/// Batch measurement configuration
16#[derive(Debug, Clone)]
17pub struct MeasurementConfig {
18    /// Number of measurement shots per state
19    pub shots: usize,
20    /// Whether to return post-measurement states
21    pub return_states: bool,
22    /// Random seed for reproducibility
23    pub seed: Option<u64>,
24    /// Use parallel processing
25    pub parallel: bool,
26}
27
28impl Default for MeasurementConfig {
29    fn default() -> Self {
30        Self {
31            shots: 1024,
32            return_states: false,
33            seed: None,
34            parallel: true,
35        }
36    }
37}
38
39/// Perform batch measurements on multiple quantum states
40pub fn measure_batch(
41    batch: &BatchStateVector,
42    qubits_to_measure: &[QubitId],
43    config: MeasurementConfig,
44) -> QuantRS2Result<BatchMeasurementResult> {
45    let batch_size = batch.batch_size();
46    let n_qubits = batch.n_qubits;
47    let num_measurements = qubits_to_measure.len();
48
49    // Validate qubits
50    for &qubit in qubits_to_measure {
51        if qubit.0 as usize >= n_qubits {
52            return Err(QuantRS2Error::InvalidQubitId(qubit.0));
53        }
54    }
55
56    // Initialize results
57    let mut outcomes = Array2::zeros((batch_size, num_measurements));
58    let mut probabilities = Array2::zeros((batch_size, num_measurements));
59    let post_measurement_states = if config.return_states {
60        Some(batch.states.clone())
61    } else {
62        None
63    };
64
65    // Perform measurements
66    if config.parallel && batch_size > 16 {
67        // Parallel measurement
68        let results: Vec<(Vec<u8>, Vec<f64>)> = (0..batch_size)
69            .into_par_iter()
70            .map(|i| {
71                let state = batch.states.row(i);
72                measure_single_state(&state.to_owned(), qubits_to_measure, &config)
73            })
74            .collect();
75
76        // Collect results
77        for (i, (outcome, probs)) in results.into_iter().enumerate() {
78            for (j, &val) in outcome.iter().enumerate() {
79                outcomes[[i, j]] = val;
80            }
81            for (j, &prob) in probs.iter().enumerate() {
82                probabilities[[i, j]] = prob;
83            }
84        }
85    } else {
86        // Sequential measurement
87        for i in 0..batch_size {
88            let state = batch.states.row(i);
89            let (outcome, probs) =
90                measure_single_state(&state.to_owned(), qubits_to_measure, &config);
91
92            for (j, &val) in outcome.iter().enumerate() {
93                outcomes[[i, j]] = val;
94            }
95            for (j, &prob) in probs.iter().enumerate() {
96                probabilities[[i, j]] = prob;
97            }
98        }
99    }
100
101    Ok(BatchMeasurementResult {
102        outcomes,
103        probabilities,
104        post_measurement_states,
105    })
106}
107
108/// Measure a single state
109fn measure_single_state(
110    state: &Array1<Complex64>,
111    qubits_to_measure: &[QubitId],
112    config: &MeasurementConfig,
113) -> (Vec<u8>, Vec<f64>) {
114    let mut rng = config.seed.map_or_else(
115        || StdRng::from_seed(thread_rng().gen()),
116        StdRng::seed_from_u64,
117    );
118
119    let mut outcomes = Vec::with_capacity(qubits_to_measure.len());
120    let mut probabilities = Vec::with_capacity(qubits_to_measure.len());
121
122    for &qubit in qubits_to_measure {
123        let (outcome, prob) = measure_qubit(state, qubit, &mut rng);
124        outcomes.push(outcome);
125        probabilities.push(prob);
126    }
127
128    (outcomes, probabilities)
129}
130
131/// Measure a single qubit
132fn measure_qubit(state: &Array1<Complex64>, qubit: QubitId, rng: &mut StdRng) -> (u8, f64) {
133    let qubit_idx = qubit.0 as usize;
134    let state_size = state.len();
135    let _n_qubits = (state_size as f64).log2() as usize;
136
137    // Calculate probability of measuring |0>
138    let mut prob_zero = 0.0;
139    let qubit_mask = 1 << qubit_idx;
140
141    for i in 0..state_size {
142        if i & qubit_mask == 0 {
143            prob_zero += state[i].norm_sqr();
144        }
145    }
146
147    // Perform measurement
148    let outcome = u8::from(rng.random::<f64>() >= prob_zero);
149    let probability = if outcome == 0 {
150        prob_zero
151    } else {
152        1.0 - prob_zero
153    };
154
155    (outcome, probability)
156}
157
158/// Perform batch measurements with statistics
159pub fn measure_batch_with_statistics(
160    batch: &BatchStateVector,
161    qubits_to_measure: &[QubitId],
162    shots: usize,
163) -> QuantRS2Result<BatchMeasurementStatistics> {
164    let batch_size = batch.batch_size();
165    let measurement_size = qubits_to_measure.len();
166
167    // Collect measurement statistics for each state in parallel
168    let statistics: Vec<_> = (0..batch_size)
169        .into_par_iter()
170        .map(|i| {
171            let state = batch.states.row(i);
172            compute_measurement_statistics(&state.to_owned(), qubits_to_measure, shots)
173        })
174        .collect();
175
176    Ok(BatchMeasurementStatistics {
177        statistics,
178        batch_size,
179        measurement_size,
180        shots,
181    })
182}
183
184/// Measurement statistics for a batch
185#[derive(Debug, Clone)]
186pub struct BatchMeasurementStatistics {
187    /// Statistics for each state in the batch
188    pub statistics: Vec<MeasurementStatistics>,
189    /// Batch size
190    pub batch_size: usize,
191    /// Number of qubits measured
192    pub measurement_size: usize,
193    /// Number of shots
194    pub shots: usize,
195}
196
197/// Measurement statistics for a single state
198#[derive(Debug, Clone)]
199pub struct MeasurementStatistics {
200    /// Count of each measurement outcome
201    pub counts: HashMap<String, usize>,
202    /// Probability of each outcome
203    pub probabilities: HashMap<String, f64>,
204    /// Most likely outcome
205    pub most_likely: String,
206    /// Entropy of the measurement distribution
207    pub entropy: f64,
208}
209
210/// Compute measurement statistics for a single state
211fn compute_measurement_statistics(
212    state: &Array1<Complex64>,
213    qubits_to_measure: &[QubitId],
214    shots: usize,
215) -> MeasurementStatistics {
216    let mut rng = StdRng::from_seed(thread_rng().gen());
217    let mut counts: HashMap<String, usize> = HashMap::new();
218
219    // Perform measurements
220    for _ in 0..shots {
221        let mut outcome = String::new();
222        for &qubit in qubits_to_measure {
223            let (bit, _) = measure_qubit(state, qubit, &mut rng);
224            outcome.push(if bit == 0 { '0' } else { '1' });
225        }
226        *counts.entry(outcome).or_insert(0) += 1;
227    }
228
229    // Compute probabilities
230    let mut probabilities = HashMap::new();
231    let mut most_likely = String::new();
232    let mut max_count = 0;
233
234    for (outcome, &count) in &counts {
235        let prob = count as f64 / shots as f64;
236        probabilities.insert(outcome.clone(), prob);
237
238        if count > max_count {
239            max_count = count;
240            most_likely.clone_from(outcome);
241        }
242    }
243
244    // Compute entropy
245    let entropy = -probabilities
246        .values()
247        .filter(|&&p| p > 0.0)
248        .map(|&p| p * p.log2())
249        .sum::<f64>();
250
251    MeasurementStatistics {
252        counts,
253        probabilities,
254        most_likely,
255        entropy,
256    }
257}
258
259/// Batch expectation value measurement
260pub fn measure_expectation_batch(
261    batch: &BatchStateVector,
262    observable_qubits: &[(QubitId, Array2<Complex64>)],
263) -> QuantRS2Result<Vec<f64>> {
264    let batch_size = batch.batch_size();
265
266    // Compute expectation values in parallel
267    let expectations: Vec<_> = (0..batch_size)
268        .into_par_iter()
269        .map(|i| {
270            let state = batch.states.row(i);
271            compute_observable_expectation(&state.to_owned(), observable_qubits, batch.n_qubits)
272        })
273        .collect::<QuantRS2Result<Vec<_>>>()?;
274
275    Ok(expectations)
276}
277
278/// Compute expectation value of an observable
279fn compute_observable_expectation(
280    state: &Array1<Complex64>,
281    observable_qubits: &[(QubitId, Array2<Complex64>)],
282    n_qubits: usize,
283) -> QuantRS2Result<f64> {
284    // For simplicity, compute single-qubit observable expectations
285    // and multiply them (assuming they commute)
286    let mut total_expectation = 1.0;
287
288    for (qubit, observable) in observable_qubits {
289        let qubit_idx = qubit.0 as usize;
290        if qubit_idx >= n_qubits {
291            return Err(QuantRS2Error::InvalidQubitId(qubit.0));
292        }
293
294        // Compute expectation for this qubit
295        let exp = compute_single_qubit_expectation(state, *qubit, observable, n_qubits)?;
296        total_expectation *= exp;
297    }
298
299    Ok(total_expectation)
300}
301
302/// Compute single-qubit expectation value
303fn compute_single_qubit_expectation(
304    state: &Array1<Complex64>,
305    qubit: QubitId,
306    observable: &Array2<Complex64>,
307    n_qubits: usize,
308) -> QuantRS2Result<f64> {
309    if observable.shape() != [2, 2] {
310        return Err(QuantRS2Error::InvalidInput(
311            "Observable must be a 2x2 matrix".to_string(),
312        ));
313    }
314
315    let qubit_idx = qubit.0 as usize;
316    let state_size = 1 << n_qubits;
317    let qubit_mask = 1 << qubit_idx;
318
319    let mut expectation = Complex64::new(0.0, 0.0);
320
321    for i in 0..state_size {
322        for j in 0..state_size {
323            // Check if states differ only in the target qubit
324            if (i ^ j) == qubit_mask {
325                let qi = (i >> qubit_idx) & 1;
326                let qj = (j >> qubit_idx) & 1;
327
328                expectation += state[i].conj() * observable[[qi, qj]] * state[j];
329            } else if i == j {
330                let qi = (i >> qubit_idx) & 1;
331                expectation += state[i].conj() * observable[[qi, qi]] * state[i];
332            }
333        }
334    }
335
336    Ok(expectation.re)
337}
338
339/// Batch tomography measurements
340pub fn measure_tomography_batch(
341    batch: &BatchStateVector,
342    qubits: &[QubitId],
343    basis: TomographyBasis,
344) -> QuantRS2Result<BatchTomographyResult> {
345    let measurements = match basis {
346        TomographyBasis::Pauli => get_pauli_measurements(qubits),
347        TomographyBasis::Computational => get_computational_measurements(qubits),
348        TomographyBasis::Custom(ref bases) => bases.clone(),
349    };
350
351    let mut results = Vec::new();
352
353    for (name, observable_qubits) in measurements {
354        let expectations = measure_expectation_batch(batch, &observable_qubits)?;
355        results.push((name, expectations));
356    }
357
358    Ok(BatchTomographyResult {
359        measurements: results,
360        basis,
361        qubits: qubits.to_vec(),
362    })
363}
364
365/// Type alias for custom measurement basis
366pub type CustomMeasurementBasis = Vec<(String, Vec<(QubitId, Array2<Complex64>)>)>;
367
368/// Tomography basis
369#[derive(Debug, Clone)]
370pub enum TomographyBasis {
371    /// Pauli basis (X, Y, Z)
372    Pauli,
373    /// Computational basis (|0>, |1>)
374    Computational,
375    /// Custom measurement basis
376    Custom(CustomMeasurementBasis),
377}
378
379/// Batch tomography result
380#[derive(Debug, Clone)]
381pub struct BatchTomographyResult {
382    /// Measurement results (name, expectations for each state)
383    pub measurements: Vec<(String, Vec<f64>)>,
384    /// Basis used
385    pub basis: TomographyBasis,
386    /// Qubits measured
387    pub qubits: Vec<QubitId>,
388}
389
390/// Get Pauli basis measurements
391fn get_pauli_measurements(qubits: &[QubitId]) -> Vec<(String, Vec<(QubitId, Array2<Complex64>)>)> {
392    use scirs2_core::ndarray::array;
393
394    let pauli_x = array![
395        [Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)],
396        [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]
397    ];
398
399    let pauli_y = array![
400        [Complex64::new(0.0, 0.0), Complex64::new(0.0, -1.0)],
401        [Complex64::new(0.0, 1.0), Complex64::new(0.0, 0.0)]
402    ];
403
404    let pauli_z = array![
405        [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
406        [Complex64::new(0.0, 0.0), Complex64::new(-1.0, 0.0)]
407    ];
408
409    let mut measurements = Vec::new();
410
411    for &qubit in qubits {
412        measurements.push((format!("X{}", qubit.0), vec![(qubit, pauli_x.clone())]));
413        measurements.push((format!("Y{}", qubit.0), vec![(qubit, pauli_y.clone())]));
414        measurements.push((format!("Z{}", qubit.0), vec![(qubit, pauli_z.clone())]));
415    }
416
417    measurements
418}
419
420/// Get computational basis measurements
421fn get_computational_measurements(
422    qubits: &[QubitId],
423) -> Vec<(String, Vec<(QubitId, Array2<Complex64>)>)> {
424    use scirs2_core::ndarray::array;
425
426    let proj_0 = array![
427        [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
428        [Complex64::new(0.0, 0.0), Complex64::new(0.0, 0.0)]
429    ];
430
431    let proj_1 = array![
432        [Complex64::new(0.0, 0.0), Complex64::new(0.0, 0.0)],
433        [Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)]
434    ];
435
436    let mut measurements = Vec::new();
437
438    for &qubit in qubits {
439        measurements.push((format!("|0⟩{}", qubit.0), vec![(qubit, proj_0.clone())]));
440        measurements.push((format!("|1⟩{}", qubit.0), vec![(qubit, proj_1.clone())]));
441    }
442
443    measurements
444}
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449    use scirs2_core::ndarray::array;
450
451    #[test]
452    fn test_batch_measurement() {
453        let batch = BatchStateVector::new(5, 2, Default::default())
454            .expect("Failed to create batch state vector");
455        let config = MeasurementConfig {
456            shots: 100,
457            return_states: false,
458            seed: Some(42),
459            parallel: false,
460        };
461
462        let result = measure_batch(&batch, &[QubitId(0), QubitId(1)], config)
463            .expect("Batch measurement failed");
464
465        assert_eq!(result.outcomes.shape(), &[5, 2]);
466        assert_eq!(result.probabilities.shape(), &[5, 2]);
467
468        // All states are |00>, so measurements should be 0
469        for i in 0..5 {
470            assert_eq!(result.outcomes[[i, 0]], 0);
471            assert_eq!(result.outcomes[[i, 1]], 0);
472            assert!((result.probabilities[[i, 0]] - 1.0).abs() < 1e-10);
473            assert!((result.probabilities[[i, 1]] - 1.0).abs() < 1e-10);
474        }
475    }
476
477    #[test]
478    fn test_measurement_statistics() {
479        // Create a superposition state
480        let mut states = Array2::zeros((1, 2));
481        states[[0, 0]] = Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0);
482        states[[0, 1]] = Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0);
483
484        let batch = BatchStateVector::from_states(states, Default::default())
485            .expect("Failed to create batch from states");
486
487        let stats = measure_batch_with_statistics(&batch, &[QubitId(0)], 1000)
488            .expect("Failed to measure batch statistics");
489
490        assert_eq!(stats.batch_size, 1);
491        assert_eq!(stats.measurement_size, 1);
492
493        let stat = &stats.statistics[0];
494        // Should have roughly equal counts for "0" and "1"
495        assert!(stat.counts.contains_key("0"));
496        assert!(stat.counts.contains_key("1"));
497
498        // Entropy should be close to 1 for equal superposition
499        assert!((stat.entropy - 1.0).abs() < 0.1);
500    }
501
502    #[test]
503    fn test_expectation_measurement() {
504        let batch = BatchStateVector::new(3, 1, Default::default())
505            .expect("Failed to create batch state vector");
506
507        // Pauli Z observable
508        let pauli_z = array![
509            [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
510            [Complex64::new(0.0, 0.0), Complex64::new(-1.0, 0.0)]
511        ];
512
513        let expectations = measure_expectation_batch(&batch, &[(QubitId(0), pauli_z)])
514            .expect("Expectation value measurement failed");
515
516        assert_eq!(expectations.len(), 3);
517        // All states are |0>, so Z expectation is +1
518        for exp in expectations {
519            assert!((exp - 1.0).abs() < 1e-10);
520        }
521    }
522}