ket/process/
aux.rs

1// SPDX-FileCopyrightText: 2025 Evandro Chagas Ribeiro da Rosa <evandro@quantuloop.com>
2//
3// SPDX-License-Identifier: Apache-2.0
4
5use std::collections::{HashMap, HashSet};
6
7use itertools::Itertools;
8
9use crate::{
10    error::{KetError, Result},
11    ir::qubit::{LogicalQubit, Qubit},
12    prelude::QuantumGate,
13    process::Process,
14};
15
16#[derive(Debug, Default)]
17pub(super) struct AuxQubit {
18    /// Number of auxiliary allocated
19    pub count: usize,
20
21    /// Number of auxiliary group allocated
22    id_count: usize,
23
24    /// Number of aux group allocate and not free
25    pub open_alloc: usize,
26
27    pub open_clean_alloc: usize,
28
29    /// Top stack (Next to be free, Can execute gate until)
30    pub alloc_stack: Vec<(usize, usize)>,
31
32    /// Number of qubits allocated and not free
33    pub being_used: usize,
34
35    /// Map group id into (aux, interaction group)
36    pub registry: HashMap<usize, (Vec<LogicalQubit>, Option<Vec<LogicalQubit>>)>,
37
38    /// Map qubit to its group id
39    registry_rev: HashMap<LogicalQubit, usize>,
40
41    /// Main qubits being used as aux
42    pub using: HashSet<LogicalQubit>,
43
44    pub state: Vec<AuxState>,
45    pub blocked_qubits: Vec<HashSet<LogicalQubit>>,
46    pub blocked_qubits_undo: Vec<HashSet<LogicalQubit>>,
47
48    pub is_permutation: usize,
49    pub is_diagonal: usize,
50}
51
52#[derive(Debug, Clone, Copy, Default)]
53pub(super) enum AuxState {
54    #[default]
55    Begin,
56    Mid,
57    Undo,
58}
59
60impl AuxQubit {
61    pub fn register_alloc(
62        &mut self,
63        aux_qubits: Vec<LogicalQubit>,
64        interacting_qubits: Option<Vec<LogicalQubit>>,
65        next_gate_queue: usize,
66    ) -> usize {
67        let id = self.id_count;
68        self.open_alloc += 1;
69        if interacting_qubits.is_none() {
70            self.open_clean_alloc += 1;
71        }
72        self.id_count += 1;
73
74        self.count += aux_qubits.len();
75        self.being_used += aux_qubits.len();
76
77        for a in &aux_qubits {
78            self.registry_rev.insert(*a, id);
79        }
80
81        self.registry.insert(id, (aux_qubits, interacting_qubits));
82
83        self.alloc_stack.push((id, next_gate_queue));
84
85        id
86    }
87
88    pub fn register_free(
89        &mut self,
90        aux_qubits: Vec<LogicalQubit>,
91        is_clean: bool,
92        main_qubits: HashSet<LogicalQubit>,
93    ) {
94        if is_clean {
95            self.open_clean_alloc -= 1;
96        }
97        self.being_used -= aux_qubits.len();
98        for a in &aux_qubits {
99            self.registry_rev.remove(a);
100        }
101        self.open_alloc -= 1;
102        if self.open_alloc == 0 {
103            self.using.clear();
104        } else {
105            self.using.extend(main_qubits);
106        }
107    }
108
109    pub fn is_dirty(&self, qubit: &LogicalQubit) -> bool {
110        self.registry_rev
111            .get(qubit)
112            .is_some_and(|id| self.registry.get(id).unwrap().1.is_some())
113    }
114
115    pub fn validate_gate(
116        &mut self,
117        gate: &QuantumGate,
118        target: &LogicalQubit,
119        ctrl: &[LogicalQubit],
120    ) -> Result<()> {
121        if let Some(aux_state) = self.state.last() {
122            match aux_state {
123                AuxState::Begin => {
124                    if self.is_permutation == 0 && !gate.is_permutation() {
125                        return Err(KetError::UncomputeFaill);
126                    }
127                }
128                AuxState::Mid => {
129                    if target.is_aux() && !self.is_dirty(target) && !gate.is_diagonal() {
130                        return Err(KetError::UncomputeFaill);
131                    } else if !gate.is_diagonal() {
132                        self.blocked_qubits.last_mut().unwrap().insert(*target);
133                    }
134                }
135                AuxState::Undo => {
136                    if self.is_permutation == 0 && !gate.is_permutation() {
137                        return Err(KetError::UncomputeFaill);
138                    }
139                    if ctrl
140                        .iter()
141                        .chain([target])
142                        .any(|q| self.blocked_qubits.last().unwrap().contains(q))
143                    {
144                        return Err(KetError::UncomputeFaill);
145                    }
146
147                    if !gate.is_diagonal() {
148                        self.blocked_qubits_undo.last_mut().unwrap().insert(*target);
149                    }
150                }
151            }
152        }
153
154        Ok(())
155    }
156}
157
158impl Process {
159    pub fn alloc_aux(
160        &mut self,
161        num_qubits: usize,
162        interacting_qubits: Option<&[LogicalQubit]>,
163    ) -> Result<(Vec<LogicalQubit>, usize)> {
164        let num_qubits_needed = if let Some(interacting_qubits) = interacting_qubits {
165            interacting_qubits.len()
166        } else {
167            self.allocated_qubits
168        } + num_qubits
169            + self.aux.using.len()
170            + self.aux.being_used;
171
172        if num_qubits_needed > self.execution_target.num_qubits {
173            return Err(KetError::MaxQubitsReached);
174        }
175
176        let result: Vec<_> = (0..num_qubits)
177            .map(|index| LogicalQubit::aux(index + self.aux.count))
178            .collect();
179
180        let id = self.aux.register_alloc(
181            result.clone(),
182            interacting_qubits.map(|iq| iq.to_owned()),
183            self.gate_queue.len(),
184        );
185
186        Ok((result, id))
187    }
188
189    pub fn free_aux(&mut self, group_id: usize) {
190        let gate_until = if let Some((next_id, gate_until)) = self.aux.alloc_stack.pop() {
191            assert!(next_id == group_id);
192            gate_until
193        } else {
194            panic!("No aux qubits to free")
195        };
196
197        let (aux_qubits, interacting_qubits) = self.aux.registry.remove(&group_id).unwrap();
198
199        let mut allocated = HashSet::new();
200
201        for aux_qubit in &aux_qubits {
202            let main_qubit = if let Some(interacting_qubits) = &interacting_qubits {
203                let mut main_qubit = None;
204
205                for candidate_qubit in (0..self.execution_target.num_qubits)
206                    .map(LogicalQubit::main)
207                    .sorted_by_key(|q| self.logical_circuit.qubit_depth.get(q).unwrap_or(&0))
208                {
209                    if !allocated.contains(&candidate_qubit)
210                        && !interacting_qubits.contains(&candidate_qubit)
211                        && !self.aux.using.contains(&candidate_qubit)
212                    {
213                        main_qubit = Some(candidate_qubit);
214                        break;
215                    }
216                }
217                main_qubit.unwrap()
218            } else {
219                let mut main_qubit = None;
220                for candidate_qubit in (self.allocated_qubits..self.execution_target.num_qubits)
221                    .map(LogicalQubit::main)
222                    .sorted_by_key(|q| self.logical_circuit.qubit_depth.get(q).unwrap_or(&0))
223                {
224                    if !allocated.contains(&candidate_qubit)
225                        && !self.aux.using.contains(&candidate_qubit)
226                    {
227                        main_qubit = Some(candidate_qubit);
228                        break;
229                    }
230                }
231                main_qubit.unwrap()
232            };
233
234            allocated.insert(main_qubit);
235            self.logical_circuit.alloc_aux_qubit(*aux_qubit, main_qubit);
236        }
237
238        self.execute_gate_queue(gate_until);
239        for q in &aux_qubits {
240            self.valid_qubit.insert(*q, false);
241        }
242
243        self.aux
244            .register_free(aux_qubits, interacting_qubits.is_none(), allocated);
245    }
246
247    pub fn is_diagonal_begin(&mut self) {
248        self.aux.is_diagonal += 1;
249    }
250
251    pub fn is_diagonal_end(&mut self) {
252        self.aux.is_diagonal -= 1;
253    }
254
255    pub fn is_permutation_begin(&mut self) {
256        self.aux.is_permutation += 1;
257    }
258
259    pub fn is_permutation_end(&mut self) {
260        self.aux.is_permutation -= 1;
261    }
262
263    pub fn around_begin(&mut self) {
264        if self.aux.open_clean_alloc != 0 {
265            self.aux.state.push(AuxState::Begin);
266        }
267    }
268
269    pub fn around_mid(&mut self) {
270        if let Some(state) = self.aux.state.last_mut() {
271            *state = match state {
272                AuxState::Begin => {
273                    self.aux.blocked_qubits.push(Default::default());
274                    AuxState::Mid
275                }
276                state => unreachable!("around_mid: unreachable state={:?}", state),
277            };
278        }
279    }
280
281    pub fn around_undo(&mut self) {
282        if let Some(state) = self.aux.state.last_mut() {
283            *state = match state {
284                AuxState::Mid => {
285                    self.aux.blocked_qubits_undo.push(Default::default());
286                    AuxState::Undo
287                }
288                state => unreachable!("around_undo: unreachable state={:?}", state),
289            };
290        }
291    }
292
293    pub fn around_end(&mut self) {
294        if let Some(state) = self.aux.state.pop() {
295            match state {
296                AuxState::Undo => {
297                    let blocked_qubits = self.aux.blocked_qubits.pop().unwrap();
298                    let blocked_qubits_undo = self.aux.blocked_qubits_undo.pop().unwrap();
299
300                    if let Some(blocked_qubits_last) = self.aux.blocked_qubits.last_mut() {
301                        blocked_qubits_last.extend(blocked_qubits);
302                        blocked_qubits_last.extend(blocked_qubits_undo);
303                    }
304                }
305                state => unreachable!("around_end: unreachable state={:?}", state),
306            };
307        }
308    }
309}