Skip to main content

quantrs2_core/
quantum_counting.rs

1//! Quantum Counting and Amplitude Estimation
2//!
3//! This module implements quantum counting and amplitude estimation algorithms,
4//! which are key components for many quantum algorithms including Shor's algorithm
5//! and quantum database search.
6//!
7//! TODO: The current implementations are simplified versions. Full implementations
8//! would require:
9//! - Proper controlled unitary implementations
10//! - Full QFT and inverse QFT
11//! - Better phase extraction from measurement results
12//! - Integration with circuit builder for more complex operations
13
14use scirs2_core::ndarray::{Array1, Array2};
15use scirs2_core::Complex64;
16use std::f64::consts::PI;
17
18/// Quantum Phase Estimation (QPE) algorithm
19///
20/// Estimates the phase φ in the eigenvalue e^(2πiφ) of a unitary operator U
21pub struct QuantumPhaseEstimation {
22    /// Number of precision bits
23    precision_bits: usize,
24    /// The unitary operator U
25    unitary: Array2<Complex64>,
26    /// Number of target qubits
27    target_qubits: usize,
28}
29
30impl QuantumPhaseEstimation {
31    /// Create a new QPE instance
32    pub fn new(precision_bits: usize, unitary: Array2<Complex64>) -> Self {
33        let target_qubits = (unitary.shape()[0] as f64).log2() as usize;
34
35        Self {
36            precision_bits,
37            unitary,
38            target_qubits,
39        }
40    }
41
42    /// Apply controlled-U^(2^k) operation
43    fn apply_controlled_u_power(&self, state: &mut [Complex64], control: usize, k: usize) {
44        let power = 1 << k;
45        let n = self.target_qubits;
46        let dim = 1 << n;
47
48        // Build U^power by repeated squaring
49        let mut u_power = Array2::eye(dim);
50        let mut temp = self.unitary.clone();
51        let mut p = power;
52
53        while p > 0 {
54            if p & 1 == 1 {
55                u_power = u_power.dot(&temp);
56            }
57            temp = temp.dot(&temp);
58            p >>= 1;
59        }
60
61        // Apply controlled operation.
62        // We iterate over each unique "precision register" configuration; for
63        // those with the control qubit set, we apply U^power to the target
64        // register.  We use only the "canonical" representative (target = 0)
65        // to avoid applying the operation more than once per configuration.
66        let total_qubits = self.precision_bits + self.target_qubits;
67        let precision_dim = 1 << self.precision_bits;
68
69        for prec in 0..precision_dim {
70            // Check if the control qubit is |1⟩ within the precision register.
71            // Precision qubits occupy the HIGH bits; qubit `control` corresponds
72            // to bit (precision_bits - control - 1) inside `prec`.
73            if (prec >> (self.precision_bits - control - 1)) & 1 == 1 {
74                // Base index: precision register `prec` in the high bits, target = 0
75                let base_idx = prec << n; // precision in MSBs, target bits = 0
76
77                // Read all 2^n amplitudes for this precision configuration
78                let mut amplitudes = vec![Complex64::new(0.0, 0.0); dim];
79                for i in 0..dim {
80                    amplitudes[i] = state[base_idx | i];
81                }
82
83                // Apply U^power to the target register amplitudes
84                let result = u_power.dot(&Array1::from(amplitudes));
85
86                // Write back
87                for i in 0..dim {
88                    state[base_idx | i] = result[i];
89                }
90            }
91        }
92    }
93
94    /// Apply inverse QFT to precision qubits
95    fn apply_inverse_qft(&self, state: &mut [Complex64]) {
96        let n = self.precision_bits;
97        let total_qubits = n + self.target_qubits;
98
99        // Implement inverse QFT on the first n qubits
100        for j in (0..n).rev() {
101            // Apply Hadamard to qubit j
102            self.apply_hadamard(state, j, total_qubits);
103
104            // Apply controlled phase rotations
105            for k in (0..j).rev() {
106                let angle = -PI / (1 << (j - k)) as f64;
107                self.apply_controlled_phase(state, k, j, angle, total_qubits);
108            }
109        }
110
111        // Swap qubits to reverse order
112        for i in 0..n / 2 {
113            self.swap_qubits(state, i, n - 1 - i, total_qubits);
114        }
115    }
116
117    /// Apply Hadamard gate to a specific qubit
118    fn apply_hadamard(&self, state: &mut [Complex64], qubit: usize, total_qubits: usize) {
119        let h = 1.0 / std::f64::consts::SQRT_2;
120        let dim = 1 << total_qubits;
121
122        for i in 0..dim {
123            if (i >> (total_qubits - qubit - 1)) & 1 == 0 {
124                let j = i | (1 << (total_qubits - qubit - 1));
125                let a = state[i];
126                let b = state[j];
127                state[i] = h * (a + b);
128                state[j] = h * (a - b);
129            }
130        }
131    }
132
133    /// Apply controlled phase rotation
134    fn apply_controlled_phase(
135        &self,
136        state: &mut [Complex64],
137        control: usize,
138        target: usize,
139        angle: f64,
140        total_qubits: usize,
141    ) {
142        let phase = Complex64::new(angle.cos(), angle.sin());
143
144        for (i, amp) in state.iter_mut().enumerate() {
145            let control_bit = (i >> (total_qubits - control - 1)) & 1;
146            let target_bit = (i >> (total_qubits - target - 1)) & 1;
147
148            if control_bit == 1 && target_bit == 1 {
149                *amp *= phase;
150            }
151        }
152    }
153
154    /// Swap two qubits
155    fn swap_qubits(&self, state: &mut [Complex64], q1: usize, q2: usize, total_qubits: usize) {
156        let dim = 1 << total_qubits;
157
158        for i in 0..dim {
159            let bit1 = (i >> (total_qubits - q1 - 1)) & 1;
160            let bit2 = (i >> (total_qubits - q2 - 1)) & 1;
161
162            if bit1 != bit2 {
163                let j = i ^ (1 << (total_qubits - q1 - 1)) ^ (1 << (total_qubits - q2 - 1));
164                if i < j {
165                    state.swap(i, j);
166                }
167            }
168        }
169    }
170
171    /// Run the QPE algorithm
172    pub fn estimate_phase(&self, eigenstate: Vec<Complex64>) -> f64 {
173        let total_qubits = self.precision_bits + self.target_qubits;
174        let state_dim = 1 << total_qubits;
175        let mut state = vec![Complex64::new(0.0, 0.0); state_dim];
176
177        // Initialize precision qubits to |0⟩ and target qubits to eigenstate
178        for i in 0..(1 << self.target_qubits) {
179            if i < eigenstate.len() {
180                state[i] = eigenstate[i];
181            }
182        }
183
184        // Apply Hadamard to all precision qubits
185        for j in 0..self.precision_bits {
186            self.apply_hadamard(&mut state, j, total_qubits);
187        }
188
189        // Apply controlled-U^{2^j} operations.
190        // Standard QPE: ancilla qubit j (j = 0 is the first/MSB ancilla)
191        // controls U^{2^(precision_bits - 1 - j)} so that after inverse QFT
192        // the MSB carries the most significant bit of the phase.
193        // Equivalently, iterating j from (precision_bits-1) down to 0:
194        //   - ancilla qubit 0 controls U^{2^(n-1)}
195        //   - ancilla qubit 1 controls U^{2^(n-2)}
196        //   - ...
197        //   - ancilla qubit n-1 controls U^{2^0} = U
198        for j in 0..self.precision_bits {
199            // power_k is the exponent for the j-th controlled-U application
200            let power_k = j; // U^{2^j} controlled on ancilla qubit (precision_bits - 1 - j)
201            let control_qubit = self.precision_bits - 1 - j;
202            self.apply_controlled_u_power(&mut state, control_qubit, power_k);
203        }
204
205        // Apply inverse QFT
206        self.apply_inverse_qft(&mut state);
207
208        // Measure precision qubits
209        let mut max_prob = 0.0;
210        let mut measured_value = 0;
211
212        for (i, amp) in state.iter().enumerate() {
213            let precision_bits_value = i >> self.target_qubits;
214            let prob = amp.norm_sqr();
215
216            if prob > max_prob {
217                max_prob = prob;
218                measured_value = precision_bits_value;
219            }
220        }
221
222        // Convert to phase estimate
223        measured_value as f64 / (1 << self.precision_bits) as f64
224    }
225}
226
227/// Quantum Counting algorithm
228///
229/// Counts the number of solutions to a search problem
230pub struct QuantumCounting {
231    /// Number of items in the search space
232    pub n_items: usize,
233    /// Number of precision bits for counting
234    pub precision_bits: usize,
235    /// Oracle function that marks solutions
236    pub oracle: Box<dyn Fn(usize) -> bool>,
237}
238
239impl QuantumCounting {
240    /// Create a new quantum counting instance
241    pub fn new(n_items: usize, precision_bits: usize, oracle: Box<dyn Fn(usize) -> bool>) -> Self {
242        Self {
243            n_items,
244            precision_bits,
245            oracle,
246        }
247    }
248
249    /// Build the Grover operator
250    fn build_grover_operator(&self) -> Array2<Complex64> {
251        let n = self.n_items;
252        let mut grover = Array2::zeros((n, n));
253
254        // Oracle: flip phase of marked items
255        for i in 0..n {
256            if (self.oracle)(i) {
257                grover[[i, i]] = Complex64::new(-1.0, 0.0);
258            } else {
259                grover[[i, i]] = Complex64::new(1.0, 0.0);
260            }
261        }
262
263        // Diffusion operator: 2|s⟩⟨s| - I
264        let s_amplitude = 1.0 / (n as f64).sqrt();
265        let diffusion =
266            Array2::from_elem((n, n), Complex64::new(2.0 * s_amplitude * s_amplitude, 0.0))
267                - Array2::<Complex64>::eye(n);
268
269        // Grover operator = -Diffusion × Oracle
270        -diffusion.dot(&grover)
271    }
272
273    /// Count the number of solutions
274    pub fn count(&self) -> f64 {
275        // Build Grover operator
276        let grover = self.build_grover_operator();
277
278        // Use QPE to estimate the phase
279        let qpe = QuantumPhaseEstimation::new(self.precision_bits, grover);
280
281        // Prepare uniform superposition as eigenstate
282        let n = self.n_items;
283        let amplitude = Complex64::new(1.0 / (n as f64).sqrt(), 0.0);
284        let eigenstate = vec![amplitude; n];
285
286        // Estimate phase
287        let phase = qpe.estimate_phase(eigenstate);
288
289        // Convert phase to count
290        // For Grover operator, eigenvalues are e^(±2iθ) where sin²(θ) = M/N
291        let theta = phase * PI;
292        let sin_theta = theta.sin();
293        sin_theta * sin_theta * n as f64
294    }
295}
296
297/// Quantum Amplitude Estimation
298///
299/// Estimates the amplitude of marked states in a superposition
300pub struct QuantumAmplitudeEstimation {
301    /// State preparation operator A
302    pub state_prep: Array2<Complex64>,
303    /// Oracle that identifies good states
304    pub oracle: Array2<Complex64>,
305    /// Number of precision bits
306    pub precision_bits: usize,
307}
308
309impl QuantumAmplitudeEstimation {
310    /// Create a new amplitude estimation instance
311    pub const fn new(
312        state_prep: Array2<Complex64>,
313        oracle: Array2<Complex64>,
314        precision_bits: usize,
315    ) -> Self {
316        Self {
317            state_prep,
318            oracle,
319            precision_bits,
320        }
321    }
322
323    /// Build the Q operator for amplitude estimation
324    fn build_q_operator(&self) -> Array2<Complex64> {
325        let n = self.state_prep.shape()[0];
326        let identity = Array2::<Complex64>::eye(n);
327
328        // Reflection about good states: I - 2P where P projects onto good states
329        let reflection_good = &identity - &self.oracle * 2.0;
330
331        // Reflection about initial state: 2A|0⟩⟨0|A† - I
332        let zero_state = Array1::zeros(n);
333        let mut zero_state_vec = zero_state.to_vec();
334        zero_state_vec[0] = Complex64::new(1.0, 0.0);
335
336        let initial = self.state_prep.dot(&Array1::from(zero_state_vec));
337        let mut reflection_initial = Array2::zeros((n, n));
338
339        for i in 0..n {
340            for j in 0..n {
341                reflection_initial[[i, j]] = 2.0 * initial[i] * initial[j].conj();
342            }
343        }
344        reflection_initial -= &identity;
345
346        // Q = -reflection_initial × reflection_good
347        -reflection_initial.dot(&reflection_good)
348    }
349
350    /// Estimate the amplitude
351    pub fn estimate(&self) -> f64 {
352        // Build Q operator
353        let q_operator = self.build_q_operator();
354
355        // Use QPE to find eigenphase
356        let qpe = QuantumPhaseEstimation::new(self.precision_bits, q_operator);
357
358        // Prepare initial state A|0⟩
359        let n = self.state_prep.shape()[0];
360        let mut zero_state = vec![Complex64::new(0.0, 0.0); n];
361        zero_state[0] = Complex64::new(1.0, 0.0);
362        let initial_state = self.state_prep.dot(&Array1::from(zero_state));
363
364        // Estimate phase
365        let phase = qpe.estimate_phase(initial_state.to_vec());
366
367        // Convert phase to amplitude
368        // For Q operator, eigenvalues are e^(±2iθ) where sin(θ) = amplitude
369        let theta = phase * PI;
370        theta.sin().abs()
371    }
372}
373
374/// Example: Count solutions to a simple search problem
375pub fn quantum_counting_example() {
376    println!("Quantum Counting Example");
377    println!("=======================");
378
379    // Count numbers divisible by 3 in range 0-15
380    let oracle = Box::new(|x: usize| x % 3 == 0 && x > 0);
381
382    let counter = QuantumCounting::new(16, 4, oracle);
383    let count = counter.count();
384
385    println!("Counting numbers divisible by 3 in range 1-15:");
386    println!("Estimated count: {count:.1}");
387    println!("Actual count: 5 (3, 6, 9, 12, 15)");
388    println!("Error: {:.1}", (count - 5.0).abs());
389}
390
391/// Example: Estimate amplitude of marked states
392pub fn amplitude_estimation_example() {
393    println!("\nAmplitude Estimation Example");
394    println!("============================");
395
396    // Create a simple state preparation that creates equal superposition
397    let n = 8;
398    let state_prep = Array2::from_elem((n, n), Complex64::new(1.0 / (n as f64).sqrt(), 0.0));
399
400    // Oracle marks states 2 and 5
401    let mut oracle = Array2::zeros((n, n));
402    oracle[[2, 2]] = Complex64::new(1.0, 0.0);
403    oracle[[5, 5]] = Complex64::new(1.0, 0.0);
404
405    let qae = QuantumAmplitudeEstimation::new(state_prep, oracle, 4);
406    let amplitude = qae.estimate();
407
408    println!("Estimating amplitude of marked states (2 and 5) in uniform superposition:");
409    println!("Estimated amplitude: {amplitude:.3}");
410    println!("Actual amplitude: {:.3}", (2.0 / n as f64).sqrt());
411    println!("Error: {:.3}", (amplitude - (2.0 / n as f64).sqrt()).abs());
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417
418    #[test]
419    fn test_phase_estimation_basic() {
420        // U = diag(1, e^{iφ}) with φ = π/4.
421        // The eigenvalue of |1⟩ is e^{iπ/4}; the true phase is φ/(2π) = 1/8.
422        //
423        // Note: the current QPE implementation is a matrix-simulation approach
424        // that gives a numerically valid phase estimate in [0, 1].  The exact
425        // value depends on qubit ordering conventions; we verify structural
426        // correctness (valid range) and that the result is a multiple of 1/2^k.
427        let phase = PI / 4.0;
428        let u = Array2::from_shape_vec(
429            (2, 2),
430            vec![
431                Complex64::new(1.0, 0.0),
432                Complex64::new(0.0, 0.0),
433                Complex64::new(0.0, 0.0),
434                Complex64::new(phase.cos(), phase.sin()),
435            ],
436        )
437        .expect("2x2 matrix from 4-element vector should succeed");
438
439        let precision_bits = 4usize;
440        let qpe = QuantumPhaseEstimation::new(precision_bits, u);
441
442        // Test with eigenstate |1⟩
443        let eigenstate = vec![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)];
444        let estimated = qpe.estimate_phase(eigenstate);
445
446        // Phase must lie in [0, 1]
447        assert!(
448            (0.0..=1.0).contains(&estimated),
449            "estimated phase {estimated} is outside [0, 1]"
450        );
451
452        // The estimate must be a multiple of 1/(2^precision_bits)
453        let grid = 1.0 / (1u64 << precision_bits) as f64;
454        let residual = (estimated / grid).round() * grid - estimated;
455        assert!(
456            residual.abs() < 1e-9,
457            "estimated {estimated} is not on the {precision_bits}-bit phase grid"
458        );
459
460        // The true phase φ/(2π) = 1/8.  The estimated value should correspond
461        // to one of the two conjugate eigenphases: φ/(2π) or 1 - φ/(2π).
462        let true_phase = phase / (2.0 * PI);
463        let conjugate_phase = 1.0 - true_phase;
464        // Allow one grid step of slack on either eigenphase
465        let slack = grid + 1e-9;
466        let near_true = (estimated - true_phase).abs() <= slack;
467        let near_conj = (estimated - conjugate_phase).abs() <= slack;
468        assert!(
469            near_true || near_conj,
470            "QPE estimate {estimated:.6} is not near true phase {true_phase:.6} \
471             or conjugate {conjugate_phase:.6} (tolerance {slack:.6})"
472        );
473    }
474
475    #[test]
476    fn test_quantum_counting_simple() {
477        // Search space of 4 items; oracle marks exactly item 2 → M = 1.
478        // QuantumCounting.count() returns N·sin²(π·θ_estimated) where θ is
479        // the QPE phase estimate for the Grover operator.
480        let oracle = Box::new(|x: usize| x == 2);
481        let counter = QuantumCounting::new(4, 4, oracle);
482        let count = counter.count();
483
484        // The count must be non-negative and bounded by the search space size
485        assert!(count >= 0.0, "count {count} must be non-negative");
486        assert!(count <= 4.0 + 1e-6, "count {count} must not exceed N=4");
487    }
488}