Skip to main content

quantrs2_core/quantum_walk/
discrete.rs

1//! Discrete-time quantum walk implementation.
2
3use super::graph::{CoinOperator, Graph};
4use scirs2_core::ndarray::{Array1, Array2};
5use scirs2_core::Complex64;
6
7/// Discrete-time quantum walk
8pub struct DiscreteQuantumWalk {
9    pub(crate) graph: Graph,
10    coin_operator: CoinOperator,
11    pub(crate) coin_dimension: usize,
12    /// Total Hilbert space dimension: coin_dimension * num_vertices
13    pub(crate) hilbert_dim: usize,
14    /// Current state vector
15    pub(crate) state: Vec<Complex64>,
16}
17
18impl DiscreteQuantumWalk {
19    /// Create a new discrete quantum walk with specified coin operator
20    pub fn new(graph: Graph, coin_operator: CoinOperator) -> Self {
21        // Coin dimension is the maximum degree for standard walks
22        // For hypercube, it's the dimension
23        let coin_dimension = match graph.num_vertices {
24            n if n > 0 => {
25                (0..graph.num_vertices)
26                    .map(|v| graph.degree(v))
27                    .max()
28                    .unwrap_or(2)
29                    .max(2) // At least 2-dimensional coin
30            }
31            _ => 2,
32        };
33
34        let hilbert_dim = coin_dimension * graph.num_vertices;
35
36        Self {
37            graph,
38            coin_operator,
39            coin_dimension,
40            hilbert_dim,
41            state: vec![Complex64::new(0.0, 0.0); hilbert_dim],
42        }
43    }
44
45    /// Initialize walker at a specific position
46    pub fn initialize_position(&mut self, position: usize) {
47        self.state = vec![Complex64::new(0.0, 0.0); self.hilbert_dim];
48
49        // Equal superposition over all coin states at the position
50        let degree = self.graph.degree(position) as f64;
51        if degree > 0.0 {
52            let amplitude = Complex64::new(1.0 / degree.sqrt(), 0.0);
53
54            for coin in 0..self.coin_dimension.min(self.graph.degree(position)) {
55                let index = self.state_index(position, coin);
56                if index < self.state.len() {
57                    self.state[index] = amplitude;
58                }
59            }
60        }
61    }
62
63    /// Perform one step of the quantum walk
64    pub fn step(&mut self) {
65        // Apply coin operator
66        self.apply_coin();
67
68        // Apply shift operator
69        self.apply_shift();
70    }
71
72    /// Get position probabilities
73    pub fn position_probabilities(&self) -> Vec<f64> {
74        let mut probs = vec![0.0; self.graph.num_vertices];
75
76        for (vertex, prob) in probs.iter_mut().enumerate() {
77            for coin in 0..self.coin_dimension {
78                let idx = self.state_index(vertex, coin);
79                if idx < self.state.len() {
80                    *prob += self.state[idx].norm_sqr();
81                }
82            }
83        }
84
85        probs
86    }
87
88    /// Get the index in the state vector for (vertex, coin) pair
89    pub(crate) const fn state_index(&self, vertex: usize, coin: usize) -> usize {
90        vertex * self.coin_dimension + coin
91    }
92
93    /// Apply the coin operator
94    fn apply_coin(&mut self) {
95        match &self.coin_operator {
96            CoinOperator::Hadamard => self.apply_hadamard_coin(),
97            CoinOperator::Grover => self.apply_grover_coin(),
98            CoinOperator::DFT => self.apply_dft_coin(),
99            CoinOperator::Custom(matrix) => self.apply_custom_coin(matrix.clone()),
100        }
101    }
102
103    /// Apply Hadamard coin
104    fn apply_hadamard_coin(&mut self) {
105        let h = 1.0 / std::f64::consts::SQRT_2;
106
107        for vertex in 0..self.graph.num_vertices {
108            if self.coin_dimension == 2 {
109                let idx0 = self.state_index(vertex, 0);
110                let idx1 = self.state_index(vertex, 1);
111
112                if idx1 < self.state.len() {
113                    let a0 = self.state[idx0];
114                    let a1 = self.state[idx1];
115
116                    self.state[idx0] = h * (a0 + a1);
117                    self.state[idx1] = h * (a0 - a1);
118                }
119            }
120        }
121    }
122
123    /// Apply Grover coin
124    fn apply_grover_coin(&mut self) {
125        // Grover coin: 2|s><s| - I, where |s> is uniform superposition
126        for vertex in 0..self.graph.num_vertices {
127            let degree = self.graph.degree(vertex);
128            if degree <= 1 {
129                continue; // No coin needed for degree 0 or 1
130            }
131
132            // Calculate sum of amplitudes for this vertex
133            let mut sum = Complex64::new(0.0, 0.0);
134            for coin in 0..degree.min(self.coin_dimension) {
135                let idx = self.state_index(vertex, coin);
136                if idx < self.state.len() {
137                    sum += self.state[idx];
138                }
139            }
140
141            // Apply Grover coin
142            let factor = Complex64::new(2.0 / degree as f64, 0.0);
143            for coin in 0..degree.min(self.coin_dimension) {
144                let idx = self.state_index(vertex, coin);
145                if idx < self.state.len() {
146                    let old_amp = self.state[idx];
147                    self.state[idx] = factor * sum - old_amp;
148                }
149            }
150        }
151    }
152
153    /// Apply DFT coin
154    fn apply_dft_coin(&mut self) {
155        // DFT coin for 2-dimensional coin space
156        if self.coin_dimension == 2 {
157            self.apply_hadamard_coin(); // DFT is same as Hadamard for 2D
158        }
159        // For higher dimensions, would implement full DFT
160    }
161
162    /// Apply custom coin operator
163    fn apply_custom_coin(&mut self, matrix: Array2<Complex64>) {
164        if matrix.shape() != [self.coin_dimension, self.coin_dimension] {
165            return; // Matrix size mismatch
166        }
167
168        for vertex in 0..self.graph.num_vertices {
169            let mut coin_state = vec![Complex64::new(0.0, 0.0); self.coin_dimension];
170
171            // Extract coin state for this vertex
172            for (coin, cs) in coin_state.iter_mut().enumerate() {
173                let idx = self.state_index(vertex, coin);
174                if idx < self.state.len() {
175                    *cs = self.state[idx];
176                }
177            }
178
179            // Apply coin operator
180            let new_coin_state = matrix.dot(&Array1::from(coin_state));
181
182            // Write back
183            for coin in 0..self.coin_dimension {
184                let idx = self.state_index(vertex, coin);
185                if idx < self.state.len() {
186                    self.state[idx] = new_coin_state[coin];
187                }
188            }
189        }
190    }
191
192    /// Apply the shift operator
193    fn apply_shift(&mut self) {
194        let mut new_state = vec![Complex64::new(0.0, 0.0); self.hilbert_dim];
195
196        for vertex in 0..self.graph.num_vertices {
197            for (coin, &neighbor) in self.graph.edges[vertex].iter().enumerate() {
198                if coin < self.coin_dimension {
199                    let from_idx = self.state_index(vertex, coin);
200
201                    // Find which coin state corresponds to coming from 'vertex' at 'neighbor'
202                    let to_coin = self.graph.edges[neighbor]
203                        .iter()
204                        .position(|&v| v == vertex)
205                        .unwrap_or(0);
206
207                    if to_coin < self.coin_dimension && from_idx < self.state.len() {
208                        let to_idx = self.state_index(neighbor, to_coin);
209                        if to_idx < new_state.len() {
210                            new_state[to_idx] = self.state[from_idx];
211                        }
212                    }
213                }
214            }
215        }
216
217        self.state.copy_from_slice(&new_state);
218    }
219}