use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::Complex64;
use serde_json::{json, Value};
use crate::error::{QuantRS2Error, QuantRS2Result};
fn pauli_i() -> Array2<Complex64> {
let mut m = Array2::zeros((2, 2));
m[[0, 0]] = Complex64::new(1.0, 0.0);
m[[1, 1]] = Complex64::new(1.0, 0.0);
m
}
fn pauli_x() -> Array2<Complex64> {
let mut m = Array2::zeros((2, 2));
m[[0, 1]] = Complex64::new(1.0, 0.0);
m[[1, 0]] = Complex64::new(1.0, 0.0);
m
}
fn pauli_y() -> Array2<Complex64> {
let mut m = Array2::zeros((2, 2));
m[[0, 1]] = Complex64::new(0.0, -1.0);
m[[1, 0]] = Complex64::new(0.0, 1.0);
m
}
fn pauli_z() -> Array2<Complex64> {
let mut m = Array2::zeros((2, 2));
m[[0, 0]] = Complex64::new(1.0, 0.0);
m[[1, 1]] = Complex64::new(-1.0, 0.0);
m
}
fn displacement_op_1(q: usize, p: usize) -> Array2<Complex64> {
let i = pauli_i();
let x = pauli_x();
let y = pauli_y();
let z = pauli_z();
let sx = if p % 2 == 0 { 1.0f64 } else { -1.0f64 };
let sy = if (q + p) % 2 == 0 { 1.0f64 } else { -1.0f64 };
let sz = if q % 2 == 0 { 1.0f64 } else { -1.0f64 };
let half = 0.5;
let mut result = Array2::zeros((2, 2));
for row in 0..2 {
for col in 0..2 {
result[[row, col]] = Complex64::new(half, 0.0)
* (i[[row, col]]
+ Complex64::new(sx, 0.0) * x[[row, col]]
+ Complex64::new(sy, 0.0) * y[[row, col]]
+ Complex64::new(sz, 0.0) * z[[row, col]]);
}
}
result
}
fn tensor_product_2x2(a: &Array2<Complex64>, b: &Array2<Complex64>) -> Array2<Complex64> {
let mut out = Array2::zeros((4, 4));
for i in 0..2 {
for j in 0..2 {
for k in 0..2 {
for l in 0..2 {
out[[2 * i + k, 2 * j + l]] = a[[i, j]] * b[[k, l]];
}
}
}
}
out
}
fn matrix_trace(m: &Array2<Complex64>) -> Complex64 {
let n = m.nrows().min(m.ncols());
(0..n).map(|i| m[[i, i]]).sum()
}
fn density_matrix(state: &Array1<Complex64>) -> Array2<Complex64> {
let d = state.len();
let mut rho = Array2::zeros((d, d));
for i in 0..d {
for j in 0..d {
rho[[i, j]] = state[i] * state[j].conj();
}
}
rho
}
fn mat_mul(a: &Array2<Complex64>, b: &Array2<Complex64>) -> QuantRS2Result<Array2<Complex64>> {
let n = a.nrows();
if a.ncols() != b.nrows() || b.ncols() != n {
return Err(QuantRS2Error::InvalidInput(
"Incompatible matrix dimensions for multiplication".to_string(),
));
}
let mut out = Array2::zeros((n, n));
for i in 0..n {
for j in 0..n {
let mut s = Complex64::new(0.0, 0.0);
for k in 0..n {
s += a[[i, k]] * b[[k, j]];
}
out[[i, j]] = s;
}
}
Ok(out)
}
fn wigner_n1(state: &Array1<Complex64>) -> QuantRS2Result<[[f64; 2]; 2]> {
let rho = density_matrix(state);
let mut w = [[0.0f64; 2]; 2];
for q in 0..2usize {
for p in 0..2usize {
let a = displacement_op_1(q, p);
let ap = mat_mul(&a, &rho)?;
let tr = matrix_trace(&ap);
w[q][p] = tr.re / 2.0; }
}
Ok(w)
}
fn wigner_n2(state: &Array1<Complex64>) -> QuantRS2Result<[[f64; 4]; 4]> {
let rho = density_matrix(state);
let mut w = [[0.0f64; 4]; 4];
for q in 0..4usize {
let q1 = q >> 1;
let q2 = q & 1;
for p in 0..4usize {
let p1 = p >> 1;
let p2 = p & 1;
let a1 = displacement_op_1(q1, p1);
let a2 = displacement_op_1(q2, p2);
let a_tensor = tensor_product_2x2(&a1, &a2);
let ap = mat_mul(&a_tensor, &rho)?;
let tr = matrix_trace(&ap);
w[q][p] = tr.re / 4.0; }
}
Ok(w)
}
pub fn wigner_plotly_json(state: &Array1<Complex64>, n_qubits: usize) -> QuantRS2Result<String> {
match n_qubits {
0 => Err(QuantRS2Error::InvalidInput(
"n_qubits must be ≥ 1".to_string(),
)),
1 => {
if state.len() != 2 {
return Err(QuantRS2Error::InvalidInput(format!(
"State length {} does not match 2^1 = 2",
state.len()
)));
}
let w = wigner_n1(state)?;
build_wigner_heatmap_n1(&w)
}
2 => {
if state.len() != 4 {
return Err(QuantRS2Error::InvalidInput(format!(
"State length {} does not match 2^2 = 4",
state.len()
)));
}
let w = wigner_n2(state)?;
build_wigner_heatmap_n2(&w)
}
_ => Err(QuantRS2Error::UnsupportedOperation(format!(
"Discrete Wigner for n={} requires GF(2^n) phase space — \
only n=1 and n=2 are supported in this version",
n_qubits
))),
}
}
fn build_wigner_heatmap_n1(w: &[[f64; 2]; 2]) -> QuantRS2Result<String> {
let labels = ["(0,0)", "(1,0)", "(0,1)", "(1,1)"];
let z: Vec<Vec<f64>> = (0..2).map(|q| (0..2).map(|p| w[q][p]).collect()).collect();
let x_labels: Vec<&str> = vec!["p=0", "p=1"];
let y_labels: Vec<&str> = vec!["q=0", "q=1"];
let hovertext: Vec<Vec<String>> = (0..2)
.map(|q| {
(0..2)
.map(|p| format!("{} W={:.4}", labels[2 * q + p], w[q][p]))
.collect()
})
.collect();
let figure = json!({
"data": [{
"type": "heatmap",
"z": z,
"x": x_labels,
"y": y_labels,
"colorscale": "RdBu",
"zmid": 0.0,
"text": hovertext,
"hoverinfo": "text",
"colorbar": {"title": "W(q,p)"}
}],
"layout": {
"title": "Discrete Wigner Function (n=1)",
"xaxis": {"title": "p"},
"yaxis": {"title": "q"},
"height": 450
}
});
serde_json::to_string(&figure).map_err(QuantRS2Error::from)
}
fn build_wigner_heatmap_n2(w: &[[f64; 4]; 4]) -> QuantRS2Result<String> {
let coord_labels = ["(0,0)", "(0,1)", "(1,0)", "(1,1)"];
let z: Vec<Vec<f64>> = (0..4).map(|q| (0..4).map(|p| w[q][p]).collect()).collect();
let x_labels: Vec<String> = (0..4usize)
.map(|p| format!("p={}", coord_labels[p]))
.collect();
let y_labels: Vec<String> = (0..4usize)
.map(|q| format!("q={}", coord_labels[q]))
.collect();
let hovertext: Vec<Vec<String>> = (0..4)
.map(|q| {
(0..4)
.map(|p| {
format!(
"q={} p={} W={:.4}",
coord_labels[q], coord_labels[p], w[q][p]
)
})
.collect()
})
.collect();
let figure = json!({
"data": [{
"type": "heatmap",
"z": z,
"x": x_labels,
"y": y_labels,
"colorscale": "RdBu",
"zmid": 0.0,
"text": hovertext,
"hoverinfo": "text",
"colorbar": {"title": "W(q,p)"}
}],
"layout": {
"title": "Discrete Wigner Function (n=2)",
"xaxis": {"title": "p (phase-space momentum)"},
"yaxis": {"title": "q (phase-space position)"},
"height": 550
}
});
serde_json::to_string(&figure).map_err(QuantRS2Error::from)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::Complex64;
fn state_zero_1q() -> Array1<Complex64> {
Array1::from(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)])
}
fn state_bell_2q() -> Array1<Complex64> {
let inv_sqrt2 = 1.0 / 2.0_f64.sqrt();
Array1::from(vec![
Complex64::new(inv_sqrt2, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(inv_sqrt2, 0.0),
])
}
#[test]
fn test_wigner_n1_zero_state() {
let state = state_zero_1q();
let w = wigner_n1(&state).expect("wigner_n1 failed");
assert!(
(w[0][0] - 0.5).abs() < 1e-10,
"W(0,0) should be 0.5, got {}",
w[0][0]
);
let sum: f64 = w.iter().flat_map(|row| row.iter()).sum();
assert!(
(sum - 1.0).abs() < 1e-10,
"Wigner normalization should be 1, got {}",
sum
);
}
#[test]
fn test_wigner_n2_normalization() {
let state = state_bell_2q();
let w = wigner_n2(&state).expect("wigner_n2 failed");
let sum: f64 = w.iter().flat_map(|row| row.iter()).sum();
assert!(
(sum - 1.0).abs() < 1e-10,
"n=2 Wigner normalization should be 1, got {}",
sum
);
}
#[test]
fn test_wigner_n3_returns_err() {
let mut state = Array1::zeros(8);
state[0] = Complex64::new(1.0, 0.0);
let result = wigner_plotly_json(&state, 3);
assert!(result.is_err(), "n=3 should return Err");
if let Err(e) = result {
assert!(
matches!(e, QuantRS2Error::UnsupportedOperation(_)),
"Error should be UnsupportedOperation, got {:?}",
e
);
}
}
#[test]
fn test_wigner_json_valid() {
let state = state_zero_1q();
let json_str = wigner_plotly_json(&state, 1).expect("Wigner JSON failed");
let _parsed: serde_json::Value =
serde_json::from_str(&json_str).expect("Output should be valid JSON");
}
}