1use scirs2_core::ndarray::{Array1, Array2};
15use scirs2_core::Complex64;
16use std::f64::consts::PI;
17
18pub struct QuantumPhaseEstimation {
22 precision_bits: usize,
24 unitary: Array2<Complex64>,
26 target_qubits: usize,
28}
29
30impl QuantumPhaseEstimation {
31 pub fn new(precision_bits: usize, unitary: Array2<Complex64>) -> Self {
33 let target_qubits = (unitary.shape()[0] as f64).log2() as usize;
34
35 Self {
36 precision_bits,
37 unitary,
38 target_qubits,
39 }
40 }
41
42 fn apply_controlled_u_power(&self, state: &mut [Complex64], control: usize, k: usize) {
44 let power = 1 << k;
45 let n = self.target_qubits;
46 let dim = 1 << n;
47
48 let mut u_power = Array2::eye(dim);
50 let mut temp = self.unitary.clone();
51 let mut p = power;
52
53 while p > 0 {
54 if p & 1 == 1 {
55 u_power = u_power.dot(&temp);
56 }
57 temp = temp.dot(&temp);
58 p >>= 1;
59 }
60
61 let total_qubits = self.precision_bits + self.target_qubits;
67 let precision_dim = 1 << self.precision_bits;
68
69 for prec in 0..precision_dim {
70 if (prec >> (self.precision_bits - control - 1)) & 1 == 1 {
74 let base_idx = prec << n; let mut amplitudes = vec![Complex64::new(0.0, 0.0); dim];
79 for i in 0..dim {
80 amplitudes[i] = state[base_idx | i];
81 }
82
83 let result = u_power.dot(&Array1::from(amplitudes));
85
86 for i in 0..dim {
88 state[base_idx | i] = result[i];
89 }
90 }
91 }
92 }
93
94 fn apply_inverse_qft(&self, state: &mut [Complex64]) {
96 let n = self.precision_bits;
97 let total_qubits = n + self.target_qubits;
98
99 for j in (0..n).rev() {
101 self.apply_hadamard(state, j, total_qubits);
103
104 for k in (0..j).rev() {
106 let angle = -PI / (1 << (j - k)) as f64;
107 self.apply_controlled_phase(state, k, j, angle, total_qubits);
108 }
109 }
110
111 for i in 0..n / 2 {
113 self.swap_qubits(state, i, n - 1 - i, total_qubits);
114 }
115 }
116
117 fn apply_hadamard(&self, state: &mut [Complex64], qubit: usize, total_qubits: usize) {
119 let h = 1.0 / std::f64::consts::SQRT_2;
120 let dim = 1 << total_qubits;
121
122 for i in 0..dim {
123 if (i >> (total_qubits - qubit - 1)) & 1 == 0 {
124 let j = i | (1 << (total_qubits - qubit - 1));
125 let a = state[i];
126 let b = state[j];
127 state[i] = h * (a + b);
128 state[j] = h * (a - b);
129 }
130 }
131 }
132
133 fn apply_controlled_phase(
135 &self,
136 state: &mut [Complex64],
137 control: usize,
138 target: usize,
139 angle: f64,
140 total_qubits: usize,
141 ) {
142 let phase = Complex64::new(angle.cos(), angle.sin());
143
144 for (i, amp) in state.iter_mut().enumerate() {
145 let control_bit = (i >> (total_qubits - control - 1)) & 1;
146 let target_bit = (i >> (total_qubits - target - 1)) & 1;
147
148 if control_bit == 1 && target_bit == 1 {
149 *amp *= phase;
150 }
151 }
152 }
153
154 fn swap_qubits(&self, state: &mut [Complex64], q1: usize, q2: usize, total_qubits: usize) {
156 let dim = 1 << total_qubits;
157
158 for i in 0..dim {
159 let bit1 = (i >> (total_qubits - q1 - 1)) & 1;
160 let bit2 = (i >> (total_qubits - q2 - 1)) & 1;
161
162 if bit1 != bit2 {
163 let j = i ^ (1 << (total_qubits - q1 - 1)) ^ (1 << (total_qubits - q2 - 1));
164 if i < j {
165 state.swap(i, j);
166 }
167 }
168 }
169 }
170
171 pub fn estimate_phase(&self, eigenstate: Vec<Complex64>) -> f64 {
173 let total_qubits = self.precision_bits + self.target_qubits;
174 let state_dim = 1 << total_qubits;
175 let mut state = vec![Complex64::new(0.0, 0.0); state_dim];
176
177 for i in 0..(1 << self.target_qubits) {
179 if i < eigenstate.len() {
180 state[i] = eigenstate[i];
181 }
182 }
183
184 for j in 0..self.precision_bits {
186 self.apply_hadamard(&mut state, j, total_qubits);
187 }
188
189 for j in 0..self.precision_bits {
199 let power_k = j; let control_qubit = self.precision_bits - 1 - j;
202 self.apply_controlled_u_power(&mut state, control_qubit, power_k);
203 }
204
205 self.apply_inverse_qft(&mut state);
207
208 let mut max_prob = 0.0;
210 let mut measured_value = 0;
211
212 for (i, amp) in state.iter().enumerate() {
213 let precision_bits_value = i >> self.target_qubits;
214 let prob = amp.norm_sqr();
215
216 if prob > max_prob {
217 max_prob = prob;
218 measured_value = precision_bits_value;
219 }
220 }
221
222 measured_value as f64 / (1 << self.precision_bits) as f64
224 }
225}
226
227pub struct QuantumCounting {
231 pub n_items: usize,
233 pub precision_bits: usize,
235 pub oracle: Box<dyn Fn(usize) -> bool>,
237}
238
239impl QuantumCounting {
240 pub fn new(n_items: usize, precision_bits: usize, oracle: Box<dyn Fn(usize) -> bool>) -> Self {
242 Self {
243 n_items,
244 precision_bits,
245 oracle,
246 }
247 }
248
249 fn build_grover_operator(&self) -> Array2<Complex64> {
251 let n = self.n_items;
252 let mut grover = Array2::zeros((n, n));
253
254 for i in 0..n {
256 if (self.oracle)(i) {
257 grover[[i, i]] = Complex64::new(-1.0, 0.0);
258 } else {
259 grover[[i, i]] = Complex64::new(1.0, 0.0);
260 }
261 }
262
263 let s_amplitude = 1.0 / (n as f64).sqrt();
265 let diffusion =
266 Array2::from_elem((n, n), Complex64::new(2.0 * s_amplitude * s_amplitude, 0.0))
267 - Array2::<Complex64>::eye(n);
268
269 -diffusion.dot(&grover)
271 }
272
273 pub fn count(&self) -> f64 {
275 let grover = self.build_grover_operator();
277
278 let qpe = QuantumPhaseEstimation::new(self.precision_bits, grover);
280
281 let n = self.n_items;
283 let amplitude = Complex64::new(1.0 / (n as f64).sqrt(), 0.0);
284 let eigenstate = vec![amplitude; n];
285
286 let phase = qpe.estimate_phase(eigenstate);
288
289 let theta = phase * PI;
292 let sin_theta = theta.sin();
293 sin_theta * sin_theta * n as f64
294 }
295}
296
297pub struct QuantumAmplitudeEstimation {
301 pub state_prep: Array2<Complex64>,
303 pub oracle: Array2<Complex64>,
305 pub precision_bits: usize,
307}
308
309impl QuantumAmplitudeEstimation {
310 pub const fn new(
312 state_prep: Array2<Complex64>,
313 oracle: Array2<Complex64>,
314 precision_bits: usize,
315 ) -> Self {
316 Self {
317 state_prep,
318 oracle,
319 precision_bits,
320 }
321 }
322
323 fn build_q_operator(&self) -> Array2<Complex64> {
325 let n = self.state_prep.shape()[0];
326 let identity = Array2::<Complex64>::eye(n);
327
328 let reflection_good = &identity - &self.oracle * 2.0;
330
331 let zero_state = Array1::zeros(n);
333 let mut zero_state_vec = zero_state.to_vec();
334 zero_state_vec[0] = Complex64::new(1.0, 0.0);
335
336 let initial = self.state_prep.dot(&Array1::from(zero_state_vec));
337 let mut reflection_initial = Array2::zeros((n, n));
338
339 for i in 0..n {
340 for j in 0..n {
341 reflection_initial[[i, j]] = 2.0 * initial[i] * initial[j].conj();
342 }
343 }
344 reflection_initial -= &identity;
345
346 -reflection_initial.dot(&reflection_good)
348 }
349
350 pub fn estimate(&self) -> f64 {
352 let q_operator = self.build_q_operator();
354
355 let qpe = QuantumPhaseEstimation::new(self.precision_bits, q_operator);
357
358 let n = self.state_prep.shape()[0];
360 let mut zero_state = vec![Complex64::new(0.0, 0.0); n];
361 zero_state[0] = Complex64::new(1.0, 0.0);
362 let initial_state = self.state_prep.dot(&Array1::from(zero_state));
363
364 let phase = qpe.estimate_phase(initial_state.to_vec());
366
367 let theta = phase * PI;
370 theta.sin().abs()
371 }
372}
373
374pub fn quantum_counting_example() {
376 println!("Quantum Counting Example");
377 println!("=======================");
378
379 let oracle = Box::new(|x: usize| x % 3 == 0 && x > 0);
381
382 let counter = QuantumCounting::new(16, 4, oracle);
383 let count = counter.count();
384
385 println!("Counting numbers divisible by 3 in range 1-15:");
386 println!("Estimated count: {count:.1}");
387 println!("Actual count: 5 (3, 6, 9, 12, 15)");
388 println!("Error: {:.1}", (count - 5.0).abs());
389}
390
391pub fn amplitude_estimation_example() {
393 println!("\nAmplitude Estimation Example");
394 println!("============================");
395
396 let n = 8;
398 let state_prep = Array2::from_elem((n, n), Complex64::new(1.0 / (n as f64).sqrt(), 0.0));
399
400 let mut oracle = Array2::zeros((n, n));
402 oracle[[2, 2]] = Complex64::new(1.0, 0.0);
403 oracle[[5, 5]] = Complex64::new(1.0, 0.0);
404
405 let qae = QuantumAmplitudeEstimation::new(state_prep, oracle, 4);
406 let amplitude = qae.estimate();
407
408 println!("Estimating amplitude of marked states (2 and 5) in uniform superposition:");
409 println!("Estimated amplitude: {amplitude:.3}");
410 println!("Actual amplitude: {:.3}", (2.0 / n as f64).sqrt());
411 println!("Error: {:.3}", (amplitude - (2.0 / n as f64).sqrt()).abs());
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417
418 #[test]
419 fn test_phase_estimation_basic() {
420 let phase = PI / 4.0;
428 let u = Array2::from_shape_vec(
429 (2, 2),
430 vec![
431 Complex64::new(1.0, 0.0),
432 Complex64::new(0.0, 0.0),
433 Complex64::new(0.0, 0.0),
434 Complex64::new(phase.cos(), phase.sin()),
435 ],
436 )
437 .expect("2x2 matrix from 4-element vector should succeed");
438
439 let precision_bits = 4usize;
440 let qpe = QuantumPhaseEstimation::new(precision_bits, u);
441
442 let eigenstate = vec![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)];
444 let estimated = qpe.estimate_phase(eigenstate);
445
446 assert!(
448 (0.0..=1.0).contains(&estimated),
449 "estimated phase {estimated} is outside [0, 1]"
450 );
451
452 let grid = 1.0 / (1u64 << precision_bits) as f64;
454 let residual = (estimated / grid).round() * grid - estimated;
455 assert!(
456 residual.abs() < 1e-9,
457 "estimated {estimated} is not on the {precision_bits}-bit phase grid"
458 );
459
460 let true_phase = phase / (2.0 * PI);
463 let conjugate_phase = 1.0 - true_phase;
464 let slack = grid + 1e-9;
466 let near_true = (estimated - true_phase).abs() <= slack;
467 let near_conj = (estimated - conjugate_phase).abs() <= slack;
468 assert!(
469 near_true || near_conj,
470 "QPE estimate {estimated:.6} is not near true phase {true_phase:.6} \
471 or conjugate {conjugate_phase:.6} (tolerance {slack:.6})"
472 );
473 }
474
475 #[test]
476 fn test_quantum_counting_simple() {
477 let oracle = Box::new(|x: usize| x == 2);
481 let counter = QuantumCounting::new(4, 4, oracle);
482 let count = counter.count();
483
484 assert!(count >= 0.0, "count {count} must be non-negative");
486 assert!(count <= 4.0 + 1e-6, "count {count} must not exceed N=4");
487 }
488}