Skip to main content

quantrs2_core/state_visualization_3d/
qsphere.rs

1//! Q-sphere visualization (Qiskit-style) for quantum states.
2//!
3//! Maps each computational basis state to a point on a unit sphere
4//! with latitude proportional to Hamming weight and longitude
5//! determined by rank within the Hamming shell.
6
7use scirs2_core::ndarray::Array1;
8use scirs2_core::Complex64;
9use serde_json::{json, Value};
10
11use crate::error::{QuantRS2Error, QuantRS2Result};
12
13/// Count set bits (Hamming weight / popcount).
14fn popcount(mut x: usize) -> usize {
15    let mut count = 0usize;
16    while x != 0 {
17        count += x & 1;
18        x >>= 1;
19    }
20    count
21}
22
23/// Map a phase angle in [0, 2π) to an RGB hex string using a cyclic colormap.
24///
25/// Uses a simple HSV→RGB conversion with hue = phase / (2π).
26fn phase_to_color(phase: f64) -> String {
27    let two_pi = 2.0 * std::f64::consts::PI;
28    // Normalize phase to [0, 1)
29    let mut h = phase / two_pi;
30    h -= h.floor();
31
32    // HSV with S=1, V=1
33    let hi = (h * 6.0).floor() as u32 % 6;
34    let f = h * 6.0 - h.floor() * 6.0;
35    // Note: this re-computes h floor so let's compute f properly
36    let f = h * 6.0 - (h * 6.0).floor();
37    let (r, g, b): (f64, f64, f64) = match hi {
38        0 => (1.0, f, 0.0),
39        1 => (1.0 - f, 1.0, 0.0),
40        2 => (0.0, 1.0, f),
41        3 => (0.0, 1.0 - f, 1.0),
42        4 => (f, 0.0, 1.0),
43        _ => (1.0, 0.0, 1.0 - f),
44    };
45    format!(
46        "#{:02x}{:02x}{:02x}",
47        (r * 255.0) as u8,
48        (g * 255.0) as u8,
49        (b * 255.0) as u8
50    )
51}
52
53/// Returns a Plotly-JSON string for a Q-sphere visualization of the state.
54///
55/// The sphere surface is rendered as a background mesh; each non-zero
56/// amplitude is plotted as a 3D scatter marker whose size encodes
57/// `|a|²` and colour encodes `arg(a)`.
58pub fn qsphere_plotly_json(state: &Array1<Complex64>, n_qubits: usize) -> QuantRS2Result<String> {
59    let dim = 1usize << n_qubits;
60    if state.len() != dim {
61        return Err(QuantRS2Error::InvalidInput(format!(
62            "State length {} does not match 2^{} = {}",
63            state.len(),
64            n_qubits,
65            dim
66        )));
67    }
68    if n_qubits == 0 {
69        return Err(QuantRS2Error::InvalidInput(
70            "n_qubits must be > 0".to_string(),
71        ));
72    }
73
74    // Precompute Hamming shells: for each Hamming weight w, list indices with that weight
75    let max_w = n_qubits;
76    let mut shells: Vec<Vec<usize>> = vec![Vec::new(); max_w + 1];
77    for x in 0..dim {
78        shells[popcount(x)].push(x);
79    }
80
81    // Compute spherical coordinates for every non-negligible basis state
82    let pi = std::f64::consts::PI;
83
84    let mut scatter_x: Vec<f64> = Vec::new();
85    let mut scatter_y: Vec<f64> = Vec::new();
86    let mut scatter_z: Vec<f64> = Vec::new();
87    let mut marker_sizes: Vec<f64> = Vec::new();
88    let mut marker_colors: Vec<String> = Vec::new();
89    let mut hover_texts: Vec<String> = Vec::new();
90
91    for x in 0..dim {
92        let amp = state[x];
93        let prob = amp.norm_sqr();
94        if prob < 1e-12 {
95            continue;
96        }
97
98        let w = popcount(x);
99        let theta = if n_qubits == 1 {
100            // n=1: |0⟩ at north (0), |1⟩ at south (π)
101            pi * (w as f64)
102        } else {
103            pi * (w as f64) / (n_qubits as f64)
104        };
105
106        // Rank of x within its Hamming shell
107        let shell = &shells[w];
108        let rank = shell
109            .iter()
110            .position(|&v| v == x)
111            .ok_or_else(|| QuantRS2Error::InvalidInput("Shell rank not found".to_string()))?;
112
113        let phi = if shell.len() == 1 {
114            0.0
115        } else {
116            2.0 * pi * (rank as f64) / (shell.len() as f64)
117        };
118
119        let sx = theta.sin() * phi.cos();
120        let sy = theta.sin() * phi.sin();
121        let sz = theta.cos();
122
123        scatter_x.push(sx);
124        scatter_y.push(sy);
125        scatter_z.push(sz);
126
127        // Marker size: scale probability to 5–30 range
128        let size = 5.0 + 25.0 * prob;
129        marker_sizes.push(size);
130
131        // Marker colour encodes phase
132        let phase = amp.im.atan2(amp.re);
133        let phase = if phase < 0.0 { phase + 2.0 * pi } else { phase };
134        marker_colors.push(phase_to_color(phase));
135
136        // Binary string label, e.g. "|101⟩"
137        let label: String = (0..n_qubits)
138            .rev()
139            .map(|bit| if (x >> bit) & 1 == 1 { '1' } else { '0' })
140            .collect();
141        hover_texts.push(format!("|{}⟩  p={:.4}  arg={:.3}rad", label, prob, phase));
142    }
143
144    // Background sphere surface
145    let sphere = build_sphere_mesh3d();
146
147    // Scatter trace for basis states
148    let scatter = json!({
149        "type": "scatter3d",
150        "x": scatter_x,
151        "y": scatter_y,
152        "z": scatter_z,
153        "mode": "markers",
154        "marker": {
155            "size": marker_sizes,
156            "color": marker_colors,
157            "opacity": 0.9,
158            "line": {"width": 1, "color": "black"}
159        },
160        "text": hover_texts,
161        "hoverinfo": "text",
162        "name": "Basis states"
163    });
164
165    let layout = json!({
166        "title": "Q-Sphere",
167        "scene": {
168            "xaxis": {"title": "x", "range": [-1.3, 1.3]},
169            "yaxis": {"title": "y", "range": [-1.3, 1.3]},
170            "zaxis": {"title": "z", "range": [-1.3, 1.3]},
171            "aspectmode": "cube",
172            "camera": {"eye": {"x": 1.4, "y": 1.4, "z": 1.0}}
173        },
174        "showlegend": false,
175        "height": 600
176    });
177
178    let figure = json!({
179        "data": [sphere, scatter],
180        "layout": layout
181    });
182
183    serde_json::to_string(&figure).map_err(QuantRS2Error::from)
184}
185
186/// Build a unit-sphere as a `mesh3d` background trace.
187fn build_sphere_mesh3d() -> Value {
188    let n = 18usize; // 18 latitude × 18 longitude segments
189    let mut xs: Vec<f64> = Vec::new();
190    let mut ys: Vec<f64> = Vec::new();
191    let mut zs: Vec<f64> = Vec::new();
192    let mut is: Vec<usize> = Vec::new();
193    let mut js: Vec<usize> = Vec::new();
194    let mut ks: Vec<usize> = Vec::new();
195
196    let pi = std::f64::consts::PI;
197
198    // Vertices
199    for i in 0..=n {
200        let theta = pi * (i as f64) / (n as f64);
201        for j in 0..=n {
202            let phi = 2.0 * pi * (j as f64) / (n as f64);
203            xs.push(theta.sin() * phi.cos());
204            ys.push(theta.sin() * phi.sin());
205            zs.push(theta.cos());
206        }
207    }
208
209    let stride = n + 1;
210    // Triangles (two triangles per quad)
211    for i in 0..n {
212        for j in 0..n {
213            let a = i * stride + j;
214            let b = i * stride + j + 1;
215            let c = (i + 1) * stride + j;
216            let d = (i + 1) * stride + j + 1;
217            // Triangle 1: a, b, c
218            is.push(a);
219            js.push(b);
220            ks.push(c);
221            // Triangle 2: b, d, c
222            is.push(b);
223            js.push(d);
224            ks.push(c);
225        }
226    }
227
228    json!({
229        "type": "mesh3d",
230        "x": xs,
231        "y": ys,
232        "z": zs,
233        "i": is,
234        "j": js,
235        "k": ks,
236        "opacity": 0.15,
237        "color": "lightblue",
238        "hoverinfo": "none",
239        "name": "Sphere"
240    })
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use scirs2_core::Complex64;
247
248    fn state_ghz() -> Array1<Complex64> {
249        // (|000⟩ + |111⟩)/√2  (3-qubit GHZ)
250        let inv_sqrt2 = 1.0 / 2.0_f64.sqrt();
251        let mut s = Array1::zeros(8);
252        s[0] = Complex64::new(inv_sqrt2, 0.0); // |000⟩
253        s[7] = Complex64::new(inv_sqrt2, 0.0); // |111⟩
254        s
255    }
256
257    #[test]
258    fn test_qsphere_ghz() {
259        // GHZ has exactly 2 non-zero amplitudes: |000⟩ (w=0) and |111⟩ (w=3)
260        let state = state_ghz();
261        let json_str = qsphere_plotly_json(&state, 3).expect("Q-sphere failed");
262        let parsed: serde_json::Value =
263            serde_json::from_str(&json_str).expect("Should be valid JSON");
264
265        // Find the scatter trace (index 1)
266        let data = parsed["data"].as_array().expect("data array missing");
267        let scatter = data
268            .iter()
269            .find(|t| t["type"] == "scatter3d")
270            .expect("No scatter3d trace found");
271        let x = scatter["x"].as_array().expect("scatter x missing");
272        assert_eq!(x.len(), 2, "GHZ should have exactly 2 markers");
273    }
274
275    #[test]
276    fn test_qsphere_json_valid() {
277        let inv_sqrt2 = 1.0 / 2.0_f64.sqrt();
278        let state = Array1::from(vec![
279            Complex64::new(inv_sqrt2, 0.0),
280            Complex64::new(0.0, 0.0),
281            Complex64::new(0.0, 0.0),
282            Complex64::new(inv_sqrt2, 0.0),
283        ]);
284        let json_str = qsphere_plotly_json(&state, 2).expect("Q-sphere failed");
285        let _parsed: serde_json::Value =
286            serde_json::from_str(&json_str).expect("Output should be valid JSON");
287    }
288}