quantrs2_core/state_visualization_3d/
wigner.rs1use scirs2_core::ndarray::{Array1, Array2};
13use scirs2_core::Complex64;
14use serde_json::{json, Value};
15
16use crate::error::{QuantRS2Error, QuantRS2Result};
17
18fn pauli_i() -> Array2<Complex64> {
20 let mut m = Array2::zeros((2, 2));
21 m[[0, 0]] = Complex64::new(1.0, 0.0);
22 m[[1, 1]] = Complex64::new(1.0, 0.0);
23 m
24}
25
26fn pauli_x() -> Array2<Complex64> {
27 let mut m = Array2::zeros((2, 2));
28 m[[0, 1]] = Complex64::new(1.0, 0.0);
29 m[[1, 0]] = Complex64::new(1.0, 0.0);
30 m
31}
32
33fn pauli_y() -> Array2<Complex64> {
34 let mut m = Array2::zeros((2, 2));
35 m[[0, 1]] = Complex64::new(0.0, -1.0);
36 m[[1, 0]] = Complex64::new(0.0, 1.0);
37 m
38}
39
40fn pauli_z() -> Array2<Complex64> {
41 let mut m = Array2::zeros((2, 2));
42 m[[0, 0]] = Complex64::new(1.0, 0.0);
43 m[[1, 1]] = Complex64::new(-1.0, 0.0);
44 m
45}
46
47fn displacement_op_1(q: usize, p: usize) -> Array2<Complex64> {
63 let i = pauli_i();
64 let x = pauli_x();
65 let y = pauli_y();
66 let z = pauli_z();
67
68 let sx = if p % 2 == 0 { 1.0f64 } else { -1.0f64 };
69 let sy = if (q + p) % 2 == 0 { 1.0f64 } else { -1.0f64 };
70 let sz = if q % 2 == 0 { 1.0f64 } else { -1.0f64 };
71
72 let half = 0.5;
73 let mut result = Array2::zeros((2, 2));
74 for row in 0..2 {
75 for col in 0..2 {
76 result[[row, col]] = Complex64::new(half, 0.0)
77 * (i[[row, col]]
78 + Complex64::new(sx, 0.0) * x[[row, col]]
79 + Complex64::new(sy, 0.0) * y[[row, col]]
80 + Complex64::new(sz, 0.0) * z[[row, col]]);
81 }
82 }
83 result
84}
85
86fn tensor_product_2x2(a: &Array2<Complex64>, b: &Array2<Complex64>) -> Array2<Complex64> {
88 let mut out = Array2::zeros((4, 4));
89 for i in 0..2 {
90 for j in 0..2 {
91 for k in 0..2 {
92 for l in 0..2 {
93 out[[2 * i + k, 2 * j + l]] = a[[i, j]] * b[[k, l]];
94 }
95 }
96 }
97 }
98 out
99}
100
101fn matrix_trace(m: &Array2<Complex64>) -> Complex64 {
103 let n = m.nrows().min(m.ncols());
104 (0..n).map(|i| m[[i, i]]).sum()
105}
106
107fn density_matrix(state: &Array1<Complex64>) -> Array2<Complex64> {
109 let d = state.len();
110 let mut rho = Array2::zeros((d, d));
111 for i in 0..d {
112 for j in 0..d {
113 rho[[i, j]] = state[i] * state[j].conj();
114 }
115 }
116 rho
117}
118
119fn mat_mul(a: &Array2<Complex64>, b: &Array2<Complex64>) -> QuantRS2Result<Array2<Complex64>> {
121 let n = a.nrows();
122 if a.ncols() != b.nrows() || b.ncols() != n {
123 return Err(QuantRS2Error::InvalidInput(
124 "Incompatible matrix dimensions for multiplication".to_string(),
125 ));
126 }
127 let mut out = Array2::zeros((n, n));
128 for i in 0..n {
129 for j in 0..n {
130 let mut s = Complex64::new(0.0, 0.0);
131 for k in 0..n {
132 s += a[[i, k]] * b[[k, j]];
133 }
134 out[[i, j]] = s;
135 }
136 }
137 Ok(out)
138}
139
140fn wigner_n1(state: &Array1<Complex64>) -> QuantRS2Result<[[f64; 2]; 2]> {
146 let rho = density_matrix(state);
147 let mut w = [[0.0f64; 2]; 2];
148 for q in 0..2usize {
149 for p in 0..2usize {
150 let a = displacement_op_1(q, p);
151 let ap = mat_mul(&a, &rho)?;
152 let tr = matrix_trace(&ap);
153 w[q][p] = tr.re / 2.0; }
155 }
156 Ok(w)
157}
158
159fn wigner_n2(state: &Array1<Complex64>) -> QuantRS2Result<[[f64; 4]; 4]> {
169 let rho = density_matrix(state);
170 let mut w = [[0.0f64; 4]; 4];
171
172 for q in 0..4usize {
173 let q1 = q >> 1;
174 let q2 = q & 1;
175 for p in 0..4usize {
176 let p1 = p >> 1;
177 let p2 = p & 1;
178
179 let a1 = displacement_op_1(q1, p1);
180 let a2 = displacement_op_1(q2, p2);
181 let a_tensor = tensor_product_2x2(&a1, &a2);
182 let ap = mat_mul(&a_tensor, &rho)?;
183 let tr = matrix_trace(&ap);
184 w[q][p] = tr.re / 4.0; }
186 }
187 Ok(w)
188}
189
190pub fn wigner_plotly_json(state: &Array1<Complex64>, n_qubits: usize) -> QuantRS2Result<String> {
195 match n_qubits {
196 0 => Err(QuantRS2Error::InvalidInput(
197 "n_qubits must be ≥ 1".to_string(),
198 )),
199 1 => {
200 if state.len() != 2 {
201 return Err(QuantRS2Error::InvalidInput(format!(
202 "State length {} does not match 2^1 = 2",
203 state.len()
204 )));
205 }
206 let w = wigner_n1(state)?;
207 build_wigner_heatmap_n1(&w)
208 }
209 2 => {
210 if state.len() != 4 {
211 return Err(QuantRS2Error::InvalidInput(format!(
212 "State length {} does not match 2^2 = 4",
213 state.len()
214 )));
215 }
216 let w = wigner_n2(state)?;
217 build_wigner_heatmap_n2(&w)
218 }
219 _ => Err(QuantRS2Error::UnsupportedOperation(format!(
220 "Discrete Wigner for n={} requires GF(2^n) phase space — \
221 only n=1 and n=2 are supported in this version",
222 n_qubits
223 ))),
224 }
225}
226
227fn build_wigner_heatmap_n1(w: &[[f64; 2]; 2]) -> QuantRS2Result<String> {
229 let labels = ["(0,0)", "(1,0)", "(0,1)", "(1,1)"];
230
231 let z: Vec<Vec<f64>> = (0..2).map(|q| (0..2).map(|p| w[q][p]).collect()).collect();
233
234 let x_labels: Vec<&str> = vec!["p=0", "p=1"];
235 let y_labels: Vec<&str> = vec!["q=0", "q=1"];
236
237 let hovertext: Vec<Vec<String>> = (0..2)
238 .map(|q| {
239 (0..2)
240 .map(|p| format!("{} W={:.4}", labels[2 * q + p], w[q][p]))
241 .collect()
242 })
243 .collect();
244
245 let figure = json!({
246 "data": [{
247 "type": "heatmap",
248 "z": z,
249 "x": x_labels,
250 "y": y_labels,
251 "colorscale": "RdBu",
252 "zmid": 0.0,
253 "text": hovertext,
254 "hoverinfo": "text",
255 "colorbar": {"title": "W(q,p)"}
256 }],
257 "layout": {
258 "title": "Discrete Wigner Function (n=1)",
259 "xaxis": {"title": "p"},
260 "yaxis": {"title": "q"},
261 "height": 450
262 }
263 });
264
265 serde_json::to_string(&figure).map_err(QuantRS2Error::from)
266}
267
268fn build_wigner_heatmap_n2(w: &[[f64; 4]; 4]) -> QuantRS2Result<String> {
270 let coord_labels = ["(0,0)", "(0,1)", "(1,0)", "(1,1)"];
271
272 let z: Vec<Vec<f64>> = (0..4).map(|q| (0..4).map(|p| w[q][p]).collect()).collect();
273
274 let x_labels: Vec<String> = (0..4usize)
275 .map(|p| format!("p={}", coord_labels[p]))
276 .collect();
277 let y_labels: Vec<String> = (0..4usize)
278 .map(|q| format!("q={}", coord_labels[q]))
279 .collect();
280
281 let hovertext: Vec<Vec<String>> = (0..4)
282 .map(|q| {
283 (0..4)
284 .map(|p| {
285 format!(
286 "q={} p={} W={:.4}",
287 coord_labels[q], coord_labels[p], w[q][p]
288 )
289 })
290 .collect()
291 })
292 .collect();
293
294 let figure = json!({
295 "data": [{
296 "type": "heatmap",
297 "z": z,
298 "x": x_labels,
299 "y": y_labels,
300 "colorscale": "RdBu",
301 "zmid": 0.0,
302 "text": hovertext,
303 "hoverinfo": "text",
304 "colorbar": {"title": "W(q,p)"}
305 }],
306 "layout": {
307 "title": "Discrete Wigner Function (n=2)",
308 "xaxis": {"title": "p (phase-space momentum)"},
309 "yaxis": {"title": "q (phase-space position)"},
310 "height": 550
311 }
312 });
313
314 serde_json::to_string(&figure).map_err(QuantRS2Error::from)
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320 use scirs2_core::Complex64;
321
322 fn state_zero_1q() -> Array1<Complex64> {
323 Array1::from(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)])
324 }
325
326 fn state_bell_2q() -> Array1<Complex64> {
327 let inv_sqrt2 = 1.0 / 2.0_f64.sqrt();
328 Array1::from(vec![
329 Complex64::new(inv_sqrt2, 0.0),
330 Complex64::new(0.0, 0.0),
331 Complex64::new(0.0, 0.0),
332 Complex64::new(inv_sqrt2, 0.0),
333 ])
334 }
335
336 #[test]
337 fn test_wigner_n1_zero_state() {
338 let state = state_zero_1q();
339 let w = wigner_n1(&state).expect("wigner_n1 failed");
340
341 assert!(
344 (w[0][0] - 0.5).abs() < 1e-10,
345 "W(0,0) should be 0.5, got {}",
346 w[0][0]
347 );
348
349 let sum: f64 = w.iter().flat_map(|row| row.iter()).sum();
351 assert!(
352 (sum - 1.0).abs() < 1e-10,
353 "Wigner normalization should be 1, got {}",
354 sum
355 );
356 }
357
358 #[test]
359 fn test_wigner_n2_normalization() {
360 let state = state_bell_2q();
361 let w = wigner_n2(&state).expect("wigner_n2 failed");
362
363 let sum: f64 = w.iter().flat_map(|row| row.iter()).sum();
364 assert!(
365 (sum - 1.0).abs() < 1e-10,
366 "n=2 Wigner normalization should be 1, got {}",
367 sum
368 );
369 }
370
371 #[test]
372 fn test_wigner_n3_returns_err() {
373 let mut state = Array1::zeros(8);
375 state[0] = Complex64::new(1.0, 0.0);
376 let result = wigner_plotly_json(&state, 3);
377 assert!(result.is_err(), "n=3 should return Err");
378 if let Err(e) = result {
379 assert!(
380 matches!(e, QuantRS2Error::UnsupportedOperation(_)),
381 "Error should be UnsupportedOperation, got {:?}",
382 e
383 );
384 }
385 }
386
387 #[test]
388 fn test_wigner_json_valid() {
389 let state = state_zero_1q();
390 let json_str = wigner_plotly_json(&state, 1).expect("Wigner JSON failed");
391 let _parsed: serde_json::Value =
392 serde_json::from_str(&json_str).expect("Output should be valid JSON");
393 }
394}