quantrs2_ml/torchquantum/
measurement.rs1use super::gates::{TQHadamard, TQS};
11use super::{CType, TQDevice, TQModule, TQOperator, TQParameter};
12use crate::error::{MLError, Result};
13use scirs2_core::ndarray::{Array1, Array2};
14use std::collections::HashMap;
15
16pub fn gen_bitstrings(n_wires: usize) -> Vec<String> {
18 (0..(1 << n_wires))
19 .map(|k| format!("{:0width$b}", k, width = n_wires))
20 .collect()
21}
22
23pub fn measure(qdev: &TQDevice, n_shots: usize) -> Vec<HashMap<String, usize>> {
25 let bitstring_candidates = gen_bitstrings(qdev.n_wires);
26 let probs = qdev.get_probs_1d();
27
28 let mut distributions = Vec::with_capacity(qdev.bsz);
29
30 for batch in 0..qdev.bsz {
31 let mut counts = HashMap::new();
32
33 for bs in &bitstring_candidates {
35 counts.insert(bs.clone(), 0);
36 }
37
38 for _ in 0..n_shots {
40 let r: f64 = fastrand::f64();
41 let mut cumsum = 0.0;
42
43 for (i, &prob) in probs.row(batch).iter().enumerate() {
44 cumsum += prob;
45 if r < cumsum {
46 *counts.entry(bitstring_candidates[i].clone()).or_insert(0) += 1;
47 break;
48 }
49 }
50 }
51
52 distributions.push(counts);
53 }
54
55 distributions
56}
57
58pub fn expval_joint_analytical(qdev: &TQDevice, observable: &str) -> Array1<f64> {
60 let observable = observable.to_uppercase();
61 let n_wires = qdev.n_wires;
62
63 assert_eq!(
64 observable.len(),
65 n_wires,
66 "Observable length must match n_wires"
67 );
68
69 let states_1d = qdev.get_states_1d();
70
71 let pauli_x = Array2::from_shape_vec(
73 (2, 2),
74 vec![
75 CType::new(0.0, 0.0),
76 CType::new(1.0, 0.0),
77 CType::new(1.0, 0.0),
78 CType::new(0.0, 0.0),
79 ],
80 )
81 .unwrap_or_else(|_| Array2::eye(2).mapv(|x| CType::new(x, 0.0)));
82
83 let pauli_y = Array2::from_shape_vec(
84 (2, 2),
85 vec![
86 CType::new(0.0, 0.0),
87 CType::new(0.0, -1.0),
88 CType::new(0.0, 1.0),
89 CType::new(0.0, 0.0),
90 ],
91 )
92 .unwrap_or_else(|_| Array2::eye(2).mapv(|x| CType::new(x, 0.0)));
93
94 let pauli_z = Array2::from_shape_vec(
95 (2, 2),
96 vec![
97 CType::new(1.0, 0.0),
98 CType::new(0.0, 0.0),
99 CType::new(0.0, 0.0),
100 CType::new(-1.0, 0.0),
101 ],
102 )
103 .unwrap_or_else(|_| Array2::eye(2).mapv(|x| CType::new(x, 0.0)));
104
105 let identity = Array2::eye(2).mapv(|x| CType::new(x, 0.0));
106
107 let mut hamiltonian = match observable.chars().next().unwrap_or('I') {
109 'X' => pauli_x.clone(),
110 'Y' => pauli_y.clone(),
111 'Z' => pauli_z.clone(),
112 _ => identity.clone(),
113 };
114
115 for c in observable.chars().skip(1) {
116 let op = match c {
117 'X' => &pauli_x,
118 'Y' => &pauli_y,
119 'Z' => &pauli_z,
120 _ => &identity,
121 };
122 hamiltonian = kron(&hamiltonian, op);
123 }
124
125 let mut expvals = Array1::zeros(qdev.bsz);
127
128 for batch in 0..qdev.bsz {
129 let state = states_1d.row(batch);
130 let mut result = CType::new(0.0, 0.0);
131
132 for i in 0..state.len() {
133 for j in 0..state.len() {
134 result += state[i].conj() * hamiltonian[[i, j]] * state[j];
135 }
136 }
137
138 expvals[batch] = result.re;
139 }
140
141 expvals
142}
143
144pub fn expval_joint_sampling(qdev: &TQDevice, observable: &str, n_shots: usize) -> Array1<f64> {
146 let observable = observable.to_uppercase();
147 let n_wires = qdev.n_wires;
148
149 let mut qdev_clone = qdev.clone();
151
152 for (wire, c) in observable.chars().enumerate() {
154 match c {
155 'X' => {
156 let mut h = TQHadamard::new();
158 let _ = h.apply(&mut qdev_clone, &[wire]);
159 }
160 'Y' => {
161 let mut s = TQS::new();
163 s.set_inverse(true);
164 let _ = s.apply(&mut qdev_clone, &[wire]);
165 let mut h = TQHadamard::new();
166 let _ = h.apply(&mut qdev_clone, &[wire]);
167 }
168 _ => {} }
170 }
171
172 let distributions = measure(&qdev_clone, n_shots);
174
175 let mut expvals = Array1::zeros(qdev.bsz);
177
178 let mask: Vec<bool> = observable.chars().map(|c| c != 'I').collect();
180
181 for (batch, distri) in distributions.iter().enumerate() {
182 let mut n_eigen_one = 0;
183 let mut n_eigen_minus_one = 0;
184
185 for (bitstring, &count) in distri {
186 let parity: usize = bitstring
188 .chars()
189 .zip(mask.iter())
190 .filter_map(|(c, &m)| {
191 if m {
192 c.to_digit(2).map(|d| d as usize)
193 } else {
194 None
195 }
196 })
197 .sum();
198
199 if parity % 2 == 0 {
200 n_eigen_one += count;
201 } else {
202 n_eigen_minus_one += count;
203 }
204 }
205
206 expvals[batch] = (n_eigen_one as f64 - n_eigen_minus_one as f64) / n_shots as f64;
207 }
208
209 expvals
210}
211
212fn kron(a: &Array2<CType>, b: &Array2<CType>) -> Array2<CType> {
214 let (m, n) = (a.nrows(), a.ncols());
215 let (p, q) = (b.nrows(), b.ncols());
216
217 let mut result = Array2::zeros((m * p, n * q));
218
219 for i in 0..m {
220 for j in 0..n {
221 for k in 0..p {
222 for l in 0..q {
223 result[[i * p + k, j * q + l]] = a[[i, j]] * b[[k, l]];
224 }
225 }
226 }
227 }
228
229 result
230}
231
232#[derive(Debug, Clone)]
234pub struct TQMeasureAll {
235 pub observable: String,
237 static_mode: bool,
238}
239
240impl TQMeasureAll {
241 pub fn new(observable: impl Into<String>) -> Self {
242 Self {
243 observable: observable.into(),
244 static_mode: false,
245 }
246 }
247
248 pub fn pauli_z() -> Self {
250 Self::new("Z")
251 }
252
253 pub fn pauli_x() -> Self {
255 Self::new("X")
256 }
257
258 pub fn measure(&self, qdev: &TQDevice) -> Array2<f64> {
260 let n_wires = qdev.n_wires;
261 let mut results = Array2::zeros((qdev.bsz, n_wires));
262
263 for wire in 0..n_wires {
264 let obs: String = (0..n_wires)
266 .map(|w| {
267 if w == wire {
268 self.observable.chars().next().unwrap_or('Z')
269 } else {
270 'I'
271 }
272 })
273 .collect();
274
275 let expval = expval_joint_analytical(qdev, &obs);
276
277 for (batch, &val) in expval.iter().enumerate() {
278 results[[batch, wire]] = val;
279 }
280 }
281
282 results
283 }
284}
285
286impl TQModule for TQMeasureAll {
287 fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
288 Ok(())
289 }
290
291 fn parameters(&self) -> Vec<TQParameter> {
292 Vec::new()
293 }
294
295 fn n_wires(&self) -> Option<usize> {
296 None
297 }
298
299 fn set_n_wires(&mut self, _n_wires: usize) {}
300
301 fn is_static_mode(&self) -> bool {
302 self.static_mode
303 }
304
305 fn static_on(&mut self) {
306 self.static_mode = true;
307 }
308
309 fn static_off(&mut self) {
310 self.static_mode = false;
311 }
312
313 fn name(&self) -> &str {
314 "MeasureAll"
315 }
316}