1use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::Complex64;
8use serde_json::{json, Value};
9
10use crate::error::{QuantRS2Error, QuantRS2Result};
11
12fn density_matrix(state: &Array1<Complex64>) -> Array2<Complex64> {
14 let d = state.len();
15 let mut rho = Array2::zeros((d, d));
16 for i in 0..d {
17 for j in 0..d {
18 rho[[i, j]] = state[i] * state[j].conj();
19 }
20 }
21 rho
22}
23
24fn basis_label(idx: usize, n_qubits: usize) -> String {
28 let bits: String = (0..n_qubits)
29 .rev()
30 .map(|b| if (idx >> b) & 1 == 1 { '1' } else { '0' })
31 .collect();
32 format!("|{}⟩", bits)
33}
34
35fn value_to_color(v: f64, vmax: f64) -> String {
39 let t = if vmax < 1e-12 {
40 0.5
41 } else {
42 0.5 + 0.5 * (v / vmax).clamp(-1.0, 1.0)
44 };
45 let r = (t * 255.0) as u8;
47 let b = ((1.0 - t) * 255.0) as u8;
48 let g = ((1.0 - (2.0 * t - 1.0).abs()) * 180.0) as u8;
49 format!("rgb({},{},{})", r, g, b)
50}
51
52fn bar_mesh(
57 xi: f64,
58 yi: f64,
59 width: f64,
60 depth: f64,
61 height: f64,
62) -> (
63 [f64; 8],
64 [f64; 8],
65 [f64; 8],
66 [usize; 12],
67 [usize; 12],
68 [usize; 12],
69) {
70 let x = [
71 xi,
72 xi + width,
73 xi + width,
74 xi,
75 xi,
76 xi + width,
77 xi + width,
78 xi,
79 ];
80 let y = [
81 yi,
82 yi,
83 yi + depth,
84 yi + depth,
85 yi,
86 yi,
87 yi + depth,
88 yi + depth,
89 ];
90 let z_top = if height.abs() < 1e-15 { 1e-15 } else { height };
92 let z = [0.0, 0.0, 0.0, 0.0, z_top, z_top, z_top, z_top];
93
94 let i = [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 6];
96 let j = [1, 3, 2, 5, 3, 6, 0, 7, 5, 7, 6, 7];
97 let k = [2, 7, 6, 4, 7, 5, 4, 4, 6, 6, 7, 5];
98 (x, y, z, i, j, k)
99}
100
101fn matrix_to_mesh3d(values: &Array2<f64>, scene: &str, title: &str) -> Value {
105 let d = values.nrows();
106 let bar_size = 0.7f64; let gap = 1.0f64; let vmax = values.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
111
112 let mut all_x: Vec<f64> = Vec::new();
114 let mut all_y: Vec<f64> = Vec::new();
115 let mut all_z: Vec<f64> = Vec::new();
116 let mut all_i: Vec<usize> = Vec::new();
117 let mut all_j: Vec<usize> = Vec::new();
118 let mut all_k: Vec<usize> = Vec::new();
119 let mut all_colors: Vec<String> = Vec::new();
120 let mut offset = 0usize;
121
122 for i in 0..d {
123 for j in 0..d {
124 let v = values[[i, j]];
125 let xi = (j as f64) * gap;
126 let yi = (i as f64) * gap;
127
128 let (bx, by, bz, bi, bj, bk) = bar_mesh(xi, yi, bar_size, bar_size, v);
129
130 all_x.extend_from_slice(&bx);
131 all_y.extend_from_slice(&by);
132 all_z.extend_from_slice(&bz);
133
134 for &fi in &bi {
135 all_i.push(fi + offset);
136 }
137 for &fi in &bj {
138 all_j.push(fi + offset);
139 }
140 for &fi in &bk {
141 all_k.push(fi + offset);
142 }
143
144 let color = value_to_color(v, vmax);
145 for _ in 0..8 {
146 all_colors.push(color.clone());
147 }
148 offset += 8;
149 }
150 }
151
152 json!({
153 "type": "mesh3d",
154 "x": all_x,
155 "y": all_y,
156 "z": all_z,
157 "i": all_i,
158 "j": all_j,
159 "k": all_k,
160 "vertexcolor": all_colors,
161 "opacity": 0.9,
162 "scene": scene,
163 "name": title,
164 "hoverinfo": "none"
165 })
166}
167
168pub fn density_matrix_bars_plotly_json(
173 state: &Array1<Complex64>,
174 n_qubits: usize,
175) -> QuantRS2Result<String> {
176 let dim = 1usize << n_qubits;
177 if state.len() != dim {
178 return Err(QuantRS2Error::InvalidInput(format!(
179 "State length {} does not match 2^{} = {}",
180 state.len(),
181 n_qubits,
182 dim
183 )));
184 }
185 if n_qubits == 0 {
186 return Err(QuantRS2Error::InvalidInput(
187 "n_qubits must be > 0".to_string(),
188 ));
189 }
190
191 let rho = density_matrix(state);
192 let labels: Vec<String> = (0..dim).map(|i| basis_label(i, n_qubits)).collect();
193
194 let re_matrix: Array2<f64> = rho.mapv(|c| c.re);
196 let im_matrix: Array2<f64> = rho.mapv(|c| c.im);
197
198 let re_trace = matrix_to_mesh3d(&re_matrix, "scene", "Re(ρ)");
199 let im_trace = matrix_to_mesh3d(&im_matrix, "scene2", "Im(ρ)");
200
201 let tick_vals: Vec<f64> = (0..dim).map(|k| (k as f64) * 1.0 + 0.35).collect();
203 let tick_text: Vec<String> = labels.clone();
204
205 let axis_def = json!({
206 "tickvals": tick_vals,
207 "ticktext": tick_text
208 });
209
210 let layout = json!({
211 "title": "Density Matrix 3D Bar Plot",
212 "scene": {
213 "xaxis": axis_def,
214 "yaxis": axis_def,
215 "zaxis": {"title": "Re(ρ)"},
216 "aspectmode": "cube",
217 "camera": {"eye": {"x": 1.5, "y": 1.5, "z": 1.2}},
218 "domain": {"x": [0.0, 0.48], "y": [0.0, 1.0]},
219 "annotations": [{
220 "text": "Re(ρ)",
221 "x": 0.5, "y": 1.0, "z": 0.0,
222 "showarrow": false,
223 "font": {"size": 14}
224 }]
225 },
226 "scene2": {
227 "xaxis": axis_def,
228 "yaxis": axis_def,
229 "zaxis": {"title": "Im(ρ)"},
230 "aspectmode": "cube",
231 "camera": {"eye": {"x": 1.5, "y": 1.5, "z": 1.2}},
232 "domain": {"x": [0.52, 1.0], "y": [0.0, 1.0]},
233 "annotations": [{
234 "text": "Im(ρ)",
235 "x": 0.5, "y": 1.0, "z": 0.0,
236 "showarrow": false,
237 "font": {"size": 14}
238 }]
239 },
240 "height": 600,
241 "showlegend": false
242 });
243
244 let figure = json!({
245 "data": [re_trace, im_trace],
246 "layout": layout
247 });
248
249 serde_json::to_string(&figure).map_err(QuantRS2Error::from)
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255 use scirs2_core::Complex64;
256
257 fn state_zero_2q() -> Array1<Complex64> {
261 let mut s = Array1::zeros(4);
263 s[0] = Complex64::new(1.0, 0.0);
264 s
265 }
266
267 fn state_plus_plus() -> Array1<Complex64> {
268 let half = 0.5;
270 Array1::from(vec![
271 Complex64::new(half, 0.0),
272 Complex64::new(half, 0.0),
273 Complex64::new(half, 0.0),
274 Complex64::new(half, 0.0),
275 ])
276 }
277
278 #[test]
279 fn test_density_bars_identity_2qubit() {
280 let state = state_plus_plus();
284 let rho = density_matrix(&state);
285
286 for i in 0..4 {
287 for j in 0..4 {
288 let re = rho[[i, j]].re;
289 let im = rho[[i, j]].im;
290 assert!(
291 (re - 0.25).abs() < 1e-10,
292 "Re(ρ[{},{}]) should be 0.25, got {}",
293 i,
294 j,
295 re
296 );
297 assert!(
298 im.abs() < 1e-10,
299 "Im(ρ[{},{}]) should be 0, got {}",
300 i,
301 j,
302 im
303 );
304 }
305 }
306 }
307
308 #[test]
309 fn test_density_bars_zero_state() {
310 let state = state_zero_2q();
312 let rho = density_matrix(&state);
313 assert!((rho[[0, 0]].re - 1.0).abs() < 1e-10);
314 for (i, j) in [(0, 1), (0, 2), (0, 3), (1, 0)] {
315 assert!(rho[[i, j]].norm_sqr() < 1e-20);
316 }
317 }
318
319 #[test]
320 fn test_density_bars_json_valid() {
321 let state = state_zero_2q();
322 let json_str =
323 density_matrix_bars_plotly_json(&state, 2).expect("Density bars JSON failed");
324 let _parsed: serde_json::Value =
325 serde_json::from_str(&json_str).expect("Output should be valid JSON");
326 }
327}