quantrs2_ml/torchquantum/
measurement.rs

1//! Quantum measurement functions (TorchQuantum-compatible)
2//!
3//! This module provides measurement operations:
4//! - gen_bitstrings: Generate all bitstrings for n qubits
5//! - measure: Sample measurements from quantum state
6//! - expval_joint_analytical: Compute expectation value analytically
7//! - expval_joint_sampling: Compute expectation value by sampling
8//! - TQMeasureAll: Module wrapper for measurement
9
10use super::gates::{TQHadamard, TQS};
11use super::{CType, TQDevice, TQModule, TQOperator, TQParameter};
12use crate::error::{MLError, Result};
13use scirs2_core::ndarray::{Array1, Array2};
14use std::collections::HashMap;
15
16/// Generate all bitstrings for n qubits
17pub fn gen_bitstrings(n_wires: usize) -> Vec<String> {
18    (0..(1 << n_wires))
19        .map(|k| format!("{:0width$b}", k, width = n_wires))
20        .collect()
21}
22
23/// Measure the quantum state and return bitstring distribution
24pub fn measure(qdev: &TQDevice, n_shots: usize) -> Vec<HashMap<String, usize>> {
25    let bitstring_candidates = gen_bitstrings(qdev.n_wires);
26    let probs = qdev.get_probs_1d();
27
28    let mut distributions = Vec::with_capacity(qdev.bsz);
29
30    for batch in 0..qdev.bsz {
31        let mut counts = HashMap::new();
32
33        // Initialize all bitstrings with 0 counts
34        for bs in &bitstring_candidates {
35            counts.insert(bs.clone(), 0);
36        }
37
38        // Sample from distribution
39        for _ in 0..n_shots {
40            let r: f64 = fastrand::f64();
41            let mut cumsum = 0.0;
42
43            for (i, &prob) in probs.row(batch).iter().enumerate() {
44                cumsum += prob;
45                if r < cumsum {
46                    *counts.entry(bitstring_candidates[i].clone()).or_insert(0) += 1;
47                    break;
48                }
49            }
50        }
51
52        distributions.push(counts);
53    }
54
55    distributions
56}
57
58/// Compute expectation value analytically
59pub fn expval_joint_analytical(qdev: &TQDevice, observable: &str) -> Array1<f64> {
60    let observable = observable.to_uppercase();
61    let n_wires = qdev.n_wires;
62
63    assert_eq!(
64        observable.len(),
65        n_wires,
66        "Observable length must match n_wires"
67    );
68
69    let states_1d = qdev.get_states_1d();
70
71    // Build Hamiltonian matrix
72    let pauli_x = Array2::from_shape_vec(
73        (2, 2),
74        vec![
75            CType::new(0.0, 0.0),
76            CType::new(1.0, 0.0),
77            CType::new(1.0, 0.0),
78            CType::new(0.0, 0.0),
79        ],
80    )
81    .unwrap_or_else(|_| Array2::eye(2).mapv(|x| CType::new(x, 0.0)));
82
83    let pauli_y = Array2::from_shape_vec(
84        (2, 2),
85        vec![
86            CType::new(0.0, 0.0),
87            CType::new(0.0, -1.0),
88            CType::new(0.0, 1.0),
89            CType::new(0.0, 0.0),
90        ],
91    )
92    .unwrap_or_else(|_| Array2::eye(2).mapv(|x| CType::new(x, 0.0)));
93
94    let pauli_z = Array2::from_shape_vec(
95        (2, 2),
96        vec![
97            CType::new(1.0, 0.0),
98            CType::new(0.0, 0.0),
99            CType::new(0.0, 0.0),
100            CType::new(-1.0, 0.0),
101        ],
102    )
103    .unwrap_or_else(|_| Array2::eye(2).mapv(|x| CType::new(x, 0.0)));
104
105    let identity = Array2::eye(2).mapv(|x| CType::new(x, 0.0));
106
107    // Build tensor product of Pauli matrices
108    let mut hamiltonian = match observable.chars().next().unwrap_or('I') {
109        'X' => pauli_x.clone(),
110        'Y' => pauli_y.clone(),
111        'Z' => pauli_z.clone(),
112        _ => identity.clone(),
113    };
114
115    for c in observable.chars().skip(1) {
116        let op = match c {
117            'X' => &pauli_x,
118            'Y' => &pauli_y,
119            'Z' => &pauli_z,
120            _ => &identity,
121        };
122        hamiltonian = kron(&hamiltonian, op);
123    }
124
125    // Compute <psi|H|psi> for each batch
126    let mut expvals = Array1::zeros(qdev.bsz);
127
128    for batch in 0..qdev.bsz {
129        let state = states_1d.row(batch);
130        let mut result = CType::new(0.0, 0.0);
131
132        for i in 0..state.len() {
133            for j in 0..state.len() {
134                result += state[i].conj() * hamiltonian[[i, j]] * state[j];
135            }
136        }
137
138        expvals[batch] = result.re;
139    }
140
141    expvals
142}
143
144/// Compute expectation value via sampling
145pub fn expval_joint_sampling(qdev: &TQDevice, observable: &str, n_shots: usize) -> Array1<f64> {
146    let observable = observable.to_uppercase();
147    let n_wires = qdev.n_wires;
148
149    // Create a clone for measurement basis rotation
150    let mut qdev_clone = qdev.clone();
151
152    // Apply rotation to measurement basis
153    for (wire, c) in observable.chars().enumerate() {
154        match c {
155            'X' => {
156                // H gate to rotate X basis to Z
157                let mut h = TQHadamard::new();
158                let _ = h.apply(&mut qdev_clone, &[wire]);
159            }
160            'Y' => {
161                // S†H to rotate Y basis to Z
162                let mut s = TQS::new();
163                s.set_inverse(true);
164                let _ = s.apply(&mut qdev_clone, &[wire]);
165                let mut h = TQHadamard::new();
166                let _ = h.apply(&mut qdev_clone, &[wire]);
167            }
168            _ => {} // Z and I don't need rotation
169        }
170    }
171
172    // Measure
173    let distributions = measure(&qdev_clone, n_shots);
174
175    // Compute expectation values
176    let mut expvals = Array1::zeros(qdev.bsz);
177
178    // Create mask for non-identity positions
179    let mask: Vec<bool> = observable.chars().map(|c| c != 'I').collect();
180
181    for (batch, distri) in distributions.iter().enumerate() {
182        let mut n_eigen_one = 0;
183        let mut n_eigen_minus_one = 0;
184
185        for (bitstring, &count) in distri {
186            // Count parity of masked bits
187            let parity: usize = bitstring
188                .chars()
189                .zip(mask.iter())
190                .filter_map(|(c, &m)| {
191                    if m {
192                        c.to_digit(2).map(|d| d as usize)
193                    } else {
194                        None
195                    }
196                })
197                .sum();
198
199            if parity % 2 == 0 {
200                n_eigen_one += count;
201            } else {
202                n_eigen_minus_one += count;
203            }
204        }
205
206        expvals[batch] = (n_eigen_one as f64 - n_eigen_minus_one as f64) / n_shots as f64;
207    }
208
209    expvals
210}
211
212/// Kronecker product of two matrices
213fn kron(a: &Array2<CType>, b: &Array2<CType>) -> Array2<CType> {
214    let (m, n) = (a.nrows(), a.ncols());
215    let (p, q) = (b.nrows(), b.ncols());
216
217    let mut result = Array2::zeros((m * p, n * q));
218
219    for i in 0..m {
220        for j in 0..n {
221            for k in 0..p {
222                for l in 0..q {
223                    result[[i * p + k, j * q + l]] = a[[i, j]] * b[[k, l]];
224                }
225            }
226        }
227    }
228
229    result
230}
231
232/// MeasureAll module for measuring all qubits
233#[derive(Debug, Clone)]
234pub struct TQMeasureAll {
235    /// Observable to measure (default: PauliZ)
236    pub observable: String,
237    static_mode: bool,
238}
239
240impl TQMeasureAll {
241    pub fn new(observable: impl Into<String>) -> Self {
242        Self {
243            observable: observable.into(),
244            static_mode: false,
245        }
246    }
247
248    /// Measure with PauliZ on all qubits
249    pub fn pauli_z() -> Self {
250        Self::new("Z")
251    }
252
253    /// Measure with PauliX on all qubits
254    pub fn pauli_x() -> Self {
255        Self::new("X")
256    }
257
258    /// Measure expectation values for all qubits
259    pub fn measure(&self, qdev: &TQDevice) -> Array2<f64> {
260        let n_wires = qdev.n_wires;
261        let mut results = Array2::zeros((qdev.bsz, n_wires));
262
263        for wire in 0..n_wires {
264            // Create observable string with observable at this wire, I elsewhere
265            let obs: String = (0..n_wires)
266                .map(|w| {
267                    if w == wire {
268                        self.observable.chars().next().unwrap_or('Z')
269                    } else {
270                        'I'
271                    }
272                })
273                .collect();
274
275            let expval = expval_joint_analytical(qdev, &obs);
276
277            for (batch, &val) in expval.iter().enumerate() {
278                results[[batch, wire]] = val;
279            }
280        }
281
282        results
283    }
284}
285
286impl TQModule for TQMeasureAll {
287    fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
288        Ok(())
289    }
290
291    fn parameters(&self) -> Vec<TQParameter> {
292        Vec::new()
293    }
294
295    fn n_wires(&self) -> Option<usize> {
296        None
297    }
298
299    fn set_n_wires(&mut self, _n_wires: usize) {}
300
301    fn is_static_mode(&self) -> bool {
302        self.static_mode
303    }
304
305    fn static_on(&mut self) {
306        self.static_mode = true;
307    }
308
309    fn static_off(&mut self) {
310        self.static_mode = false;
311    }
312
313    fn name(&self) -> &str {
314        "MeasureAll"
315    }
316}