1use crate::{
7 error::{QuantRS2Error, QuantRS2Result},
8 gate::{multi::*, single::*, GateOp},
9 qubit::QubitId,
10 zx_calculus::{CircuitToZX, EdgeType, SpiderType, ZXDiagram, ZXOptimizer},
11};
12use rustc_hash::FxHashMap;
13use std::collections::{HashSet, VecDeque};
14use std::f64::consts::PI;
15
16#[derive(Debug, Clone)]
18struct GateLayer {
19 gates: Vec<Box<dyn GateOp>>,
21}
22
23pub struct ZXExtractor {
25 diagram: ZXDiagram,
26 spider_positions: FxHashMap<usize, (usize, usize)>, layers: Vec<GateLayer>,
30}
31
32impl ZXExtractor {
33 pub fn new(diagram: ZXDiagram) -> Self {
35 Self {
36 diagram,
37 spider_positions: FxHashMap::default(),
38 layers: Vec::new(),
39 }
40 }
41
42 pub fn extract_circuit(&mut self) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
44 self.analyze_diagram()?;
46
47 self.extract_gates()?;
49
50 let mut circuit = Vec::new();
52 for layer in &self.layers {
53 circuit.extend(layer.gates.clone());
54 }
55
56 Ok(circuit)
57 }
58
59 fn analyze_diagram(&mut self) -> QuantRS2Result<()> {
61 let inputs = self.diagram.inputs.clone();
63 let outputs = self.diagram.outputs.clone();
64
65 let topo_order = self.topological_sort(&inputs, &outputs)?;
67
68 for (layer_idx, spider_id) in topo_order.iter().enumerate() {
70 self.spider_positions.insert(*spider_id, (layer_idx, 0));
71 }
72
73 Ok(())
74 }
75
76 fn topological_sort(&self, inputs: &[usize], outputs: &[usize]) -> QuantRS2Result<Vec<usize>> {
78 let mut in_degree: FxHashMap<usize, usize> = FxHashMap::default();
79 let mut adjacency: FxHashMap<usize, Vec<usize>> = FxHashMap::default();
80
81 for &spider_id in self.diagram.spiders.keys() {
83 in_degree.insert(spider_id, 0);
84 adjacency.insert(spider_id, Vec::new());
85 }
86
87 for (&source, neighbors) in &self.diagram.adjacency {
89 for &(target, _) in neighbors {
90 if self.is_forward_edge(source, target, inputs, outputs) {
92 if let Some(degree) = in_degree.get_mut(&target) {
93 *degree += 1;
94 }
95 if let Some(adj) = adjacency.get_mut(&source) {
96 adj.push(target);
97 }
98 }
99 }
100 }
101
102 let mut queue: VecDeque<usize> = inputs.iter().copied().collect();
104 let mut topo_order = Vec::new();
105
106 while let Some(spider) = queue.pop_front() {
107 topo_order.push(spider);
108
109 if let Some(neighbors) = adjacency.get(&spider) {
110 for &neighbor in neighbors {
111 if let Some(degree) = in_degree.get_mut(&neighbor) {
112 *degree -= 1;
113 if *degree == 0 {
114 queue.push_back(neighbor);
115 }
116 }
117 }
118 }
119 }
120
121 if topo_order.len() != self.diagram.spiders.len() {
122 return Err(QuantRS2Error::InvalidInput(
123 "ZX-diagram contains cycles or disconnected components".to_string(),
124 ));
125 }
126
127 Ok(topo_order)
128 }
129
130 fn is_forward_edge(
132 &self,
133 source: usize,
134 target: usize,
135 inputs: &[usize],
136 outputs: &[usize],
137 ) -> bool {
138 if inputs.contains(&source) && !inputs.contains(&target) {
141 return true;
142 }
143 if outputs.contains(&target) && !outputs.contains(&source) {
144 return true;
145 }
146
147 let source_dist = self.distance_from_inputs(source, inputs);
149 let target_dist = self.distance_from_inputs(target, inputs);
150
151 source_dist < target_dist
152 }
153
154 fn distance_from_inputs(&self, spider: usize, inputs: &[usize]) -> usize {
156 if inputs.contains(&spider) {
157 return 0;
158 }
159
160 let mut visited = HashSet::new();
161 let mut queue = VecDeque::new();
162
163 for &input in inputs {
164 queue.push_back((input, 0));
165 visited.insert(input);
166 }
167
168 while let Some((current, dist)) = queue.pop_front() {
169 if current == spider {
170 return dist;
171 }
172
173 for (neighbor, _) in self.diagram.neighbors(current) {
174 if visited.insert(neighbor) {
175 queue.push_back((neighbor, dist + 1));
176 }
177 }
178 }
179
180 usize::MAX }
182
183 fn extract_gates(&mut self) -> QuantRS2Result<()> {
185 let qubit_spiders = self.group_by_qubit()?;
187
188 for (qubit, spider_chain) in qubit_spiders {
190 self.extract_single_qubit_gates(&qubit, &spider_chain)?;
191 }
192
193 self.extract_two_qubit_gates()?;
195
196 Ok(())
197 }
198
199 fn group_by_qubit(&self) -> QuantRS2Result<FxHashMap<QubitId, Vec<usize>>> {
201 let mut qubit_spiders: FxHashMap<QubitId, Vec<usize>> = FxHashMap::default();
202
203 for &input_id in &self.diagram.inputs {
205 if let Some(input_spider) = self.diagram.spiders.get(&input_id) {
206 if let Some(qubit) = input_spider.qubit {
207 let chain = self.trace_qubit_line(input_id)?;
208 qubit_spiders.insert(qubit, chain);
209 }
210 }
211 }
212
213 Ok(qubit_spiders)
214 }
215
216 fn trace_qubit_line(&self, start: usize) -> QuantRS2Result<Vec<usize>> {
218 let mut chain = vec![start];
219 let mut current = start;
220 let mut visited = HashSet::new();
221 visited.insert(start);
222
223 while !self.diagram.outputs.contains(¤t) {
225 let neighbors = self.diagram.neighbors(current);
226
227 let next = neighbors
229 .iter()
230 .find(|(id, _)| !visited.contains(id) && self.is_on_qubit_line(*id))
231 .map(|(id, _)| *id);
232
233 if let Some(next_id) = next {
234 chain.push(next_id);
235 visited.insert(next_id);
236 current = next_id;
237 } else {
238 break;
239 }
240 }
241
242 Ok(chain)
243 }
244
245 fn is_on_qubit_line(&self, spider_id: usize) -> bool {
247 if let Some(spider) = self.diagram.spiders.get(&spider_id) {
249 spider.spider_type == SpiderType::Boundary || self.diagram.degree(spider_id) <= 2
250 } else {
251 false
252 }
253 }
254
255 fn extract_single_qubit_gates(
257 &mut self,
258 qubit: &QubitId,
259 spider_chain: &[usize],
260 ) -> QuantRS2Result<()> {
261 let mut i = 0;
262
263 while i < spider_chain.len() {
264 let spider_id = spider_chain[i];
265
266 if let Some(spider) = self.diagram.spiders.get(&spider_id) {
267 match spider.spider_type {
268 SpiderType::Z if spider.phase.abs() > 1e-10 => {
269 let gate: Box<dyn GateOp> = Box::new(RotationZ {
271 target: *qubit,
272 theta: spider.phase,
273 });
274 self.add_gate_to_layer(gate, i);
275 }
276 SpiderType::X if spider.phase.abs() > 1e-10 => {
277 let gate: Box<dyn GateOp> = Box::new(RotationX {
279 target: *qubit,
280 theta: spider.phase,
281 });
282 self.add_gate_to_layer(gate, i);
283 }
284 _ => {}
285 }
286
287 if i + 1 < spider_chain.len() {
289 let next_id = spider_chain[i + 1];
290 let edge_type = self.get_edge_type(spider_id, next_id);
291
292 if edge_type == Some(EdgeType::Hadamard) {
293 let gate: Box<dyn GateOp> = Box::new(Hadamard { target: *qubit });
294 self.add_gate_to_layer(gate, i);
295 }
296 }
297 }
298
299 i += 1;
300 }
301
302 Ok(())
303 }
304
305 fn extract_two_qubit_gates(&mut self) -> QuantRS2Result<()> {
307 let mut processed = HashSet::new();
308
309 for (&spider_id, spider) in &self.diagram.spiders.clone() {
310 if processed.contains(&spider_id) {
311 continue;
312 }
313
314 if self.diagram.degree(spider_id) > 2 {
316 let neighbors = self.diagram.neighbors(spider_id);
318
319 if spider.spider_type == SpiderType::Z && spider.phase.abs() < 1e-10 {
321 for &(neighbor_id, edge_type) in &neighbors {
322 if let Some(neighbor) = self.diagram.spiders.get(&neighbor_id) {
323 if neighbor.spider_type == SpiderType::X
324 && neighbor.phase.abs() < 1e-10
325 && edge_type == EdgeType::Regular
326 {
327 if let (Some(control_qubit), Some(target_qubit)) = (
329 self.get_spider_qubit(spider_id),
330 self.get_spider_qubit(neighbor_id),
331 ) {
332 let gate: Box<dyn GateOp> = Box::new(CNOT {
333 control: control_qubit,
334 target: target_qubit,
335 });
336
337 let layer = self
338 .spider_positions
339 .get(&spider_id)
340 .map_or(0, |(l, _)| *l);
341
342 self.add_gate_to_layer(gate, layer);
343 processed.insert(spider_id);
344 processed.insert(neighbor_id);
345 }
346 }
347 }
348 }
349 }
350
351 if spider.spider_type == SpiderType::Z && spider.phase.abs() < 1e-10 {
353 for &(neighbor_id, edge_type) in &neighbors {
354 if let Some(neighbor) = self.diagram.spiders.get(&neighbor_id) {
355 if neighbor.spider_type == SpiderType::Z
356 && neighbor.phase.abs() < 1e-10
357 && edge_type == EdgeType::Hadamard
358 {
359 if let (Some(qubit1), Some(qubit2)) = (
361 self.get_spider_qubit(spider_id),
362 self.get_spider_qubit(neighbor_id),
363 ) {
364 let gate: Box<dyn GateOp> = Box::new(CZ {
365 control: qubit1,
366 target: qubit2,
367 });
368
369 let layer = self
370 .spider_positions
371 .get(&spider_id)
372 .map_or(0, |(l, _)| *l);
373
374 self.add_gate_to_layer(gate, layer);
375 processed.insert(spider_id);
376 processed.insert(neighbor_id);
377 }
378 }
379 }
380 }
381 }
382 }
383 }
384
385 Ok(())
386 }
387
388 fn get_spider_qubit(&self, spider_id: usize) -> Option<QubitId> {
390 if let Some(spider) = self.diagram.spiders.get(&spider_id) {
392 if let Some(qubit) = spider.qubit {
393 return Some(qubit);
394 }
395 }
396
397 self.find_connected_boundary(spider_id)
399 }
400
401 fn find_connected_boundary(&self, spider_id: usize) -> Option<QubitId> {
403 let mut visited = HashSet::new();
404 let mut queue = VecDeque::new();
405 queue.push_back(spider_id);
406 visited.insert(spider_id);
407
408 while let Some(current) = queue.pop_front() {
409 if let Some(spider) = self.diagram.spiders.get(¤t) {
410 if let Some(qubit) = spider.qubit {
411 return Some(qubit);
412 }
413 }
414
415 for (neighbor, _) in self.diagram.neighbors(current) {
416 if visited.insert(neighbor) {
417 queue.push_back(neighbor);
418 }
419 }
420 }
421
422 None
423 }
424
425 fn get_edge_type(&self, spider1: usize, spider2: usize) -> Option<EdgeType> {
427 self.diagram
428 .neighbors(spider1)
429 .iter()
430 .find(|(id, _)| *id == spider2)
431 .map(|(_, edge_type)| *edge_type)
432 }
433
434 fn add_gate_to_layer(&mut self, gate: Box<dyn GateOp>, layer_idx: usize) {
436 while self.layers.len() <= layer_idx {
438 self.layers.push(GateLayer { gates: Vec::new() });
439 }
440
441 self.layers[layer_idx].gates.push(gate);
442 }
443}
444
445pub struct ZXPipeline {
447 optimizer: ZXOptimizer,
448}
449
450impl ZXPipeline {
451 pub const fn new() -> Self {
453 Self {
454 optimizer: ZXOptimizer::new(),
455 }
456 }
457
458 pub fn optimize(&self, gates: &[Box<dyn GateOp>]) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
460 let num_qubits = gates
462 .iter()
463 .flat_map(|g| g.qubits())
464 .map(|q| q.0 + 1)
465 .max()
466 .unwrap_or(0);
467
468 let mut converter = CircuitToZX::new(num_qubits as usize);
470 for gate in gates {
471 converter.add_gate(gate.as_ref())?;
472 }
473
474 let mut diagram = converter.into_diagram();
475
476 let rewrites = diagram.simplify(100);
478 println!("Applied {rewrites} ZX-calculus rewrites");
479
480 let mut extractor = ZXExtractor::new(diagram);
482 extractor.extract_circuit()
483 }
484
485 pub fn compare_t_count(
487 &self,
488 original: &[Box<dyn GateOp>],
489 optimized: &[Box<dyn GateOp>],
490 ) -> (usize, usize) {
491 let count_t = |gates: &[Box<dyn GateOp>]| {
492 gates
493 .iter()
494 .filter(|g| {
495 g.name() == "T"
496 || (g.name() == "RZ" && {
497 if let Some(rz) = g.as_any().downcast_ref::<RotationZ>() {
498 (rz.theta - PI / 4.0).abs() < 1e-10
499 } else {
500 false
501 }
502 })
503 })
504 .count()
505 };
506
507 (count_t(original), count_t(optimized))
508 }
509}
510
511impl Default for ZXPipeline {
512 fn default() -> Self {
513 Self::new()
514 }
515}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520
521 #[test]
522 fn test_circuit_extraction_identity() {
523 let mut diagram = ZXDiagram::new();
525 let input = diagram.add_boundary(QubitId(0), true);
526 let output = diagram.add_boundary(QubitId(0), false);
527 diagram.add_edge(input, output, EdgeType::Regular);
528
529 let mut extractor = ZXExtractor::new(diagram);
530 let circuit = extractor
531 .extract_circuit()
532 .expect("Failed to extract circuit");
533
534 assert_eq!(circuit.len(), 0);
536 }
537
538 #[test]
539 fn test_circuit_extraction_single_gate() {
540 let mut diagram = ZXDiagram::new();
542 let input = diagram.add_boundary(QubitId(0), true);
543 let z_spider = diagram.add_spider(SpiderType::Z, PI / 2.0);
544 let output = diagram.add_boundary(QubitId(0), false);
545
546 diagram.add_edge(input, z_spider, EdgeType::Regular);
547 diagram.add_edge(z_spider, output, EdgeType::Regular);
548
549 let mut extractor = ZXExtractor::new(diagram);
550 let circuit = extractor
551 .extract_circuit()
552 .expect("Failed to extract circuit");
553
554 assert_eq!(circuit.len(), 1);
556 assert_eq!(circuit[0].name(), "RZ");
557 }
558
559 #[test]
560 fn test_zx_pipeline_optimization() {
561 let gates: Vec<Box<dyn GateOp>> = vec![
563 Box::new(Hadamard { target: QubitId(0) }),
564 Box::new(PauliZ { target: QubitId(0) }),
565 Box::new(Hadamard { target: QubitId(0) }),
566 ];
567
568 let pipeline = ZXPipeline::new();
569 let optimized = pipeline
570 .optimize(&gates)
571 .expect("Failed to optimize circuit");
572
573 assert!(optimized.len() <= gates.len());
575 }
576
577 #[test]
578 fn test_t_count_reduction() {
579 let gates: Vec<Box<dyn GateOp>> = vec![
581 Box::new(RotationZ {
582 target: QubitId(0),
583 theta: PI / 4.0,
584 }), Box::new(RotationZ {
586 target: QubitId(0),
587 theta: PI / 4.0,
588 }), ];
590
591 let pipeline = ZXPipeline::new();
592 let optimized = pipeline
593 .optimize(&gates)
594 .expect("Failed to optimize circuit");
595
596 let (original_t, optimized_t) = pipeline.compare_t_count(&gates, &optimized);
597
598 assert_eq!(original_t, 2);
600 assert!(optimized_t <= original_t);
601 }
602}