1use crate::{
7 error::{QuantRS2Error, QuantRS2Result},
8 gate::{multi::*, single::*, GateOp},
9 qubit::QubitId,
10 zx_calculus::{CircuitToZX, EdgeType, SpiderType, ZXDiagram, ZXOptimizer},
11};
12use rustc_hash::FxHashMap;
13use std::collections::{HashSet, VecDeque};
14use std::f64::consts::PI;
15
16#[derive(Debug, Clone)]
18struct GateLayer {
19 gates: Vec<Box<dyn GateOp>>,
21}
22
23pub struct ZXExtractor {
25 diagram: ZXDiagram,
26 spider_positions: FxHashMap<usize, (usize, usize)>, layers: Vec<GateLayer>,
30}
31
32impl ZXExtractor {
33 pub fn new(diagram: ZXDiagram) -> Self {
35 Self {
36 diagram,
37 spider_positions: FxHashMap::default(),
38 layers: Vec::new(),
39 }
40 }
41
42 pub fn extract_circuit(&mut self) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
44 self.analyze_diagram()?;
46
47 self.extract_gates()?;
49
50 let mut circuit = Vec::new();
52 for layer in &self.layers {
53 circuit.extend(layer.gates.clone());
54 }
55
56 Ok(circuit)
57 }
58
59 fn analyze_diagram(&mut self) -> QuantRS2Result<()> {
61 let inputs = self.diagram.inputs.clone();
63 let outputs = self.diagram.outputs.clone();
64
65 let topo_order = self.topological_sort(&inputs, &outputs)?;
67
68 for (layer_idx, spider_id) in topo_order.iter().enumerate() {
70 self.spider_positions.insert(*spider_id, (layer_idx, 0));
71 }
72
73 Ok(())
74 }
75
76 fn topological_sort(&self, inputs: &[usize], outputs: &[usize]) -> QuantRS2Result<Vec<usize>> {
78 let mut in_degree: FxHashMap<usize, usize> = FxHashMap::default();
79 let mut adjacency: FxHashMap<usize, Vec<usize>> = FxHashMap::default();
80
81 for &spider_id in self.diagram.spiders.keys() {
83 in_degree.insert(spider_id, 0);
84 adjacency.insert(spider_id, Vec::new());
85 }
86
87 for (&source, neighbors) in &self.diagram.adjacency {
89 for &(target, _) in neighbors {
90 if self.is_forward_edge(source, target, inputs, outputs) {
92 *in_degree.get_mut(&target).unwrap() += 1;
93 adjacency.get_mut(&source).unwrap().push(target);
94 }
95 }
96 }
97
98 let mut queue: VecDeque<usize> = inputs.iter().cloned().collect();
100 let mut topo_order = Vec::new();
101
102 while let Some(spider) = queue.pop_front() {
103 topo_order.push(spider);
104
105 if let Some(neighbors) = adjacency.get(&spider) {
106 for &neighbor in neighbors {
107 if let Some(degree) = in_degree.get_mut(&neighbor) {
108 *degree -= 1;
109 if *degree == 0 {
110 queue.push_back(neighbor);
111 }
112 }
113 }
114 }
115 }
116
117 if topo_order.len() != self.diagram.spiders.len() {
118 return Err(QuantRS2Error::InvalidInput(
119 "ZX-diagram contains cycles or disconnected components".to_string(),
120 ));
121 }
122
123 Ok(topo_order)
124 }
125
126 fn is_forward_edge(
128 &self,
129 source: usize,
130 target: usize,
131 inputs: &[usize],
132 outputs: &[usize],
133 ) -> bool {
134 if inputs.contains(&source) && !inputs.contains(&target) {
137 return true;
138 }
139 if outputs.contains(&target) && !outputs.contains(&source) {
140 return true;
141 }
142
143 let source_dist = self.distance_from_inputs(source, inputs);
145 let target_dist = self.distance_from_inputs(target, inputs);
146
147 source_dist < target_dist
148 }
149
150 fn distance_from_inputs(&self, spider: usize, inputs: &[usize]) -> usize {
152 if inputs.contains(&spider) {
153 return 0;
154 }
155
156 let mut visited = HashSet::new();
157 let mut queue = VecDeque::new();
158
159 for &input in inputs {
160 queue.push_back((input, 0));
161 visited.insert(input);
162 }
163
164 while let Some((current, dist)) = queue.pop_front() {
165 if current == spider {
166 return dist;
167 }
168
169 for (neighbor, _) in self.diagram.neighbors(current) {
170 if !visited.contains(&neighbor) {
171 visited.insert(neighbor);
172 queue.push_back((neighbor, dist + 1));
173 }
174 }
175 }
176
177 usize::MAX }
179
180 fn extract_gates(&mut self) -> QuantRS2Result<()> {
182 let qubit_spiders = self.group_by_qubit()?;
184
185 for (qubit, spider_chain) in qubit_spiders {
187 self.extract_single_qubit_gates(&qubit, &spider_chain)?;
188 }
189
190 self.extract_two_qubit_gates()?;
192
193 Ok(())
194 }
195
196 fn group_by_qubit(&self) -> QuantRS2Result<FxHashMap<QubitId, Vec<usize>>> {
198 let mut qubit_spiders: FxHashMap<QubitId, Vec<usize>> = FxHashMap::default();
199
200 for &input_id in &self.diagram.inputs {
202 if let Some(input_spider) = self.diagram.spiders.get(&input_id) {
203 if let Some(qubit) = input_spider.qubit {
204 let chain = self.trace_qubit_line(input_id)?;
205 qubit_spiders.insert(qubit, chain);
206 }
207 }
208 }
209
210 Ok(qubit_spiders)
211 }
212
213 fn trace_qubit_line(&self, start: usize) -> QuantRS2Result<Vec<usize>> {
215 let mut chain = vec![start];
216 let mut current = start;
217 let mut visited = HashSet::new();
218 visited.insert(start);
219
220 while !self.diagram.outputs.contains(¤t) {
222 let neighbors = self.diagram.neighbors(current);
223
224 let next = neighbors
226 .iter()
227 .find(|(id, _)| !visited.contains(id) && self.is_on_qubit_line(*id))
228 .map(|(id, _)| *id);
229
230 if let Some(next_id) = next {
231 chain.push(next_id);
232 visited.insert(next_id);
233 current = next_id;
234 } else {
235 break;
236 }
237 }
238
239 Ok(chain)
240 }
241
242 fn is_on_qubit_line(&self, spider_id: usize) -> bool {
244 if let Some(spider) = self.diagram.spiders.get(&spider_id) {
246 spider.spider_type == SpiderType::Boundary || self.diagram.degree(spider_id) <= 2
247 } else {
248 false
249 }
250 }
251
252 fn extract_single_qubit_gates(
254 &mut self,
255 qubit: &QubitId,
256 spider_chain: &[usize],
257 ) -> QuantRS2Result<()> {
258 let mut i = 0;
259
260 while i < spider_chain.len() {
261 let spider_id = spider_chain[i];
262
263 if let Some(spider) = self.diagram.spiders.get(&spider_id) {
264 match spider.spider_type {
265 SpiderType::Z if spider.phase.abs() > 1e-10 => {
266 let gate: Box<dyn GateOp> = Box::new(RotationZ {
268 target: *qubit,
269 theta: spider.phase,
270 });
271 self.add_gate_to_layer(gate, i);
272 }
273 SpiderType::X if spider.phase.abs() > 1e-10 => {
274 let gate: Box<dyn GateOp> = Box::new(RotationX {
276 target: *qubit,
277 theta: spider.phase,
278 });
279 self.add_gate_to_layer(gate, i);
280 }
281 _ => {}
282 }
283
284 if i + 1 < spider_chain.len() {
286 let next_id = spider_chain[i + 1];
287 let edge_type = self.get_edge_type(spider_id, next_id);
288
289 if edge_type == Some(EdgeType::Hadamard) {
290 let gate: Box<dyn GateOp> = Box::new(Hadamard { target: *qubit });
291 self.add_gate_to_layer(gate, i);
292 }
293 }
294 }
295
296 i += 1;
297 }
298
299 Ok(())
300 }
301
302 fn extract_two_qubit_gates(&mut self) -> QuantRS2Result<()> {
304 let mut processed = HashSet::new();
305
306 for (&spider_id, spider) in &self.diagram.spiders.clone() {
307 if processed.contains(&spider_id) {
308 continue;
309 }
310
311 if self.diagram.degree(spider_id) > 2 {
313 let neighbors = self.diagram.neighbors(spider_id);
315
316 if spider.spider_type == SpiderType::Z && spider.phase.abs() < 1e-10 {
318 for &(neighbor_id, edge_type) in &neighbors {
319 if let Some(neighbor) = self.diagram.spiders.get(&neighbor_id) {
320 if neighbor.spider_type == SpiderType::X
321 && neighbor.phase.abs() < 1e-10
322 && edge_type == EdgeType::Regular
323 {
324 if let (Some(control_qubit), Some(target_qubit)) = (
326 self.get_spider_qubit(spider_id),
327 self.get_spider_qubit(neighbor_id),
328 ) {
329 let gate: Box<dyn GateOp> = Box::new(CNOT {
330 control: control_qubit,
331 target: target_qubit,
332 });
333
334 let layer = self
335 .spider_positions
336 .get(&spider_id)
337 .map(|(l, _)| *l)
338 .unwrap_or(0);
339
340 self.add_gate_to_layer(gate, layer);
341 processed.insert(spider_id);
342 processed.insert(neighbor_id);
343 }
344 }
345 }
346 }
347 }
348
349 if spider.spider_type == SpiderType::Z && spider.phase.abs() < 1e-10 {
351 for &(neighbor_id, edge_type) in &neighbors {
352 if let Some(neighbor) = self.diagram.spiders.get(&neighbor_id) {
353 if neighbor.spider_type == SpiderType::Z
354 && neighbor.phase.abs() < 1e-10
355 && edge_type == EdgeType::Hadamard
356 {
357 if let (Some(qubit1), Some(qubit2)) = (
359 self.get_spider_qubit(spider_id),
360 self.get_spider_qubit(neighbor_id),
361 ) {
362 let gate: Box<dyn GateOp> = Box::new(CZ {
363 control: qubit1,
364 target: qubit2,
365 });
366
367 let layer = self
368 .spider_positions
369 .get(&spider_id)
370 .map(|(l, _)| *l)
371 .unwrap_or(0);
372
373 self.add_gate_to_layer(gate, layer);
374 processed.insert(spider_id);
375 processed.insert(neighbor_id);
376 }
377 }
378 }
379 }
380 }
381 }
382 }
383
384 Ok(())
385 }
386
387 fn get_spider_qubit(&self, spider_id: usize) -> Option<QubitId> {
389 if let Some(spider) = self.diagram.spiders.get(&spider_id) {
391 if let Some(qubit) = spider.qubit {
392 return Some(qubit);
393 }
394 }
395
396 self.find_connected_boundary(spider_id)
398 }
399
400 fn find_connected_boundary(&self, spider_id: usize) -> Option<QubitId> {
402 let mut visited = HashSet::new();
403 let mut queue = VecDeque::new();
404 queue.push_back(spider_id);
405 visited.insert(spider_id);
406
407 while let Some(current) = queue.pop_front() {
408 if let Some(spider) = self.diagram.spiders.get(¤t) {
409 if let Some(qubit) = spider.qubit {
410 return Some(qubit);
411 }
412 }
413
414 for (neighbor, _) in self.diagram.neighbors(current) {
415 if !visited.contains(&neighbor) {
416 visited.insert(neighbor);
417 queue.push_back(neighbor);
418 }
419 }
420 }
421
422 None
423 }
424
425 fn get_edge_type(&self, spider1: usize, spider2: usize) -> Option<EdgeType> {
427 self.diagram
428 .neighbors(spider1)
429 .iter()
430 .find(|(id, _)| *id == spider2)
431 .map(|(_, edge_type)| *edge_type)
432 }
433
434 fn add_gate_to_layer(&mut self, gate: Box<dyn GateOp>, layer_idx: usize) {
436 while self.layers.len() <= layer_idx {
438 self.layers.push(GateLayer { gates: Vec::new() });
439 }
440
441 self.layers[layer_idx].gates.push(gate);
442 }
443}
444
445pub struct ZXPipeline {
447 optimizer: ZXOptimizer,
448}
449
450impl ZXPipeline {
451 pub fn new() -> Self {
453 Self {
454 optimizer: ZXOptimizer::new(),
455 }
456 }
457
458 pub fn optimize(&self, gates: &[Box<dyn GateOp>]) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
460 let num_qubits = gates
462 .iter()
463 .flat_map(|g| g.qubits())
464 .map(|q| q.0 + 1)
465 .max()
466 .unwrap_or(0);
467
468 let mut converter = CircuitToZX::new(num_qubits as usize);
470 for gate in gates {
471 converter.add_gate(gate.as_ref())?;
472 }
473
474 let mut diagram = converter.into_diagram();
475
476 let rewrites = diagram.simplify(100);
478 println!("Applied {} ZX-calculus rewrites", rewrites);
479
480 let mut extractor = ZXExtractor::new(diagram);
482 extractor.extract_circuit()
483 }
484
485 pub fn compare_t_count(
487 &self,
488 original: &[Box<dyn GateOp>],
489 optimized: &[Box<dyn GateOp>],
490 ) -> (usize, usize) {
491 let count_t = |gates: &[Box<dyn GateOp>]| {
492 gates
493 .iter()
494 .filter(|g| {
495 g.name() == "T"
496 || (g.name() == "RZ" && {
497 if let Some(rz) = g.as_any().downcast_ref::<RotationZ>() {
498 (rz.theta - PI / 4.0).abs() < 1e-10
499 } else {
500 false
501 }
502 })
503 })
504 .count()
505 };
506
507 (count_t(original), count_t(optimized))
508 }
509}
510
511impl Default for ZXPipeline {
512 fn default() -> Self {
513 Self::new()
514 }
515}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520
521 #[test]
522 fn test_circuit_extraction_identity() {
523 let mut diagram = ZXDiagram::new();
525 let input = diagram.add_boundary(QubitId(0), true);
526 let output = diagram.add_boundary(QubitId(0), false);
527 diagram.add_edge(input, output, EdgeType::Regular);
528
529 let mut extractor = ZXExtractor::new(diagram);
530 let circuit = extractor.extract_circuit().unwrap();
531
532 assert_eq!(circuit.len(), 0);
534 }
535
536 #[test]
537 fn test_circuit_extraction_single_gate() {
538 let mut diagram = ZXDiagram::new();
540 let input = diagram.add_boundary(QubitId(0), true);
541 let z_spider = diagram.add_spider(SpiderType::Z, PI / 2.0);
542 let output = diagram.add_boundary(QubitId(0), false);
543
544 diagram.add_edge(input, z_spider, EdgeType::Regular);
545 diagram.add_edge(z_spider, output, EdgeType::Regular);
546
547 let mut extractor = ZXExtractor::new(diagram);
548 let circuit = extractor.extract_circuit().unwrap();
549
550 assert_eq!(circuit.len(), 1);
552 assert_eq!(circuit[0].name(), "RZ");
553 }
554
555 #[test]
556 fn test_zx_pipeline_optimization() {
557 let gates: Vec<Box<dyn GateOp>> = vec![
559 Box::new(Hadamard { target: QubitId(0) }),
560 Box::new(PauliZ { target: QubitId(0) }),
561 Box::new(Hadamard { target: QubitId(0) }),
562 ];
563
564 let pipeline = ZXPipeline::new();
565 let optimized = pipeline.optimize(&gates).unwrap();
566
567 assert!(optimized.len() <= gates.len());
569 }
570
571 #[test]
572 fn test_t_count_reduction() {
573 let gates: Vec<Box<dyn GateOp>> = vec![
575 Box::new(RotationZ {
576 target: QubitId(0),
577 theta: PI / 4.0,
578 }), Box::new(RotationZ {
580 target: QubitId(0),
581 theta: PI / 4.0,
582 }), ];
584
585 let pipeline = ZXPipeline::new();
586 let optimized = pipeline.optimize(&gates).unwrap();
587
588 let (original_t, optimized_t) = pipeline.compare_t_count(&gates, &optimized);
589
590 assert_eq!(original_t, 2);
592 assert!(optimized_t <= original_t);
593 }
594}