1use super::{BatchMeasurementResult, BatchStateVector};
4use crate::{
5 error::{QuantRS2Error, QuantRS2Result},
6 qubit::QubitId,
7};
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::random::prelude::*;
10use scirs2_core::Complex64;
11use crate::parallel_ops_stubs::*;
13use std::collections::HashMap;
14
15#[derive(Debug, Clone)]
17pub struct MeasurementConfig {
18 pub shots: usize,
20 pub return_states: bool,
22 pub seed: Option<u64>,
24 pub parallel: bool,
26}
27
28impl Default for MeasurementConfig {
29 fn default() -> Self {
30 Self {
31 shots: 1024,
32 return_states: false,
33 seed: None,
34 parallel: true,
35 }
36 }
37}
38
39pub fn measure_batch(
41 batch: &BatchStateVector,
42 qubits_to_measure: &[QubitId],
43 config: MeasurementConfig,
44) -> QuantRS2Result<BatchMeasurementResult> {
45 let batch_size = batch.batch_size();
46 let n_qubits = batch.n_qubits;
47 let num_measurements = qubits_to_measure.len();
48
49 for &qubit in qubits_to_measure {
51 if qubit.0 as usize >= n_qubits {
52 return Err(QuantRS2Error::InvalidQubitId(qubit.0));
53 }
54 }
55
56 let mut outcomes = Array2::zeros((batch_size, num_measurements));
58 let mut probabilities = Array2::zeros((batch_size, num_measurements));
59 let post_measurement_states = if config.return_states {
60 Some(batch.states.clone())
61 } else {
62 None
63 };
64
65 if config.parallel && batch_size > 16 {
67 let results: Vec<(Vec<u8>, Vec<f64>)> = (0..batch_size)
69 .into_par_iter()
70 .map(|i| {
71 let state = batch.states.row(i);
72 measure_single_state(&state.to_owned(), qubits_to_measure, &config)
73 })
74 .collect();
75
76 for (i, (outcome, probs)) in results.into_iter().enumerate() {
78 for (j, &val) in outcome.iter().enumerate() {
79 outcomes[[i, j]] = val;
80 }
81 for (j, &prob) in probs.iter().enumerate() {
82 probabilities[[i, j]] = prob;
83 }
84 }
85 } else {
86 for i in 0..batch_size {
88 let state = batch.states.row(i);
89 let (outcome, probs) =
90 measure_single_state(&state.to_owned(), qubits_to_measure, &config);
91
92 for (j, &val) in outcome.iter().enumerate() {
93 outcomes[[i, j]] = val;
94 }
95 for (j, &prob) in probs.iter().enumerate() {
96 probabilities[[i, j]] = prob;
97 }
98 }
99 }
100
101 Ok(BatchMeasurementResult {
102 outcomes,
103 probabilities,
104 post_measurement_states,
105 })
106}
107
108fn measure_single_state(
110 state: &Array1<Complex64>,
111 qubits_to_measure: &[QubitId],
112 config: &MeasurementConfig,
113) -> (Vec<u8>, Vec<f64>) {
114 let mut rng = config.seed.map_or_else(
115 || StdRng::from_seed(thread_rng().gen()),
116 StdRng::seed_from_u64,
117 );
118
119 let mut outcomes = Vec::with_capacity(qubits_to_measure.len());
120 let mut probabilities = Vec::with_capacity(qubits_to_measure.len());
121
122 for &qubit in qubits_to_measure {
123 let (outcome, prob) = measure_qubit(state, qubit, &mut rng);
124 outcomes.push(outcome);
125 probabilities.push(prob);
126 }
127
128 (outcomes, probabilities)
129}
130
131fn measure_qubit(state: &Array1<Complex64>, qubit: QubitId, rng: &mut StdRng) -> (u8, f64) {
133 let qubit_idx = qubit.0 as usize;
134 let state_size = state.len();
135 let _n_qubits = (state_size as f64).log2() as usize;
136
137 let mut prob_zero = 0.0;
139 let qubit_mask = 1 << qubit_idx;
140
141 for i in 0..state_size {
142 if i & qubit_mask == 0 {
143 prob_zero += state[i].norm_sqr();
144 }
145 }
146
147 let outcome = u8::from(rng.random::<f64>() >= prob_zero);
149 let probability = if outcome == 0 {
150 prob_zero
151 } else {
152 1.0 - prob_zero
153 };
154
155 (outcome, probability)
156}
157
158pub fn measure_batch_with_statistics(
160 batch: &BatchStateVector,
161 qubits_to_measure: &[QubitId],
162 shots: usize,
163) -> QuantRS2Result<BatchMeasurementStatistics> {
164 let batch_size = batch.batch_size();
165 let measurement_size = qubits_to_measure.len();
166
167 let statistics: Vec<_> = (0..batch_size)
169 .into_par_iter()
170 .map(|i| {
171 let state = batch.states.row(i);
172 compute_measurement_statistics(&state.to_owned(), qubits_to_measure, shots)
173 })
174 .collect();
175
176 Ok(BatchMeasurementStatistics {
177 statistics,
178 batch_size,
179 measurement_size,
180 shots,
181 })
182}
183
184#[derive(Debug, Clone)]
186pub struct BatchMeasurementStatistics {
187 pub statistics: Vec<MeasurementStatistics>,
189 pub batch_size: usize,
191 pub measurement_size: usize,
193 pub shots: usize,
195}
196
197#[derive(Debug, Clone)]
199pub struct MeasurementStatistics {
200 pub counts: HashMap<String, usize>,
202 pub probabilities: HashMap<String, f64>,
204 pub most_likely: String,
206 pub entropy: f64,
208}
209
210fn compute_measurement_statistics(
212 state: &Array1<Complex64>,
213 qubits_to_measure: &[QubitId],
214 shots: usize,
215) -> MeasurementStatistics {
216 let mut rng = StdRng::from_seed(thread_rng().gen());
217 let mut counts: HashMap<String, usize> = HashMap::new();
218
219 for _ in 0..shots {
221 let mut outcome = String::new();
222 for &qubit in qubits_to_measure {
223 let (bit, _) = measure_qubit(state, qubit, &mut rng);
224 outcome.push(if bit == 0 { '0' } else { '1' });
225 }
226 *counts.entry(outcome).or_insert(0) += 1;
227 }
228
229 let mut probabilities = HashMap::new();
231 let mut most_likely = String::new();
232 let mut max_count = 0;
233
234 for (outcome, &count) in &counts {
235 let prob = count as f64 / shots as f64;
236 probabilities.insert(outcome.clone(), prob);
237
238 if count > max_count {
239 max_count = count;
240 most_likely.clone_from(outcome);
241 }
242 }
243
244 let entropy = -probabilities
246 .values()
247 .filter(|&&p| p > 0.0)
248 .map(|&p| p * p.log2())
249 .sum::<f64>();
250
251 MeasurementStatistics {
252 counts,
253 probabilities,
254 most_likely,
255 entropy,
256 }
257}
258
259pub fn measure_expectation_batch(
261 batch: &BatchStateVector,
262 observable_qubits: &[(QubitId, Array2<Complex64>)],
263) -> QuantRS2Result<Vec<f64>> {
264 let batch_size = batch.batch_size();
265
266 let expectations: Vec<_> = (0..batch_size)
268 .into_par_iter()
269 .map(|i| {
270 let state = batch.states.row(i);
271 compute_observable_expectation(&state.to_owned(), observable_qubits, batch.n_qubits)
272 })
273 .collect::<QuantRS2Result<Vec<_>>>()?;
274
275 Ok(expectations)
276}
277
278fn compute_observable_expectation(
280 state: &Array1<Complex64>,
281 observable_qubits: &[(QubitId, Array2<Complex64>)],
282 n_qubits: usize,
283) -> QuantRS2Result<f64> {
284 let mut total_expectation = 1.0;
287
288 for (qubit, observable) in observable_qubits {
289 let qubit_idx = qubit.0 as usize;
290 if qubit_idx >= n_qubits {
291 return Err(QuantRS2Error::InvalidQubitId(qubit.0));
292 }
293
294 let exp = compute_single_qubit_expectation(state, *qubit, observable, n_qubits)?;
296 total_expectation *= exp;
297 }
298
299 Ok(total_expectation)
300}
301
302fn compute_single_qubit_expectation(
304 state: &Array1<Complex64>,
305 qubit: QubitId,
306 observable: &Array2<Complex64>,
307 n_qubits: usize,
308) -> QuantRS2Result<f64> {
309 if observable.shape() != [2, 2] {
310 return Err(QuantRS2Error::InvalidInput(
311 "Observable must be a 2x2 matrix".to_string(),
312 ));
313 }
314
315 let qubit_idx = qubit.0 as usize;
316 let state_size = 1 << n_qubits;
317 let qubit_mask = 1 << qubit_idx;
318
319 let mut expectation = Complex64::new(0.0, 0.0);
320
321 for i in 0..state_size {
322 for j in 0..state_size {
323 if (i ^ j) == qubit_mask {
325 let qi = (i >> qubit_idx) & 1;
326 let qj = (j >> qubit_idx) & 1;
327
328 expectation += state[i].conj() * observable[[qi, qj]] * state[j];
329 } else if i == j {
330 let qi = (i >> qubit_idx) & 1;
331 expectation += state[i].conj() * observable[[qi, qi]] * state[i];
332 }
333 }
334 }
335
336 Ok(expectation.re)
337}
338
339pub fn measure_tomography_batch(
341 batch: &BatchStateVector,
342 qubits: &[QubitId],
343 basis: TomographyBasis,
344) -> QuantRS2Result<BatchTomographyResult> {
345 let measurements = match basis {
346 TomographyBasis::Pauli => get_pauli_measurements(qubits),
347 TomographyBasis::Computational => get_computational_measurements(qubits),
348 TomographyBasis::Custom(ref bases) => bases.clone(),
349 };
350
351 let mut results = Vec::new();
352
353 for (name, observable_qubits) in measurements {
354 let expectations = measure_expectation_batch(batch, &observable_qubits)?;
355 results.push((name, expectations));
356 }
357
358 Ok(BatchTomographyResult {
359 measurements: results,
360 basis,
361 qubits: qubits.to_vec(),
362 })
363}
364
365pub type CustomMeasurementBasis = Vec<(String, Vec<(QubitId, Array2<Complex64>)>)>;
367
368#[derive(Debug, Clone)]
370pub enum TomographyBasis {
371 Pauli,
373 Computational,
375 Custom(CustomMeasurementBasis),
377}
378
379#[derive(Debug, Clone)]
381pub struct BatchTomographyResult {
382 pub measurements: Vec<(String, Vec<f64>)>,
384 pub basis: TomographyBasis,
386 pub qubits: Vec<QubitId>,
388}
389
390fn get_pauli_measurements(qubits: &[QubitId]) -> Vec<(String, Vec<(QubitId, Array2<Complex64>)>)> {
392 use scirs2_core::ndarray::array;
393
394 let pauli_x = array![
395 [Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)],
396 [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)]
397 ];
398
399 let pauli_y = array![
400 [Complex64::new(0.0, 0.0), Complex64::new(0.0, -1.0)],
401 [Complex64::new(0.0, 1.0), Complex64::new(0.0, 0.0)]
402 ];
403
404 let pauli_z = array![
405 [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
406 [Complex64::new(0.0, 0.0), Complex64::new(-1.0, 0.0)]
407 ];
408
409 let mut measurements = Vec::new();
410
411 for &qubit in qubits {
412 measurements.push((format!("X{}", qubit.0), vec![(qubit, pauli_x.clone())]));
413 measurements.push((format!("Y{}", qubit.0), vec![(qubit, pauli_y.clone())]));
414 measurements.push((format!("Z{}", qubit.0), vec![(qubit, pauli_z.clone())]));
415 }
416
417 measurements
418}
419
420fn get_computational_measurements(
422 qubits: &[QubitId],
423) -> Vec<(String, Vec<(QubitId, Array2<Complex64>)>)> {
424 use scirs2_core::ndarray::array;
425
426 let proj_0 = array![
427 [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
428 [Complex64::new(0.0, 0.0), Complex64::new(0.0, 0.0)]
429 ];
430
431 let proj_1 = array![
432 [Complex64::new(0.0, 0.0), Complex64::new(0.0, 0.0)],
433 [Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)]
434 ];
435
436 let mut measurements = Vec::new();
437
438 for &qubit in qubits {
439 measurements.push((format!("|0⟩{}", qubit.0), vec![(qubit, proj_0.clone())]));
440 measurements.push((format!("|1⟩{}", qubit.0), vec![(qubit, proj_1.clone())]));
441 }
442
443 measurements
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449 use scirs2_core::ndarray::array;
450
451 #[test]
452 fn test_batch_measurement() {
453 let batch = BatchStateVector::new(5, 2, Default::default())
454 .expect("Failed to create batch state vector");
455 let config = MeasurementConfig {
456 shots: 100,
457 return_states: false,
458 seed: Some(42),
459 parallel: false,
460 };
461
462 let result = measure_batch(&batch, &[QubitId(0), QubitId(1)], config)
463 .expect("Batch measurement failed");
464
465 assert_eq!(result.outcomes.shape(), &[5, 2]);
466 assert_eq!(result.probabilities.shape(), &[5, 2]);
467
468 for i in 0..5 {
470 assert_eq!(result.outcomes[[i, 0]], 0);
471 assert_eq!(result.outcomes[[i, 1]], 0);
472 assert!((result.probabilities[[i, 0]] - 1.0).abs() < 1e-10);
473 assert!((result.probabilities[[i, 1]] - 1.0).abs() < 1e-10);
474 }
475 }
476
477 #[test]
478 fn test_measurement_statistics() {
479 let mut states = Array2::zeros((1, 2));
481 states[[0, 0]] = Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0);
482 states[[0, 1]] = Complex64::new(1.0 / std::f64::consts::SQRT_2, 0.0);
483
484 let batch = BatchStateVector::from_states(states, Default::default())
485 .expect("Failed to create batch from states");
486
487 let stats = measure_batch_with_statistics(&batch, &[QubitId(0)], 1000)
488 .expect("Failed to measure batch statistics");
489
490 assert_eq!(stats.batch_size, 1);
491 assert_eq!(stats.measurement_size, 1);
492
493 let stat = &stats.statistics[0];
494 assert!(stat.counts.contains_key("0"));
496 assert!(stat.counts.contains_key("1"));
497
498 assert!((stat.entropy - 1.0).abs() < 0.1);
500 }
501
502 #[test]
503 fn test_expectation_measurement() {
504 let batch = BatchStateVector::new(3, 1, Default::default())
505 .expect("Failed to create batch state vector");
506
507 let pauli_z = array![
509 [Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
510 [Complex64::new(0.0, 0.0), Complex64::new(-1.0, 0.0)]
511 ];
512
513 let expectations = measure_expectation_batch(&batch, &[(QubitId(0), pauli_z)])
514 .expect("Expectation value measurement failed");
515
516 assert_eq!(expectations.len(), 3);
517 for exp in expectations {
519 assert!((exp - 1.0).abs() < 1e-10);
520 }
521 }
522}