quantrs2_core/decomposition/
solovay_kitaev.rs1use crate::error::{QuantRS2Error, QuantRS2Result};
8use crate::gate::{single::*, GateOp};
9use crate::matrix_ops::{matrices_approx_equal, DenseMatrix, QuantumMatrix};
10use crate::qubit::QubitId;
11use rustc_hash::FxHashMap;
12use scirs2_core::ndarray::{Array2, ArrayView2};
13use scirs2_core::Complex64;
14use smallvec::SmallVec;
15
16#[derive(Debug, Clone)]
18pub struct SolovayKitaevConfig {
19 pub max_depth: usize,
21 pub epsilon: f64,
23 pub base_set: BaseGateSet,
25 pub cache_limit: usize,
27}
28
29impl Default for SolovayKitaevConfig {
30 fn default() -> Self {
31 Self {
32 max_depth: 10,
33 epsilon: 1e-3,
34 base_set: BaseGateSet::CliffordT,
35 cache_limit: 10000,
36 }
37 }
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum BaseGateSet {
43 CliffordT,
45 VBasis,
47 Custom,
49}
50
51pub type GateSequence = SmallVec<[Box<dyn GateOp>; 8]>;
53
54#[derive(Debug)]
56pub struct GateSequenceWithMatrix {
57 pub sequence: GateSequence,
59 pub matrix: Array2<Complex64>,
61 pub cost: usize,
63}
64
65pub struct SolovayKitaev {
67 config: SolovayKitaevConfig,
68 sequence_cache: Vec<Vec<GateSequenceWithMatrix>>,
70 #[allow(dead_code)]
72 lookup_table: FxHashMap<u64, Vec<usize>>,
73}
74
75impl SolovayKitaev {
76 pub fn new(config: SolovayKitaevConfig) -> Self {
78 let max_depth = config.max_depth;
79 let mut sequence_cache = Vec::with_capacity(max_depth + 1);
80 for _ in 0..=max_depth {
81 sequence_cache.push(Vec::new());
82 }
83
84 let mut sk = Self {
85 config,
86 sequence_cache,
87 lookup_table: FxHashMap::default(),
88 };
89 sk.initialize_base_sequences();
90 sk
91 }
92
93 fn initialize_base_sequences(&mut self) {
95 let qubit = QubitId(0);
96 let base_gates = match self.config.base_set {
97 BaseGateSet::CliffordT => {
98 vec![
99 self.create_sequence_with_matrix(vec![Box::new(Hadamard { target: qubit })]),
101 self.create_sequence_with_matrix(vec![Box::new(Phase { target: qubit })]),
102 self.create_sequence_with_matrix(vec![Box::new(T { target: qubit })]),
103 self.create_sequence_with_matrix(vec![
105 Box::new(Hadamard { target: qubit }),
106 Box::new(Phase { target: qubit }),
107 ]),
108 self.create_sequence_with_matrix(vec![
109 Box::new(Phase { target: qubit }),
110 Box::new(Hadamard { target: qubit }),
111 ]),
112 ]
113 }
114 BaseGateSet::VBasis => {
115 vec![
116 self.create_sequence_with_matrix(vec![Box::new(Hadamard { target: qubit })]),
117 self.create_sequence_with_matrix(vec![Box::new(T { target: qubit })]),
118 self.create_sequence_with_matrix(vec![Box::new(PauliX { target: qubit })]),
119 ]
120 }
121 BaseGateSet::Custom => {
122 vec![]
124 }
125 };
126
127 for seq in base_gates {
129 if let Ok(seq) = seq {
130 self.sequence_cache[0].push(seq);
131 }
132 }
133
134 for level in 1..=self.config.max_depth.min(3) {
136 self.build_sequences_at_level(level);
137 }
138 }
139
140 fn create_sequence_with_matrix(
142 &self,
143 gates: Vec<Box<dyn GateOp>>,
144 ) -> QuantRS2Result<GateSequenceWithMatrix> {
145 let mut matrix = Array2::eye(2);
146 let mut cost = 0;
147
148 for gate in &gates {
149 let gate_matrix = gate.matrix()?;
150 let gate_array = Array2::from_shape_vec((2, 2), gate_matrix)
151 .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))?;
152 matrix = matrix.dot(&gate_array);
153
154 if gate.name() == "T" || gate.name() == "T†" {
156 cost += 1;
157 }
158 }
159
160 Ok(GateSequenceWithMatrix {
161 sequence: SmallVec::from_vec(gates),
162 matrix,
163 cost,
164 })
165 }
166
167 fn build_sequences_at_level(&mut self, level: usize) {
169 if level == 0 || level > self.config.max_depth {
170 return;
171 }
172
173 let mut new_sequences = Vec::new();
174
175 for i in 0..level {
177 let j = level - 1 - i;
178 let seq1_count = self.sequence_cache[i].len();
179 let seq2_count = self.sequence_cache[j].len();
180
181 for idx1 in 0..seq1_count {
182 for idx2 in 0..seq2_count {
183 let seq1 = &self.sequence_cache[i][idx1];
184 let seq2 = &self.sequence_cache[j][idx2];
185
186 let mut combined = SmallVec::new();
188 combined.extend(seq1.sequence.iter().map(|g| g.clone()));
189 combined.extend(seq2.sequence.iter().map(|g| g.clone()));
190
191 let matrix = seq1.matrix.dot(&seq2.matrix);
192 let cost = seq1.cost + seq2.cost;
193
194 let new_seq = GateSequenceWithMatrix {
195 sequence: combined,
196 matrix,
197 cost,
198 };
199
200 if self.should_add_sequence(&new_seq, &new_sequences) {
202 new_sequences.push(new_seq);
203 }
204
205 if new_sequences.len() >= self.config.cache_limit / 10 {
207 break;
208 }
209 }
210 if new_sequences.len() >= self.config.cache_limit / 10 {
211 break;
212 }
213 }
214 }
215
216 self.sequence_cache[level] = new_sequences;
217 }
218
219 fn should_add_sequence(
221 &self,
222 new_seq: &GateSequenceWithMatrix,
223 existing: &[GateSequenceWithMatrix],
224 ) -> bool {
225 for seq in existing {
227 if matrices_approx_equal(&new_seq.matrix.view(), &seq.matrix.view(), 1e-10)
228 && seq.cost <= new_seq.cost
229 {
230 return false;
231 }
232 }
233 true
234 }
235
236 pub fn approximate(&mut self, target: &ArrayView2<Complex64>) -> QuantRS2Result<GateSequence> {
238 if target.shape() != &[2, 2] {
239 return Err(QuantRS2Error::InvalidInput(
240 "Target must be a 2x2 unitary matrix".to_string(),
241 ));
242 }
243
244 let target_dense = DenseMatrix::new(target.to_owned())?;
246 if !target_dense.is_unitary(1e-10)? {
247 return Err(QuantRS2Error::InvalidInput(
248 "Target matrix is not unitary".to_string(),
249 ));
250 }
251
252 let depth = self.calculate_required_depth();
254 self.approximate_recursive(target, depth)
255 }
256
257 fn calculate_required_depth(&self) -> usize {
259 let log_inv_eps = (1.0 / self.config.epsilon).ln();
262 let depth = (log_inv_eps * 2.0) as usize;
263 depth.min(self.config.max_depth)
264 }
265
266 fn approximate_recursive(
268 &mut self,
269 target: &ArrayView2<Complex64>,
270 depth: usize,
271 ) -> QuantRS2Result<GateSequence> {
272 if depth == 0 {
274 return self.find_closest_base_sequence(target);
275 }
276
277 let u_n_minus_1 = self.approximate_recursive(target, depth - 1)?;
279 let u_n_minus_1_matrix = self.compute_sequence_matrix(&u_n_minus_1)?;
280
281 let error = target.to_owned() - &u_n_minus_1_matrix;
283
284 let mut error_norm = 0.0;
286 for val in &error {
287 error_norm += val.norm_sqr();
288 }
289 let error_norm = error_norm.sqrt();
290
291 if error_norm < self.config.epsilon {
293 return Ok(u_n_minus_1);
294 }
295
296 self.group_commutator_correction(target, &u_n_minus_1, &u_n_minus_1_matrix, depth)
298 }
299
300 fn find_closest_base_sequence(
302 &self,
303 target: &ArrayView2<Complex64>,
304 ) -> QuantRS2Result<GateSequence> {
305 let mut best_sequence = None;
306 let mut best_distance = f64::INFINITY;
307
308 for sequences in &self.sequence_cache {
310 for seq in sequences {
311 let diff = target.to_owned() - &seq.matrix;
312
313 let mut distance = 0.0;
315 for val in &diff {
316 distance += val.norm_sqr();
317 }
318 let distance = distance.sqrt();
319
320 if distance < best_distance {
321 best_distance = distance;
322 best_sequence = Some(seq.sequence.iter().map(|g| g.clone()).collect());
323 }
324 }
325 }
326
327 best_sequence.ok_or_else(|| {
328 QuantRS2Error::ComputationError("No base sequences available".to_string())
329 })
330 }
331
332 fn group_commutator_correction(
334 &self,
335 target: &ArrayView2<Complex64>,
336 base_seq: &GateSequence,
337 base_matrix: &Array2<Complex64>,
338 depth: usize,
339 ) -> QuantRS2Result<GateSequence> {
340 let error = target.to_owned() - base_matrix;
342 let trace = error[[0, 0]] + error[[1, 1]];
343 let angle = (trace.re / 2.0).acos();
344
345 let (v_seq, w_seq) = self.find_commutator_sequences(angle, depth - 1)?;
347
348 let mut result = SmallVec::new();
350 result.extend(v_seq.iter().map(|g| g.clone()));
351 result.extend(w_seq.iter().map(|g| g.clone()));
352 result.extend(self.compute_inverse_sequence(&v_seq)?);
353 result.extend(self.compute_inverse_sequence(&w_seq)?);
354 result.extend(base_seq.iter().map(|g| g.clone()));
355
356 Ok(result)
357 }
358
359 fn find_commutator_sequences(
361 &self,
362 angle: f64,
363 _depth: usize,
364 ) -> QuantRS2Result<(GateSequence, GateSequence)> {
365 let qubit = QubitId(0);
368
369 let small_angle = angle / 4.0;
371
372 let v = vec![Box::new(RotationZ {
373 target: qubit,
374 theta: small_angle,
375 }) as Box<dyn GateOp>];
376
377 let w = vec![Box::new(RotationY {
378 target: qubit,
379 theta: small_angle,
380 }) as Box<dyn GateOp>];
381
382 Ok((SmallVec::from_vec(v), SmallVec::from_vec(w)))
383 }
384
385 fn compute_sequence_matrix(
387 &self,
388 sequence: &GateSequence,
389 ) -> QuantRS2Result<Array2<Complex64>> {
390 let mut matrix = Array2::eye(2);
391
392 for gate in sequence {
393 let gate_matrix = gate.matrix()?;
394 let gate_array = Array2::from_shape_vec((2, 2), gate_matrix)
395 .map_err(|e| QuantRS2Error::InvalidInput(e.to_string()))?;
396 matrix = matrix.dot(&gate_array);
397 }
398
399 Ok(matrix)
400 }
401
402 fn compute_inverse_sequence(&self, sequence: &GateSequence) -> QuantRS2Result<GateSequence> {
404 let mut inverse = SmallVec::new();
405
406 for gate in sequence.iter().rev() {
408 inverse.push(self.invert_gate(gate.as_ref())?);
409 }
410
411 Ok(inverse)
412 }
413
414 fn invert_gate(&self, gate: &dyn GateOp) -> QuantRS2Result<Box<dyn GateOp>> {
416 let qubit = gate.qubits()[0]; match gate.name() {
419 "H" => Ok(Box::new(Hadamard { target: qubit })), "X" => Ok(Box::new(PauliX { target: qubit })), "Y" => Ok(Box::new(PauliY { target: qubit })), "Z" => Ok(Box::new(PauliZ { target: qubit })), "S" => Ok(Box::new(PhaseDagger { target: qubit })),
424 "S†" => Ok(Box::new(Phase { target: qubit })),
425 "T" => Ok(Box::new(TDagger { target: qubit })),
426 "T†" => Ok(Box::new(T { target: qubit })),
427 _ => {
428 if gate.is_parameterized() {
430 Err(QuantRS2Error::UnsupportedOperation(format!(
432 "Cannot invert parameterized gate {}",
433 gate.name()
434 )))
435 } else {
436 Err(QuantRS2Error::UnsupportedOperation(format!(
437 "Cannot invert gate {}",
438 gate.name()
439 )))
440 }
441 }
442 }
443 }
444}
445
446pub fn count_t_gates(sequence: &GateSequence) -> usize {
448 sequence
449 .iter()
450 .filter(|g| g.name() == "T" || g.name() == "T†")
451 .count()
452}
453
454pub fn optimize_sequence(sequence: GateSequence) -> GateSequence {
456 let mut optimized = SmallVec::new();
457 let mut i = 0;
458
459 while i < sequence.len() {
460 if i + 1 < sequence.len() {
461 let gate1 = &sequence[i];
462 let gate2 = &sequence[i + 1];
463
464 if gate1.qubits() == gate2.qubits() {
466 let combined = match (gate1.name(), gate2.name()) {
467 ("S", "S") | ("S†", "S†") => Some("Z"),
468 ("S", "S†") | ("S†", "S") | ("T", "T†") | ("T†", "T") | ("H", "H") => {
469 None
470 } _ => Some(""), };
473
474 match combined {
475 None => {
476 i += 2;
478 }
479 Some("Z") => {
480 optimized.push(Box::new(PauliZ {
481 target: gate1.qubits()[0],
482 }) as Box<dyn GateOp>);
483 i += 2;
484 }
485 _ => {
486 optimized.push(sequence[i].clone());
488 i += 1;
489 }
490 }
491 } else {
492 optimized.push(sequence[i].clone());
493 i += 1;
494 }
495 } else {
496 optimized.push(sequence[i].clone());
497 i += 1;
498 }
499 }
500
501 optimized
502}
503
504#[cfg(test)]
505mod tests {
506 use super::*;
507
508 #[test]
509 fn test_solovay_kitaev_initialization() {
510 let config = SolovayKitaevConfig::default();
511 let sk = SolovayKitaev::new(config);
512
513 assert!(!sk.sequence_cache[0].is_empty());
515 }
516
517 #[test]
518 fn test_t_gate_counting() {
519 let qubit = QubitId(0);
520 let sequence: GateSequence = SmallVec::from_vec(vec![
521 Box::new(T { target: qubit }) as Box<dyn GateOp>,
522 Box::new(Hadamard { target: qubit }),
523 Box::new(TDagger { target: qubit }),
524 Box::new(Phase { target: qubit }),
525 Box::new(T { target: qubit }),
526 ]);
527
528 assert_eq!(count_t_gates(&sequence), 3);
529 }
530
531 #[test]
532 fn test_sequence_optimization() {
533 let qubit = QubitId(0);
534 let sequence: GateSequence = SmallVec::from_vec(vec![
535 Box::new(Hadamard { target: qubit }) as Box<dyn GateOp>,
536 Box::new(Hadamard { target: qubit }),
537 Box::new(Phase { target: qubit }),
538 Box::new(PhaseDagger { target: qubit }),
539 ]);
540
541 let optimized = optimize_sequence(sequence);
542 assert_eq!(optimized.len(), 0); }
544}