Skip to main content

quantrs2_core/state_visualization_3d/
density_bars.rs

1//! Density matrix 3D bar-plot visualization.
2//!
3//! Generates two side-by-side 3D bar charts showing Re(ρ) and Im(ρ)
4//! for a quantum state's density matrix ρ = |ψ⟩⟨ψ|.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::Complex64;
8use serde_json::{json, Value};
9
10use crate::error::{QuantRS2Error, QuantRS2Result};
11
12/// Build ρ = |ψ⟩⟨ψ| from a state vector.
13fn density_matrix(state: &Array1<Complex64>) -> Array2<Complex64> {
14    let d = state.len();
15    let mut rho = Array2::zeros((d, d));
16    for i in 0..d {
17        for j in 0..d {
18            rho[[i, j]] = state[i] * state[j].conj();
19        }
20    }
21    rho
22}
23
24/// Generate a basis label string for a given index with `n_qubits`.
25///
26/// For example, index 3 with n_qubits=2 → "|11⟩".
27fn basis_label(idx: usize, n_qubits: usize) -> String {
28    let bits: String = (0..n_qubits)
29        .rev()
30        .map(|b| if (idx >> b) & 1 == 1 { '1' } else { '0' })
31        .collect();
32    format!("|{}⟩", bits)
33}
34
35/// Map a float value to an RGB colour string.
36///
37/// Positive values → red end; negative values → blue end.
38fn value_to_color(v: f64, vmax: f64) -> String {
39    let t = if vmax < 1e-12 {
40        0.5
41    } else {
42        // Map [-vmax, +vmax] → [0, 1]
43        0.5 + 0.5 * (v / vmax).clamp(-1.0, 1.0)
44    };
45    // Interpolate: blue (0,0,255) → white (255,255,255) → red (255,0,0)
46    let r = (t * 255.0) as u8;
47    let b = ((1.0 - t) * 255.0) as u8;
48    let g = ((1.0 - (2.0 * t - 1.0).abs()) * 180.0) as u8;
49    format!("rgb({},{},{})", r, g, b)
50}
51
52/// Build Plotly mesh3d vertices for a single rectangular bar.
53///
54/// The bar runs from (x0, y0, 0) to (x0+w, y0+h, z_height).
55/// Returns (vertices x, y, z, face indices i, j, k).
56fn bar_mesh(
57    xi: f64,
58    yi: f64,
59    width: f64,
60    depth: f64,
61    height: f64,
62) -> (
63    [f64; 8],
64    [f64; 8],
65    [f64; 8],
66    [usize; 12],
67    [usize; 12],
68    [usize; 12],
69) {
70    let x = [
71        xi,
72        xi + width,
73        xi + width,
74        xi,
75        xi,
76        xi + width,
77        xi + width,
78        xi,
79    ];
80    let y = [
81        yi,
82        yi,
83        yi + depth,
84        yi + depth,
85        yi,
86        yi,
87        yi + depth,
88        yi + depth,
89    ];
90    // Bottom face z=0, top face z=height; ensure non-degenerate for tiny heights
91    let z_top = if height.abs() < 1e-15 { 1e-15 } else { height };
92    let z = [0.0, 0.0, 0.0, 0.0, z_top, z_top, z_top, z_top];
93
94    // 12 triangles (6 faces × 2 triangles each)
95    let i = [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 6];
96    let j = [1, 3, 2, 5, 3, 6, 0, 7, 5, 7, 6, 7];
97    let k = [2, 7, 6, 4, 7, 5, 4, 4, 6, 6, 7, 5];
98    (x, y, z, i, j, k)
99}
100
101/// Build all mesh3d data for a d×d matrix as accumulated bar mesh.
102///
103/// Returns a Plotly `mesh3d` trace value.
104fn matrix_to_mesh3d(values: &Array2<f64>, scene: &str, title: &str) -> Value {
105    let d = values.nrows();
106    let bar_size = 0.7f64; // width/depth of each bar (leaves gap between bars)
107    let gap = 1.0f64; // spacing between bar centres
108
109    // Find max absolute value for colour scaling
110    let vmax = values.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
111
112    // Accumulate all bars into a single mesh3d trace
113    let mut all_x: Vec<f64> = Vec::new();
114    let mut all_y: Vec<f64> = Vec::new();
115    let mut all_z: Vec<f64> = Vec::new();
116    let mut all_i: Vec<usize> = Vec::new();
117    let mut all_j: Vec<usize> = Vec::new();
118    let mut all_k: Vec<usize> = Vec::new();
119    let mut all_colors: Vec<String> = Vec::new();
120    let mut offset = 0usize;
121
122    for i in 0..d {
123        for j in 0..d {
124            let v = values[[i, j]];
125            let xi = (j as f64) * gap;
126            let yi = (i as f64) * gap;
127
128            let (bx, by, bz, bi, bj, bk) = bar_mesh(xi, yi, bar_size, bar_size, v);
129
130            all_x.extend_from_slice(&bx);
131            all_y.extend_from_slice(&by);
132            all_z.extend_from_slice(&bz);
133
134            for &fi in &bi {
135                all_i.push(fi + offset);
136            }
137            for &fi in &bj {
138                all_j.push(fi + offset);
139            }
140            for &fi in &bk {
141                all_k.push(fi + offset);
142            }
143
144            let color = value_to_color(v, vmax);
145            for _ in 0..8 {
146                all_colors.push(color.clone());
147            }
148            offset += 8;
149        }
150    }
151
152    json!({
153        "type": "mesh3d",
154        "x": all_x,
155        "y": all_y,
156        "z": all_z,
157        "i": all_i,
158        "j": all_j,
159        "k": all_k,
160        "vertexcolor": all_colors,
161        "opacity": 0.9,
162        "scene": scene,
163        "name": title,
164        "hoverinfo": "none"
165    })
166}
167
168/// Two side-by-side 3D bar plots of Re(ρ) and Im(ρ).
169///
170/// Builds ρ = |ψ⟩⟨ψ| as a d×d matrix (d = 2^n_qubits) and
171/// generates a Plotly figure with two 3D subplots.
172pub fn density_matrix_bars_plotly_json(
173    state: &Array1<Complex64>,
174    n_qubits: usize,
175) -> QuantRS2Result<String> {
176    let dim = 1usize << n_qubits;
177    if state.len() != dim {
178        return Err(QuantRS2Error::InvalidInput(format!(
179            "State length {} does not match 2^{} = {}",
180            state.len(),
181            n_qubits,
182            dim
183        )));
184    }
185    if n_qubits == 0 {
186        return Err(QuantRS2Error::InvalidInput(
187            "n_qubits must be > 0".to_string(),
188        ));
189    }
190
191    let rho = density_matrix(state);
192    let labels: Vec<String> = (0..dim).map(|i| basis_label(i, n_qubits)).collect();
193
194    // Extract real and imaginary parts
195    let re_matrix: Array2<f64> = rho.mapv(|c| c.re);
196    let im_matrix: Array2<f64> = rho.mapv(|c| c.im);
197
198    let re_trace = matrix_to_mesh3d(&re_matrix, "scene", "Re(ρ)");
199    let im_trace = matrix_to_mesh3d(&im_matrix, "scene2", "Im(ρ)");
200
201    // Tick configuration for axes
202    let tick_vals: Vec<f64> = (0..dim).map(|k| (k as f64) * 1.0 + 0.35).collect();
203    let tick_text: Vec<String> = labels.clone();
204
205    let axis_def = json!({
206        "tickvals": tick_vals,
207        "ticktext": tick_text
208    });
209
210    let layout = json!({
211        "title": "Density Matrix 3D Bar Plot",
212        "scene": {
213            "xaxis": axis_def,
214            "yaxis": axis_def,
215            "zaxis": {"title": "Re(ρ)"},
216            "aspectmode": "cube",
217            "camera": {"eye": {"x": 1.5, "y": 1.5, "z": 1.2}},
218            "domain": {"x": [0.0, 0.48], "y": [0.0, 1.0]},
219            "annotations": [{
220                "text": "Re(ρ)",
221                "x": 0.5, "y": 1.0, "z": 0.0,
222                "showarrow": false,
223                "font": {"size": 14}
224            }]
225        },
226        "scene2": {
227            "xaxis": axis_def,
228            "yaxis": axis_def,
229            "zaxis": {"title": "Im(ρ)"},
230            "aspectmode": "cube",
231            "camera": {"eye": {"x": 1.5, "y": 1.5, "z": 1.2}},
232            "domain": {"x": [0.52, 1.0], "y": [0.0, 1.0]},
233            "annotations": [{
234                "text": "Im(ρ)",
235                "x": 0.5, "y": 1.0, "z": 0.0,
236                "showarrow": false,
237                "font": {"size": 14}
238            }]
239        },
240        "height": 600,
241        "showlegend": false
242    });
243
244    let figure = json!({
245        "data": [re_trace, im_trace],
246        "layout": layout
247    });
248
249    serde_json::to_string(&figure).map_err(QuantRS2Error::from)
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255    use scirs2_core::Complex64;
256
257    /// Build a maximally mixed 2-qubit state (sqrt of I/4).
258    /// As a pure state |ψ⟩, the density matrix of |00⟩ is not I/4;
259    /// but we can check diagonal of a computational basis state.
260    fn state_zero_2q() -> Array1<Complex64> {
261        // |00⟩: ρ has Re diagonal = (1,0,0,0), Im = 0
262        let mut s = Array1::zeros(4);
263        s[0] = Complex64::new(1.0, 0.0);
264        s
265    }
266
267    fn state_plus_plus() -> Array1<Complex64> {
268        // |++⟩ = (|00⟩+|01⟩+|10⟩+|11⟩)/2
269        let half = 0.5;
270        Array1::from(vec![
271            Complex64::new(half, 0.0),
272            Complex64::new(half, 0.0),
273            Complex64::new(half, 0.0),
274            Complex64::new(half, 0.0),
275        ])
276    }
277
278    #[test]
279    fn test_density_bars_identity_2qubit() {
280        // For |++⟩ = (|00⟩+|01⟩+|10⟩+|11⟩)/2,
281        // ρ = |++⟩⟨++| = outer product → all entries = 1/4.
282        // Re(ρ) diagonal = 0.25, off-diagonal Re = 0.25, Im = 0.
283        let state = state_plus_plus();
284        let rho = density_matrix(&state);
285
286        for i in 0..4 {
287            for j in 0..4 {
288                let re = rho[[i, j]].re;
289                let im = rho[[i, j]].im;
290                assert!(
291                    (re - 0.25).abs() < 1e-10,
292                    "Re(ρ[{},{}]) should be 0.25, got {}",
293                    i,
294                    j,
295                    re
296                );
297                assert!(
298                    im.abs() < 1e-10,
299                    "Im(ρ[{},{}]) should be 0, got {}",
300                    i,
301                    j,
302                    im
303                );
304            }
305        }
306    }
307
308    #[test]
309    fn test_density_bars_zero_state() {
310        // |00⟩: ρ[0,0] = 1, all others = 0
311        let state = state_zero_2q();
312        let rho = density_matrix(&state);
313        assert!((rho[[0, 0]].re - 1.0).abs() < 1e-10);
314        for (i, j) in [(0, 1), (0, 2), (0, 3), (1, 0)] {
315            assert!(rho[[i, j]].norm_sqr() < 1e-20);
316        }
317    }
318
319    #[test]
320    fn test_density_bars_json_valid() {
321        let state = state_zero_2q();
322        let json_str =
323            density_matrix_bars_plotly_json(&state, 2).expect("Density bars JSON failed");
324        let _parsed: serde_json::Value =
325            serde_json::from_str(&json_str).expect("Output should be valid JSON");
326    }
327}