1use std::collections::{HashMap, HashSet};
6
7use itertools::Itertools;
8
9use crate::{
10 error::{KetError, Result},
11 ir::qubit::{LogicalQubit, Qubit},
12 prelude::QuantumGate,
13 process::Process,
14};
15
16#[derive(Debug, Default)]
17pub(super) struct AuxQubit {
18 pub count: usize,
20
21 id_count: usize,
23
24 pub open_alloc: usize,
26
27 pub open_clean_alloc: usize,
28
29 pub alloc_stack: Vec<(usize, usize)>,
31
32 pub being_used: usize,
34
35 pub registry: HashMap<usize, (Vec<LogicalQubit>, Option<Vec<LogicalQubit>>)>,
37
38 registry_rev: HashMap<LogicalQubit, usize>,
40
41 pub using: HashSet<LogicalQubit>,
43
44 pub state: Vec<AuxState>,
45 pub blocked_qubits: Vec<HashSet<LogicalQubit>>,
46 pub blocked_qubits_undo: Vec<HashSet<LogicalQubit>>,
47
48 pub is_permutation: usize,
49 pub is_diagonal: usize,
50}
51
52#[derive(Debug, Clone, Copy, Default)]
53pub(super) enum AuxState {
54 #[default]
55 Begin,
56 Mid,
57 Undo,
58}
59
60impl AuxQubit {
61 pub fn register_alloc(
62 &mut self,
63 aux_qubits: Vec<LogicalQubit>,
64 interacting_qubits: Option<Vec<LogicalQubit>>,
65 next_gate_queue: usize,
66 ) -> usize {
67 let id = self.id_count;
68 self.open_alloc += 1;
69 if interacting_qubits.is_none() {
70 self.open_clean_alloc += 1;
71 }
72 self.id_count += 1;
73
74 self.count += aux_qubits.len();
75 self.being_used += aux_qubits.len();
76
77 for a in &aux_qubits {
78 self.registry_rev.insert(*a, id);
79 }
80
81 self.registry.insert(id, (aux_qubits, interacting_qubits));
82
83 self.alloc_stack.push((id, next_gate_queue));
84
85 id
86 }
87
88 pub fn register_free(
89 &mut self,
90 aux_qubits: Vec<LogicalQubit>,
91 is_clean: bool,
92 main_qubits: HashSet<LogicalQubit>,
93 ) {
94 if is_clean {
95 self.open_clean_alloc -= 1;
96 }
97 self.being_used -= aux_qubits.len();
98 for a in &aux_qubits {
99 self.registry_rev.remove(a);
100 }
101 self.open_alloc -= 1;
102 if self.open_alloc == 0 {
103 self.using.clear();
104 } else {
105 self.using.extend(main_qubits);
106 }
107 }
108
109 pub fn is_dirty(&self, qubit: &LogicalQubit) -> bool {
110 self.registry_rev
111 .get(qubit)
112 .is_some_and(|id| self.registry.get(id).unwrap().1.is_some())
113 }
114
115 pub fn validate_gate(
116 &mut self,
117 gate: &QuantumGate,
118 target: &LogicalQubit,
119 ctrl: &[LogicalQubit],
120 ) -> Result<()> {
121 if let Some(aux_state) = self.state.last() {
122 match aux_state {
123 AuxState::Begin => {
124 if self.is_permutation == 0 && !gate.is_permutation() {
125 return Err(KetError::UncomputeFaill);
126 }
127 }
128 AuxState::Mid => {
129 if target.is_aux() && !self.is_dirty(target) && !gate.is_diagonal() {
130 return Err(KetError::UncomputeFaill);
131 } else if !gate.is_diagonal() {
132 self.blocked_qubits.last_mut().unwrap().insert(*target);
133 }
134 }
135 AuxState::Undo => {
136 if self.is_permutation == 0 && !gate.is_permutation() {
137 return Err(KetError::UncomputeFaill);
138 }
139 if ctrl
140 .iter()
141 .chain([target])
142 .any(|q| self.blocked_qubits.last().unwrap().contains(q))
143 {
144 return Err(KetError::UncomputeFaill);
145 }
146
147 if !gate.is_diagonal() {
148 self.blocked_qubits_undo.last_mut().unwrap().insert(*target);
149 }
150 }
151 }
152 }
153
154 Ok(())
155 }
156}
157
158impl Process {
159 pub fn alloc_aux(
160 &mut self,
161 num_qubits: usize,
162 interacting_qubits: Option<&[LogicalQubit]>,
163 ) -> Result<(Vec<LogicalQubit>, usize)> {
164 let num_qubits_needed = if let Some(interacting_qubits) = interacting_qubits {
165 interacting_qubits.len()
166 } else {
167 self.allocated_qubits
168 } + num_qubits
169 + self.aux.using.len()
170 + self.aux.being_used;
171
172 if num_qubits_needed > self.execution_target.num_qubits {
173 return Err(KetError::MaxQubitsReached);
174 }
175
176 let result: Vec<_> = (0..num_qubits)
177 .map(|index| LogicalQubit::aux(index + self.aux.count))
178 .collect();
179
180 let id = self.aux.register_alloc(
181 result.clone(),
182 interacting_qubits.map(|iq| iq.to_owned()),
183 self.gate_queue.len(),
184 );
185
186 Ok((result, id))
187 }
188
189 pub fn free_aux(&mut self, group_id: usize) {
190 let gate_until = if let Some((next_id, gate_until)) = self.aux.alloc_stack.pop() {
191 assert!(next_id == group_id);
192 gate_until
193 } else {
194 panic!("No aux qubits to free")
195 };
196
197 let (aux_qubits, interacting_qubits) = self.aux.registry.remove(&group_id).unwrap();
198
199 let mut allocated = HashSet::new();
200
201 for aux_qubit in &aux_qubits {
202 let main_qubit = if let Some(interacting_qubits) = &interacting_qubits {
203 let mut main_qubit = None;
204
205 for candidate_qubit in (0..self.execution_target.num_qubits)
206 .map(LogicalQubit::main)
207 .sorted_by_key(|q| self.logical_circuit.qubit_depth.get(q).unwrap_or(&0))
208 {
209 if !allocated.contains(&candidate_qubit)
210 && !interacting_qubits.contains(&candidate_qubit)
211 && !self.aux.using.contains(&candidate_qubit)
212 {
213 main_qubit = Some(candidate_qubit);
214 break;
215 }
216 }
217 main_qubit.unwrap()
218 } else {
219 let mut main_qubit = None;
220 for candidate_qubit in (self.allocated_qubits..self.execution_target.num_qubits)
221 .map(LogicalQubit::main)
222 .sorted_by_key(|q| self.logical_circuit.qubit_depth.get(q).unwrap_or(&0))
223 {
224 if !allocated.contains(&candidate_qubit)
225 && !self.aux.using.contains(&candidate_qubit)
226 {
227 main_qubit = Some(candidate_qubit);
228 break;
229 }
230 }
231 main_qubit.unwrap()
232 };
233
234 allocated.insert(main_qubit);
235 self.logical_circuit.alloc_aux_qubit(*aux_qubit, main_qubit);
236 }
237
238 self.execute_gate_queue(gate_until);
239 for q in &aux_qubits {
240 self.valid_qubit.insert(*q, false);
241 }
242
243 self.aux
244 .register_free(aux_qubits, interacting_qubits.is_none(), allocated);
245 }
246
247 pub fn is_diagonal_begin(&mut self) {
248 self.aux.is_diagonal += 1;
249 }
250
251 pub fn is_diagonal_end(&mut self) {
252 self.aux.is_diagonal -= 1;
253 }
254
255 pub fn is_permutation_begin(&mut self) {
256 self.aux.is_permutation += 1;
257 }
258
259 pub fn is_permutation_end(&mut self) {
260 self.aux.is_permutation -= 1;
261 }
262
263 pub fn around_begin(&mut self) {
264 if self.aux.open_clean_alloc != 0 {
265 self.aux.state.push(AuxState::Begin);
266 }
267 }
268
269 pub fn around_mid(&mut self) {
270 if let Some(state) = self.aux.state.last_mut() {
271 *state = match state {
272 AuxState::Begin => {
273 self.aux.blocked_qubits.push(Default::default());
274 AuxState::Mid
275 }
276 state => unreachable!("around_mid: unreachable state={:?}", state),
277 };
278 }
279 }
280
281 pub fn around_undo(&mut self) {
282 if let Some(state) = self.aux.state.last_mut() {
283 *state = match state {
284 AuxState::Mid => {
285 self.aux.blocked_qubits_undo.push(Default::default());
286 AuxState::Undo
287 }
288 state => unreachable!("around_undo: unreachable state={:?}", state),
289 };
290 }
291 }
292
293 pub fn around_end(&mut self) {
294 if let Some(state) = self.aux.state.pop() {
295 match state {
296 AuxState::Undo => {
297 let blocked_qubits = self.aux.blocked_qubits.pop().unwrap();
298 let blocked_qubits_undo = self.aux.blocked_qubits_undo.pop().unwrap();
299
300 if let Some(blocked_qubits_last) = self.aux.blocked_qubits.last_mut() {
301 blocked_qubits_last.extend(blocked_qubits);
302 blocked_qubits_last.extend(blocked_qubits_undo);
303 }
304 }
305 state => unreachable!("around_end: unreachable state={:?}", state),
306 };
307 }
308 }
309}