Skip to main content

quantrs2_core/state_visualization_3d/
wigner.rs

1//! Discrete Wigner function visualization for n=1 and n=2 qubit states.
2//!
3//! Implements Wootters' discrete phase-space Wigner function using
4//! the displacement-operator basis {I, X, Z, Y} for n=1 and its
5//! tensor-product extension for n=2.
6//!
7//! # Scope
8//! Strictly limited to n=1 (4 phase-space points) and n=2 (16 points).
9//! Returns an error for n ≥ 3 because the GF(2^n) construction is
10//! research-grade and has multiple inequivalent definitions.
11
12use scirs2_core::ndarray::{Array1, Array2};
13use scirs2_core::Complex64;
14use serde_json::{json, Value};
15
16use crate::error::{QuantRS2Error, QuantRS2Result};
17
18/// Pauli matrices as 2×2 complex arrays.
19fn pauli_i() -> Array2<Complex64> {
20    let mut m = Array2::zeros((2, 2));
21    m[[0, 0]] = Complex64::new(1.0, 0.0);
22    m[[1, 1]] = Complex64::new(1.0, 0.0);
23    m
24}
25
26fn pauli_x() -> Array2<Complex64> {
27    let mut m = Array2::zeros((2, 2));
28    m[[0, 1]] = Complex64::new(1.0, 0.0);
29    m[[1, 0]] = Complex64::new(1.0, 0.0);
30    m
31}
32
33fn pauli_y() -> Array2<Complex64> {
34    let mut m = Array2::zeros((2, 2));
35    m[[0, 1]] = Complex64::new(0.0, -1.0);
36    m[[1, 0]] = Complex64::new(0.0, 1.0);
37    m
38}
39
40fn pauli_z() -> Array2<Complex64> {
41    let mut m = Array2::zeros((2, 2));
42    m[[0, 0]] = Complex64::new(1.0, 0.0);
43    m[[1, 1]] = Complex64::new(-1.0, 0.0);
44    m
45}
46
47/// Phase-space point operator A(q, p) for a single qubit.
48///
49/// Uses the Wootters (1987) definition where the 4 operators form a
50/// complete orthogonal set satisfying:
51///   Σ_{q,p} A(q,p) = 2·I   (enabling Wigner normalization Σ W = 1)
52///   Tr(A(q,p)) = 1
53///
54/// The operators are:
55///   A(q,p) = ½(I + (-1)^p X + (-1)^{q+p} Y + (-1)^q Z)
56///
57/// Explicitly:
58///   A(0,0) = ½(I + X + Y + Z)
59///   A(1,0) = ½(I + X - Y - Z)
60///   A(0,1) = ½(I - X - Y + Z)
61///   A(1,1) = ½(I - X + Y - Z)
62fn displacement_op_1(q: usize, p: usize) -> Array2<Complex64> {
63    let i = pauli_i();
64    let x = pauli_x();
65    let y = pauli_y();
66    let z = pauli_z();
67
68    let sx = if p % 2 == 0 { 1.0f64 } else { -1.0f64 };
69    let sy = if (q + p) % 2 == 0 { 1.0f64 } else { -1.0f64 };
70    let sz = if q % 2 == 0 { 1.0f64 } else { -1.0f64 };
71
72    let half = 0.5;
73    let mut result = Array2::zeros((2, 2));
74    for row in 0..2 {
75        for col in 0..2 {
76            result[[row, col]] = Complex64::new(half, 0.0)
77                * (i[[row, col]]
78                    + Complex64::new(sx, 0.0) * x[[row, col]]
79                    + Complex64::new(sy, 0.0) * y[[row, col]]
80                    + Complex64::new(sz, 0.0) * z[[row, col]]);
81        }
82    }
83    result
84}
85
86/// Tensor product of two 2×2 matrices → 4×4 matrix.
87fn tensor_product_2x2(a: &Array2<Complex64>, b: &Array2<Complex64>) -> Array2<Complex64> {
88    let mut out = Array2::zeros((4, 4));
89    for i in 0..2 {
90        for j in 0..2 {
91            for k in 0..2 {
92                for l in 0..2 {
93                    out[[2 * i + k, 2 * j + l]] = a[[i, j]] * b[[k, l]];
94                }
95            }
96        }
97    }
98    out
99}
100
101/// Trace of a square matrix.
102fn matrix_trace(m: &Array2<Complex64>) -> Complex64 {
103    let n = m.nrows().min(m.ncols());
104    (0..n).map(|i| m[[i, i]]).sum()
105}
106
107/// Compute the density matrix ρ = |ψ⟩⟨ψ| from a state vector.
108fn density_matrix(state: &Array1<Complex64>) -> Array2<Complex64> {
109    let d = state.len();
110    let mut rho = Array2::zeros((d, d));
111    for i in 0..d {
112        for j in 0..d {
113            rho[[i, j]] = state[i] * state[j].conj();
114        }
115    }
116    rho
117}
118
119/// Matrix–matrix multiply for square complex matrices.
120fn mat_mul(a: &Array2<Complex64>, b: &Array2<Complex64>) -> QuantRS2Result<Array2<Complex64>> {
121    let n = a.nrows();
122    if a.ncols() != b.nrows() || b.ncols() != n {
123        return Err(QuantRS2Error::InvalidInput(
124            "Incompatible matrix dimensions for multiplication".to_string(),
125        ));
126    }
127    let mut out = Array2::zeros((n, n));
128    for i in 0..n {
129        for j in 0..n {
130            let mut s = Complex64::new(0.0, 0.0);
131            for k in 0..n {
132                s += a[[i, k]] * b[[k, j]];
133            }
134            out[[i, j]] = s;
135        }
136    }
137    Ok(out)
138}
139
140/// Compute the discrete Wigner function for n=1.
141///
142/// W(q,p) = (1/d) Tr(A(q,p) ρ)  where d=2.
143///
144/// Returns a 2×2 array indexed by (q, p) ∈ {0,1}².
145fn wigner_n1(state: &Array1<Complex64>) -> QuantRS2Result<[[f64; 2]; 2]> {
146    let rho = density_matrix(state);
147    let mut w = [[0.0f64; 2]; 2];
148    for q in 0..2usize {
149        for p in 0..2usize {
150            let a = displacement_op_1(q, p);
151            let ap = mat_mul(&a, &rho)?;
152            let tr = matrix_trace(&ap);
153            w[q][p] = tr.re / 2.0; // d = 2
154        }
155    }
156    Ok(w)
157}
158
159/// Compute the discrete Wigner function for n=2.
160///
161/// Uses the tensor-product displacement operator:
162/// A⊗(q₁q₂, p₁p₂) = A₁(q₁,p₁) ⊗ A₂(q₂,p₂)
163///
164/// W(q₁,q₂; p₁,p₂) = (1/4) Tr(A⊗ ρ)
165///
166/// Returns a 4×4 array indexed by (q, p) ∈ {0..3}×{0..3},
167/// where q = 2·q₁ + q₂ and p = 2·p₁ + p₂.
168fn wigner_n2(state: &Array1<Complex64>) -> QuantRS2Result<[[f64; 4]; 4]> {
169    let rho = density_matrix(state);
170    let mut w = [[0.0f64; 4]; 4];
171
172    for q in 0..4usize {
173        let q1 = q >> 1;
174        let q2 = q & 1;
175        for p in 0..4usize {
176            let p1 = p >> 1;
177            let p2 = p & 1;
178
179            let a1 = displacement_op_1(q1, p1);
180            let a2 = displacement_op_1(q2, p2);
181            let a_tensor = tensor_product_2x2(&a1, &a2);
182            let ap = mat_mul(&a_tensor, &rho)?;
183            let tr = matrix_trace(&ap);
184            w[q][p] = tr.re / 4.0; // d = 4
185        }
186    }
187    Ok(w)
188}
189
190/// Discrete Wigner function for n=1 (4-point) or n=2 (16-point) states.
191///
192/// Returns an `Err` for n ≥ 3 — the GF(2^n) construction is out of
193/// scope for this version.
194pub fn wigner_plotly_json(state: &Array1<Complex64>, n_qubits: usize) -> QuantRS2Result<String> {
195    match n_qubits {
196        0 => Err(QuantRS2Error::InvalidInput(
197            "n_qubits must be ≥ 1".to_string(),
198        )),
199        1 => {
200            if state.len() != 2 {
201                return Err(QuantRS2Error::InvalidInput(format!(
202                    "State length {} does not match 2^1 = 2",
203                    state.len()
204                )));
205            }
206            let w = wigner_n1(state)?;
207            build_wigner_heatmap_n1(&w)
208        }
209        2 => {
210            if state.len() != 4 {
211                return Err(QuantRS2Error::InvalidInput(format!(
212                    "State length {} does not match 2^2 = 4",
213                    state.len()
214                )));
215            }
216            let w = wigner_n2(state)?;
217            build_wigner_heatmap_n2(&w)
218        }
219        _ => Err(QuantRS2Error::UnsupportedOperation(format!(
220            "Discrete Wigner for n={} requires GF(2^n) phase space — \
221             only n=1 and n=2 are supported in this version",
222            n_qubits
223        ))),
224    }
225}
226
227/// Build a Plotly heatmap for the n=1 Wigner function (2×2 grid).
228fn build_wigner_heatmap_n1(w: &[[f64; 2]; 2]) -> QuantRS2Result<String> {
229    let labels = ["(0,0)", "(1,0)", "(0,1)", "(1,1)"];
230
231    // Arrange as a 2×2 grid: rows = q, columns = p
232    let z: Vec<Vec<f64>> = (0..2).map(|q| (0..2).map(|p| w[q][p]).collect()).collect();
233
234    let x_labels: Vec<&str> = vec!["p=0", "p=1"];
235    let y_labels: Vec<&str> = vec!["q=0", "q=1"];
236
237    let hovertext: Vec<Vec<String>> = (0..2)
238        .map(|q| {
239            (0..2)
240                .map(|p| format!("{} W={:.4}", labels[2 * q + p], w[q][p]))
241                .collect()
242        })
243        .collect();
244
245    let figure = json!({
246        "data": [{
247            "type": "heatmap",
248            "z": z,
249            "x": x_labels,
250            "y": y_labels,
251            "colorscale": "RdBu",
252            "zmid": 0.0,
253            "text": hovertext,
254            "hoverinfo": "text",
255            "colorbar": {"title": "W(q,p)"}
256        }],
257        "layout": {
258            "title": "Discrete Wigner Function (n=1)",
259            "xaxis": {"title": "p"},
260            "yaxis": {"title": "q"},
261            "height": 450
262        }
263    });
264
265    serde_json::to_string(&figure).map_err(QuantRS2Error::from)
266}
267
268/// Build a Plotly heatmap for the n=2 Wigner function (4×4 grid).
269fn build_wigner_heatmap_n2(w: &[[f64; 4]; 4]) -> QuantRS2Result<String> {
270    let coord_labels = ["(0,0)", "(0,1)", "(1,0)", "(1,1)"];
271
272    let z: Vec<Vec<f64>> = (0..4).map(|q| (0..4).map(|p| w[q][p]).collect()).collect();
273
274    let x_labels: Vec<String> = (0..4usize)
275        .map(|p| format!("p={}", coord_labels[p]))
276        .collect();
277    let y_labels: Vec<String> = (0..4usize)
278        .map(|q| format!("q={}", coord_labels[q]))
279        .collect();
280
281    let hovertext: Vec<Vec<String>> = (0..4)
282        .map(|q| {
283            (0..4)
284                .map(|p| {
285                    format!(
286                        "q={} p={} W={:.4}",
287                        coord_labels[q], coord_labels[p], w[q][p]
288                    )
289                })
290                .collect()
291        })
292        .collect();
293
294    let figure = json!({
295        "data": [{
296            "type": "heatmap",
297            "z": z,
298            "x": x_labels,
299            "y": y_labels,
300            "colorscale": "RdBu",
301            "zmid": 0.0,
302            "text": hovertext,
303            "hoverinfo": "text",
304            "colorbar": {"title": "W(q,p)"}
305        }],
306        "layout": {
307            "title": "Discrete Wigner Function (n=2)",
308            "xaxis": {"title": "p (phase-space momentum)"},
309            "yaxis": {"title": "q (phase-space position)"},
310            "height": 550
311        }
312    });
313
314    serde_json::to_string(&figure).map_err(QuantRS2Error::from)
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320    use scirs2_core::Complex64;
321
322    fn state_zero_1q() -> Array1<Complex64> {
323        Array1::from(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)])
324    }
325
326    fn state_bell_2q() -> Array1<Complex64> {
327        let inv_sqrt2 = 1.0 / 2.0_f64.sqrt();
328        Array1::from(vec![
329            Complex64::new(inv_sqrt2, 0.0),
330            Complex64::new(0.0, 0.0),
331            Complex64::new(0.0, 0.0),
332            Complex64::new(inv_sqrt2, 0.0),
333        ])
334    }
335
336    #[test]
337    fn test_wigner_n1_zero_state() {
338        let state = state_zero_1q();
339        let w = wigner_n1(&state).expect("wigner_n1 failed");
340
341        // |0⟩: ρ = [[1,0],[0,0]]
342        // W(0,0) = (1/2) Tr(I ρ) = 1/2
343        assert!(
344            (w[0][0] - 0.5).abs() < 1e-10,
345            "W(0,0) should be 0.5, got {}",
346            w[0][0]
347        );
348
349        // Normalization: Σ W = 1
350        let sum: f64 = w.iter().flat_map(|row| row.iter()).sum();
351        assert!(
352            (sum - 1.0).abs() < 1e-10,
353            "Wigner normalization should be 1, got {}",
354            sum
355        );
356    }
357
358    #[test]
359    fn test_wigner_n2_normalization() {
360        let state = state_bell_2q();
361        let w = wigner_n2(&state).expect("wigner_n2 failed");
362
363        let sum: f64 = w.iter().flat_map(|row| row.iter()).sum();
364        assert!(
365            (sum - 1.0).abs() < 1e-10,
366            "n=2 Wigner normalization should be 1, got {}",
367            sum
368        );
369    }
370
371    #[test]
372    fn test_wigner_n3_returns_err() {
373        // Build a valid 3-qubit state (|000⟩)
374        let mut state = Array1::zeros(8);
375        state[0] = Complex64::new(1.0, 0.0);
376        let result = wigner_plotly_json(&state, 3);
377        assert!(result.is_err(), "n=3 should return Err");
378        if let Err(e) = result {
379            assert!(
380                matches!(e, QuantRS2Error::UnsupportedOperation(_)),
381                "Error should be UnsupportedOperation, got {:?}",
382                e
383            );
384        }
385    }
386
387    #[test]
388    fn test_wigner_json_valid() {
389        let state = state_zero_1q();
390        let json_str = wigner_plotly_json(&state, 1).expect("Wigner JSON failed");
391        let _parsed: serde_json::Value =
392            serde_json::from_str(&json_str).expect("Output should be valid JSON");
393    }
394}