Skip to main content

quantrs2_core/state_visualization_3d/
bloch.rs

1//! Bloch sphere visualization for multi-qubit states.
2//!
3//! Provides single-qubit Bloch vector computation via partial trace
4//! and a Plotly JSON generator for an N-sphere grid.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::Complex64;
8use serde_json::{json, Value};
9
10use crate::error::{QuantRS2Error, QuantRS2Result};
11
12/// Compute the Bloch vector for a single qubit by partial-tracing the full state.
13///
14/// Returns `(x, y, z)` where:
15/// - `x = 2 · Re(ρ₀₁)`
16/// - `y = 2 · Im(ρ₁₀)` (= −2 · Im(ρ₀₁) due to Hermiticity)
17/// - `z = ρ₀₀ − ρ₁₁`
18pub fn bloch_vector_for_qubit(
19    state: &Array1<Complex64>,
20    qubit_idx: usize,
21    n_qubits: usize,
22) -> QuantRS2Result<(f64, f64, f64)> {
23    let dim = 1usize << n_qubits;
24    if state.len() != dim {
25        return Err(QuantRS2Error::InvalidInput(format!(
26            "State length {} does not match 2^{} = {}",
27            state.len(),
28            n_qubits,
29            dim
30        )));
31    }
32    if qubit_idx >= n_qubits {
33        return Err(QuantRS2Error::InvalidInput(format!(
34            "qubit_idx {} out of range for {} qubits",
35            qubit_idx, n_qubits
36        )));
37    }
38
39    // Build density matrix ρ = |ψ⟩⟨ψ|
40    let mut rho = Array2::zeros((dim, dim));
41    for i in 0..dim {
42        for j in 0..dim {
43            rho[[i, j]] = state[i] * state[j].conj();
44        }
45    }
46
47    // Partial trace: keep only qubit_idx
48    let reduced = crate::matrix_ops::partial_trace(&rho, &[qubit_idx], n_qubits)?;
49
50    // Bloch vector components from the 2×2 reduced density matrix
51    let x = 2.0 * reduced[[0, 1]].re;
52    let y = 2.0 * reduced[[1, 0]].im; // Im(ρ₁₀) = −Im(ρ₀₁)
53    let z = reduced[[0, 0]].re - reduced[[1, 1]].re;
54
55    Ok((x, y, z))
56}
57
58/// Build a unit-sphere surface mesh as a Plotly `surface` trace.
59///
60/// Returns a `serde_json::Value` ready to be embedded in a `data` array.
61fn sphere_surface_trace(x_axis: &str) -> Value {
62    // Parametric sphere: 20×20 mesh is sufficient for background rendering
63    let n = 20usize;
64    let mut x_vals: Vec<Vec<f64>> = Vec::with_capacity(n + 1);
65    let mut y_vals: Vec<Vec<f64>> = Vec::with_capacity(n + 1);
66    let mut z_vals: Vec<Vec<f64>> = Vec::with_capacity(n + 1);
67
68    for i in 0..=n {
69        let theta = std::f64::consts::PI * (i as f64) / (n as f64);
70        let mut row_x = Vec::with_capacity(n + 1);
71        let mut row_y = Vec::with_capacity(n + 1);
72        let mut row_z = Vec::with_capacity(n + 1);
73        for j in 0..=n {
74            let phi = 2.0 * std::f64::consts::PI * (j as f64) / (n as f64);
75            row_x.push(theta.sin() * phi.cos());
76            row_y.push(theta.sin() * phi.sin());
77            row_z.push(theta.cos());
78        }
79        x_vals.push(row_x);
80        y_vals.push(row_y);
81        z_vals.push(row_z);
82    }
83
84    json!({
85        "type": "surface",
86        "x": x_vals,
87        "y": y_vals,
88        "z": z_vals,
89        "opacity": 0.25,
90        "colorscale": [[0, "lightblue"], [1, "lightblue"]],
91        "showscale": false,
92        "scene": format!("scene{}", x_axis.trim_start_matches("xaxis")),
93        "hoverinfo": "none"
94    })
95}
96
97/// Build a cone trace representing the Bloch vector.
98fn bloch_vector_trace(bx: f64, by: f64, bz: f64, scene: &str) -> Value {
99    json!({
100        "type": "cone",
101        "x": [0.0],
102        "y": [0.0],
103        "z": [0.0],
104        "u": [bx],
105        "v": [by],
106        "w": [bz],
107        "colorscale": [[0, "red"], [1, "darkred"]],
108        "showscale": false,
109        "sizemode": "absolute",
110        "sizeref": 0.5,
111        "anchor": "tail",
112        "scene": scene,
113        "hoverinfo": "text",
114        "text": [format!("({:.3}, {:.3}, {:.3})", bx, by, bz)]
115    })
116}
117
118/// Returns a Plotly-JSON string for a grid of Bloch spheres, one per qubit.
119///
120/// The grid is arranged in a `ceil(sqrt(N)) × ceil(sqrt(N))` layout.
121/// Each subplot shows the Bloch sphere surface and the qubit's Bloch vector.
122pub fn bloch_array_plotly_json(
123    state: &Array1<Complex64>,
124    n_qubits: usize,
125) -> QuantRS2Result<String> {
126    if n_qubits == 0 {
127        return Err(QuantRS2Error::InvalidInput(
128            "n_qubits must be > 0".to_string(),
129        ));
130    }
131
132    // Compute Bloch vectors for all qubits
133    let mut vectors: Vec<(f64, f64, f64)> = Vec::with_capacity(n_qubits);
134    for i in 0..n_qubits {
135        vectors.push(bloch_vector_for_qubit(state, i, n_qubits)?);
136    }
137
138    let cols = (n_qubits as f64).sqrt().ceil() as usize;
139    let cols = cols.max(1);
140    let rows = (n_qubits + cols - 1) / cols;
141
142    let mut data: Vec<Value> = Vec::new();
143    let mut layout = json!({});
144
145    // Build scene layout and traces for each qubit
146    for (idx, &(bx, by, bz)) in vectors.iter().enumerate() {
147        let scene_name = if idx == 0 {
148            "scene".to_string()
149        } else {
150            format!("scene{}", idx + 1)
151        };
152
153        let row = idx / cols;
154        let col = idx % cols;
155
156        // Each subplot occupies a fraction of the total plotting area
157        let w = 1.0 / (cols as f64);
158        let h = 1.0 / (rows as f64);
159        let x_start = (col as f64) * w;
160        let y_start = 1.0 - ((row + 1) as f64) * h;
161
162        // Sphere surface trace
163        let mut sphere = sphere_surface_trace("x");
164        // Overwrite scene key to correct scene
165        if let Value::Object(ref mut map) = sphere {
166            map.insert("scene".to_string(), json!(scene_name));
167        }
168        data.push(sphere);
169
170        // Bloch vector cone trace
171        let mut cone = bloch_vector_trace(bx, by, bz, &scene_name);
172        // scene is already set in the function above; override if needed
173        if let Value::Object(ref mut map) = cone {
174            map.insert("scene".to_string(), json!(scene_name));
175        }
176        data.push(cone);
177
178        // Scene layout entry
179        let scene_def = json!({
180            "xaxis": {"title": "x", "range": [-1.2, 1.2]},
181            "yaxis": {"title": "y", "range": [-1.2, 1.2]},
182            "zaxis": {"title": "z", "range": [-1.2, 1.2]},
183            "aspectmode": "cube",
184            "annotations": [{
185                "text": format!("Qubit {}", idx),
186                "x": 0.5,
187                "y": 1.05,
188                "z": 1.0,
189                "showarrow": false,
190                "font": {"size": 12}
191            }],
192            "domain": {
193                "x": [x_start, x_start + w],
194                "y": [y_start, y_start + h]
195            }
196        });
197
198        if let Value::Object(ref mut layout_map) = layout {
199            layout_map.insert(scene_name, scene_def);
200        }
201    }
202
203    if let Value::Object(ref mut layout_map) = layout {
204        layout_map.insert("title".to_string(), json!("Bloch Sphere Array"));
205        layout_map.insert("showlegend".to_string(), json!(false));
206        layout_map.insert("height".to_string(), json!(400 * rows));
207    }
208
209    let figure = json!({
210        "data": data,
211        "layout": layout
212    });
213
214    serde_json::to_string(&figure).map_err(QuantRS2Error::from)
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use scirs2_core::Complex64;
221
222    fn state_zero() -> Array1<Complex64> {
223        Array1::from(vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)])
224    }
225
226    fn state_one() -> Array1<Complex64> {
227        Array1::from(vec![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)])
228    }
229
230    fn state_plus() -> Array1<Complex64> {
231        let inv_sqrt2 = 1.0 / 2.0_f64.sqrt();
232        Array1::from(vec![
233            Complex64::new(inv_sqrt2, 0.0),
234            Complex64::new(inv_sqrt2, 0.0),
235        ])
236    }
237
238    fn state_plus_i() -> Array1<Complex64> {
239        // |+i⟩ = (|0⟩ + i|1⟩)/√2
240        let inv_sqrt2 = 1.0 / 2.0_f64.sqrt();
241        Array1::from(vec![
242            Complex64::new(inv_sqrt2, 0.0),
243            Complex64::new(0.0, inv_sqrt2),
244        ])
245    }
246
247    fn state_bell() -> Array1<Complex64> {
248        // |Φ+⟩ = (|00⟩ + |11⟩)/√2
249        let inv_sqrt2 = 1.0 / 2.0_f64.sqrt();
250        Array1::from(vec![
251            Complex64::new(inv_sqrt2, 0.0),
252            Complex64::new(0.0, 0.0),
253            Complex64::new(0.0, 0.0),
254            Complex64::new(inv_sqrt2, 0.0),
255        ])
256    }
257
258    #[test]
259    fn test_bloch_zero_state() {
260        let v = bloch_vector_for_qubit(&state_zero(), 0, 1).expect("Bloch vector failed");
261        assert!((v.0).abs() < 1e-10, "x should be 0, got {}", v.0);
262        assert!((v.1).abs() < 1e-10, "y should be 0, got {}", v.1);
263        assert!((v.2 - 1.0).abs() < 1e-10, "z should be 1, got {}", v.2);
264    }
265
266    #[test]
267    fn test_bloch_one_state() {
268        let v = bloch_vector_for_qubit(&state_one(), 0, 1).expect("Bloch vector failed");
269        assert!((v.0).abs() < 1e-10, "x should be 0, got {}", v.0);
270        assert!((v.1).abs() < 1e-10, "y should be 0, got {}", v.1);
271        assert!((v.2 + 1.0).abs() < 1e-10, "z should be -1, got {}", v.2);
272    }
273
274    #[test]
275    fn test_bloch_plus_state() {
276        let v = bloch_vector_for_qubit(&state_plus(), 0, 1).expect("Bloch vector failed");
277        assert!((v.0 - 1.0).abs() < 1e-10, "x should be 1, got {}", v.0);
278        assert!((v.1).abs() < 1e-10, "y should be 0, got {}", v.1);
279        assert!((v.2).abs() < 1e-10, "z should be 0, got {}", v.2);
280    }
281
282    #[test]
283    fn test_bloch_plus_i_state() {
284        // |+i⟩ = (|0⟩+i|1⟩)/√2 → Bloch vector = (0, 1, 0)
285        let v = bloch_vector_for_qubit(&state_plus_i(), 0, 1).expect("Bloch vector failed");
286        assert!((v.0).abs() < 1e-10, "x should be 0, got {}", v.0);
287        assert!((v.1 - 1.0).abs() < 1e-10, "y should be 1, got {}", v.1);
288        assert!((v.2).abs() < 1e-10, "z should be 0, got {}", v.2);
289    }
290
291    #[test]
292    fn test_bloch_bell_state() {
293        // Both qubits of |Φ+⟩ should be maximally mixed → (0, 0, 0)
294        let bell = state_bell();
295        for q in 0..2 {
296            let v = bloch_vector_for_qubit(&bell, q, 2).expect("Bloch vector failed");
297            assert!(
298                v.0.abs() < 1e-10 && v.1.abs() < 1e-10 && v.2.abs() < 1e-10,
299                "Bell qubit {} Bloch vector should be (0,0,0), got {:?}",
300                q,
301                v
302            );
303        }
304    }
305
306    #[test]
307    fn test_bloch_json_valid() {
308        let state = state_bell();
309        let json_str = bloch_array_plotly_json(&state, 2).expect("JSON generation failed");
310        let _parsed: serde_json::Value =
311            serde_json::from_str(&json_str).expect("Output should be valid JSON");
312    }
313}