quantrs2_core/
zx_extraction.rs

1//! Circuit extraction from ZX-diagrams
2//!
3//! This module provides algorithms to extract quantum circuits from
4//! optimized ZX-diagrams, completing the optimization pipeline.
5
6use crate::{
7    error::{QuantRS2Error, QuantRS2Result},
8    gate::{multi::*, single::*, GateOp},
9    qubit::QubitId,
10    zx_calculus::{CircuitToZX, EdgeType, SpiderType, ZXDiagram, ZXOptimizer},
11};
12use rustc_hash::FxHashMap;
13use std::collections::{HashSet, VecDeque};
14use std::f64::consts::PI;
15
16/// Represents a layer of gates in the extracted circuit
17#[derive(Debug, Clone)]
18struct GateLayer {
19    /// Gates in this layer (can be applied in parallel)
20    gates: Vec<Box<dyn GateOp>>,
21}
22
23/// Circuit extractor from ZX-diagrams
24pub struct ZXExtractor {
25    diagram: ZXDiagram,
26    /// Maps spider IDs to their positions in the circuit
27    spider_positions: FxHashMap<usize, (usize, usize)>, // (layer, position_in_layer)
28    /// Layers of the circuit
29    layers: Vec<GateLayer>,
30}
31
32impl ZXExtractor {
33    /// Create a new extractor from a ZX-diagram
34    pub fn new(diagram: ZXDiagram) -> Self {
35        Self {
36            diagram,
37            spider_positions: FxHashMap::default(),
38            layers: Vec::new(),
39        }
40    }
41
42    /// Extract a quantum circuit from the ZX-diagram
43    pub fn extract_circuit(&mut self) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
44        // First, perform graph analysis to understand the structure
45        self.analyze_diagram()?;
46
47        // Extract gates layer by layer
48        self.extract_gates()?;
49
50        // Flatten layers into a single circuit
51        let mut circuit = Vec::new();
52        for layer in &self.layers {
53            circuit.extend(layer.gates.clone());
54        }
55
56        Ok(circuit)
57    }
58
59    /// Analyze the diagram structure
60    fn analyze_diagram(&mut self) -> QuantRS2Result<()> {
61        // Find the flow from inputs to outputs using BFS
62        let inputs = self.diagram.inputs.clone();
63        let outputs = self.diagram.outputs.clone();
64
65        // Create a topological ordering of spiders
66        let topo_order = self.topological_sort(&inputs, &outputs)?;
67
68        // Assign positions to spiders
69        for (layer_idx, spider_id) in topo_order.iter().enumerate() {
70            self.spider_positions.insert(*spider_id, (layer_idx, 0));
71        }
72
73        Ok(())
74    }
75
76    /// Perform topological sort from inputs to outputs
77    fn topological_sort(&self, inputs: &[usize], outputs: &[usize]) -> QuantRS2Result<Vec<usize>> {
78        let mut in_degree: FxHashMap<usize, usize> = FxHashMap::default();
79        let mut adjacency: FxHashMap<usize, Vec<usize>> = FxHashMap::default();
80
81        // Build directed graph (from inputs to outputs)
82        for &spider_id in self.diagram.spiders.keys() {
83            in_degree.insert(spider_id, 0);
84            adjacency.insert(spider_id, Vec::new());
85        }
86
87        // Count in-degrees and build adjacency
88        for (&source, neighbors) in &self.diagram.adjacency {
89            for &(target, _) in neighbors {
90                // Only count edges going "forward" (avoid double counting)
91                if self.is_forward_edge(source, target, inputs, outputs) {
92                    *in_degree.get_mut(&target).unwrap() += 1;
93                    adjacency.get_mut(&source).unwrap().push(target);
94                }
95            }
96        }
97
98        // Kahn's algorithm for topological sort
99        let mut queue: VecDeque<usize> = inputs.iter().cloned().collect();
100        let mut topo_order = Vec::new();
101
102        while let Some(spider) = queue.pop_front() {
103            topo_order.push(spider);
104
105            if let Some(neighbors) = adjacency.get(&spider) {
106                for &neighbor in neighbors {
107                    if let Some(degree) = in_degree.get_mut(&neighbor) {
108                        *degree -= 1;
109                        if *degree == 0 {
110                            queue.push_back(neighbor);
111                        }
112                    }
113                }
114            }
115        }
116
117        if topo_order.len() != self.diagram.spiders.len() {
118            return Err(QuantRS2Error::InvalidInput(
119                "ZX-diagram contains cycles or disconnected components".to_string(),
120            ));
121        }
122
123        Ok(topo_order)
124    }
125
126    /// Determine if an edge goes "forward" in the circuit
127    fn is_forward_edge(
128        &self,
129        source: usize,
130        target: usize,
131        inputs: &[usize],
132        outputs: &[usize],
133    ) -> bool {
134        // More sophisticated flow analysis
135        // Boundary nodes have clear direction
136        if inputs.contains(&source) && !inputs.contains(&target) {
137            return true;
138        }
139        if outputs.contains(&target) && !outputs.contains(&source) {
140            return true;
141        }
142
143        // For non-boundary nodes, check if source is closer to inputs
144        let source_dist = self.distance_from_inputs(source, inputs);
145        let target_dist = self.distance_from_inputs(target, inputs);
146
147        source_dist < target_dist
148    }
149
150    /// Calculate minimum distance from a spider to any input
151    fn distance_from_inputs(&self, spider: usize, inputs: &[usize]) -> usize {
152        if inputs.contains(&spider) {
153            return 0;
154        }
155
156        let mut visited = HashSet::new();
157        let mut queue = VecDeque::new();
158
159        for &input in inputs {
160            queue.push_back((input, 0));
161            visited.insert(input);
162        }
163
164        while let Some((current, dist)) = queue.pop_front() {
165            if current == spider {
166                return dist;
167            }
168
169            for (neighbor, _) in self.diagram.neighbors(current) {
170                if !visited.contains(&neighbor) {
171                    visited.insert(neighbor);
172                    queue.push_back((neighbor, dist + 1));
173                }
174            }
175        }
176
177        usize::MAX // Not reachable from inputs
178    }
179
180    /// Extract gates from the analyzed diagram
181    fn extract_gates(&mut self) -> QuantRS2Result<()> {
182        // Group spiders by qubit
183        let qubit_spiders = self.group_by_qubit()?;
184
185        // Process each qubit line
186        for (qubit, spider_chain) in qubit_spiders {
187            self.extract_single_qubit_gates(&qubit, &spider_chain)?;
188        }
189
190        // Extract two-qubit gates
191        self.extract_two_qubit_gates()?;
192
193        Ok(())
194    }
195
196    /// Group spiders by their associated qubit
197    fn group_by_qubit(&self) -> QuantRS2Result<FxHashMap<QubitId, Vec<usize>>> {
198        let mut qubit_spiders: FxHashMap<QubitId, Vec<usize>> = FxHashMap::default();
199
200        // Start from input boundaries
201        for &input_id in &self.diagram.inputs {
202            if let Some(input_spider) = self.diagram.spiders.get(&input_id) {
203                if let Some(qubit) = input_spider.qubit {
204                    let chain = self.trace_qubit_line(input_id)?;
205                    qubit_spiders.insert(qubit, chain);
206                }
207            }
208        }
209
210        Ok(qubit_spiders)
211    }
212
213    /// Trace a qubit line from input to output
214    fn trace_qubit_line(&self, start: usize) -> QuantRS2Result<Vec<usize>> {
215        let mut chain = vec![start];
216        let mut current = start;
217        let mut visited = HashSet::new();
218        visited.insert(start);
219
220        // Follow the qubit line
221        while !self.diagram.outputs.contains(&current) {
222            let neighbors = self.diagram.neighbors(current);
223
224            // Find the next spider in the chain (not visited, not backwards)
225            let next = neighbors
226                .iter()
227                .find(|(id, _)| !visited.contains(id) && self.is_on_qubit_line(*id))
228                .map(|(id, _)| *id);
229
230            if let Some(next_id) = next {
231                chain.push(next_id);
232                visited.insert(next_id);
233                current = next_id;
234            } else {
235                break;
236            }
237        }
238
239        Ok(chain)
240    }
241
242    /// Check if a spider is on a qubit line (not a control/target connection)
243    fn is_on_qubit_line(&self, spider_id: usize) -> bool {
244        // Boundary spiders are always on qubit lines
245        if let Some(spider) = self.diagram.spiders.get(&spider_id) {
246            spider.spider_type == SpiderType::Boundary || self.diagram.degree(spider_id) <= 2
247        } else {
248            false
249        }
250    }
251
252    /// Extract single-qubit gates from a chain of spiders
253    fn extract_single_qubit_gates(
254        &mut self,
255        qubit: &QubitId,
256        spider_chain: &[usize],
257    ) -> QuantRS2Result<()> {
258        let mut i = 0;
259
260        while i < spider_chain.len() {
261            let spider_id = spider_chain[i];
262
263            if let Some(spider) = self.diagram.spiders.get(&spider_id) {
264                match spider.spider_type {
265                    SpiderType::Z if spider.phase.abs() > 1e-10 => {
266                        // Z rotation
267                        let gate: Box<dyn GateOp> = Box::new(RotationZ {
268                            target: *qubit,
269                            theta: spider.phase,
270                        });
271                        self.add_gate_to_layer(gate, i);
272                    }
273                    SpiderType::X if spider.phase.abs() > 1e-10 => {
274                        // X rotation
275                        let gate: Box<dyn GateOp> = Box::new(RotationX {
276                            target: *qubit,
277                            theta: spider.phase,
278                        });
279                        self.add_gate_to_layer(gate, i);
280                    }
281                    _ => {}
282                }
283
284                // Check for Hadamard edges
285                if i + 1 < spider_chain.len() {
286                    let next_id = spider_chain[i + 1];
287                    let edge_type = self.get_edge_type(spider_id, next_id);
288
289                    if edge_type == Some(EdgeType::Hadamard) {
290                        let gate: Box<dyn GateOp> = Box::new(Hadamard { target: *qubit });
291                        self.add_gate_to_layer(gate, i);
292                    }
293                }
294            }
295
296            i += 1;
297        }
298
299        Ok(())
300    }
301
302    /// Extract two-qubit gates from spider connections
303    fn extract_two_qubit_gates(&mut self) -> QuantRS2Result<()> {
304        let mut processed = HashSet::new();
305
306        for (&spider_id, spider) in &self.diagram.spiders.clone() {
307            if processed.contains(&spider_id) {
308                continue;
309            }
310
311            // Look for patterns that represent two-qubit gates
312            if self.diagram.degree(spider_id) > 2 {
313                // This spider connects multiple qubits
314                let neighbors = self.diagram.neighbors(spider_id);
315
316                // Check for CNOT pattern (Z spider connected to X spider)
317                if spider.spider_type == SpiderType::Z && spider.phase.abs() < 1e-10 {
318                    for &(neighbor_id, edge_type) in &neighbors {
319                        if let Some(neighbor) = self.diagram.spiders.get(&neighbor_id) {
320                            if neighbor.spider_type == SpiderType::X
321                                && neighbor.phase.abs() < 1e-10
322                                && edge_type == EdgeType::Regular
323                            {
324                                // Found CNOT pattern
325                                if let (Some(control_qubit), Some(target_qubit)) = (
326                                    self.get_spider_qubit(spider_id),
327                                    self.get_spider_qubit(neighbor_id),
328                                ) {
329                                    let gate: Box<dyn GateOp> = Box::new(CNOT {
330                                        control: control_qubit,
331                                        target: target_qubit,
332                                    });
333
334                                    let layer = self
335                                        .spider_positions
336                                        .get(&spider_id)
337                                        .map(|(l, _)| *l)
338                                        .unwrap_or(0);
339
340                                    self.add_gate_to_layer(gate, layer);
341                                    processed.insert(spider_id);
342                                    processed.insert(neighbor_id);
343                                }
344                            }
345                        }
346                    }
347                }
348
349                // Check for CZ pattern (two Z spiders connected with Hadamard)
350                if spider.spider_type == SpiderType::Z && spider.phase.abs() < 1e-10 {
351                    for &(neighbor_id, edge_type) in &neighbors {
352                        if let Some(neighbor) = self.diagram.spiders.get(&neighbor_id) {
353                            if neighbor.spider_type == SpiderType::Z
354                                && neighbor.phase.abs() < 1e-10
355                                && edge_type == EdgeType::Hadamard
356                            {
357                                // Found CZ pattern
358                                if let (Some(qubit1), Some(qubit2)) = (
359                                    self.get_spider_qubit(spider_id),
360                                    self.get_spider_qubit(neighbor_id),
361                                ) {
362                                    let gate: Box<dyn GateOp> = Box::new(CZ {
363                                        control: qubit1,
364                                        target: qubit2,
365                                    });
366
367                                    let layer = self
368                                        .spider_positions
369                                        .get(&spider_id)
370                                        .map(|(l, _)| *l)
371                                        .unwrap_or(0);
372
373                                    self.add_gate_to_layer(gate, layer);
374                                    processed.insert(spider_id);
375                                    processed.insert(neighbor_id);
376                                }
377                            }
378                        }
379                    }
380                }
381            }
382        }
383
384        Ok(())
385    }
386
387    /// Get the qubit associated with a spider
388    fn get_spider_qubit(&self, spider_id: usize) -> Option<QubitId> {
389        // First check if it's a boundary spider
390        if let Some(spider) = self.diagram.spiders.get(&spider_id) {
391            if let Some(qubit) = spider.qubit {
392                return Some(qubit);
393            }
394        }
395
396        // Otherwise, trace back to find the input boundary
397        self.find_connected_boundary(spider_id)
398    }
399
400    /// Find a boundary spider connected to this spider
401    fn find_connected_boundary(&self, spider_id: usize) -> Option<QubitId> {
402        let mut visited = HashSet::new();
403        let mut queue = VecDeque::new();
404        queue.push_back(spider_id);
405        visited.insert(spider_id);
406
407        while let Some(current) = queue.pop_front() {
408            if let Some(spider) = self.diagram.spiders.get(&current) {
409                if let Some(qubit) = spider.qubit {
410                    return Some(qubit);
411                }
412            }
413
414            for (neighbor, _) in self.diagram.neighbors(current) {
415                if !visited.contains(&neighbor) {
416                    visited.insert(neighbor);
417                    queue.push_back(neighbor);
418                }
419            }
420        }
421
422        None
423    }
424
425    /// Get the edge type between two spiders
426    fn get_edge_type(&self, spider1: usize, spider2: usize) -> Option<EdgeType> {
427        self.diagram
428            .neighbors(spider1)
429            .iter()
430            .find(|(id, _)| *id == spider2)
431            .map(|(_, edge_type)| *edge_type)
432    }
433
434    /// Add a gate to the appropriate layer
435    fn add_gate_to_layer(&mut self, gate: Box<dyn GateOp>, layer_idx: usize) {
436        // Ensure we have enough layers
437        while self.layers.len() <= layer_idx {
438            self.layers.push(GateLayer { gates: Vec::new() });
439        }
440
441        self.layers[layer_idx].gates.push(gate);
442    }
443}
444
445/// Complete ZX-calculus optimization pipeline
446pub struct ZXPipeline {
447    optimizer: ZXOptimizer,
448}
449
450impl ZXPipeline {
451    /// Create a new ZX optimization pipeline
452    pub fn new() -> Self {
453        Self {
454            optimizer: ZXOptimizer::new(),
455        }
456    }
457
458    /// Optimize a circuit using ZX-calculus
459    pub fn optimize(&self, gates: &[Box<dyn GateOp>]) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
460        // Find number of qubits
461        let num_qubits = gates
462            .iter()
463            .flat_map(|g| g.qubits())
464            .map(|q| q.0 + 1)
465            .max()
466            .unwrap_or(0);
467
468        // Convert to ZX-diagram
469        let mut converter = CircuitToZX::new(num_qubits as usize);
470        for gate in gates {
471            converter.add_gate(gate.as_ref())?;
472        }
473
474        let mut diagram = converter.into_diagram();
475
476        // Optimize the diagram
477        let rewrites = diagram.simplify(100);
478        println!("Applied {} ZX-calculus rewrites", rewrites);
479
480        // Extract optimized circuit
481        let mut extractor = ZXExtractor::new(diagram);
482        extractor.extract_circuit()
483    }
484
485    /// Compare T-count before and after optimization
486    pub fn compare_t_count(
487        &self,
488        original: &[Box<dyn GateOp>],
489        optimized: &[Box<dyn GateOp>],
490    ) -> (usize, usize) {
491        let count_t = |gates: &[Box<dyn GateOp>]| {
492            gates
493                .iter()
494                .filter(|g| {
495                    g.name() == "T"
496                        || (g.name() == "RZ" && {
497                            if let Some(rz) = g.as_any().downcast_ref::<RotationZ>() {
498                                (rz.theta - PI / 4.0).abs() < 1e-10
499                            } else {
500                                false
501                            }
502                        })
503                })
504                .count()
505        };
506
507        (count_t(original), count_t(optimized))
508    }
509}
510
511impl Default for ZXPipeline {
512    fn default() -> Self {
513        Self::new()
514    }
515}
516
517#[cfg(test)]
518mod tests {
519    use super::*;
520
521    #[test]
522    fn test_circuit_extraction_identity() {
523        // Create a simple diagram with just boundaries
524        let mut diagram = ZXDiagram::new();
525        let input = diagram.add_boundary(QubitId(0), true);
526        let output = diagram.add_boundary(QubitId(0), false);
527        diagram.add_edge(input, output, EdgeType::Regular);
528
529        let mut extractor = ZXExtractor::new(diagram);
530        let circuit = extractor.extract_circuit().unwrap();
531
532        // Should extract empty circuit (identity)
533        assert_eq!(circuit.len(), 0);
534    }
535
536    #[test]
537    fn test_circuit_extraction_single_gate() {
538        // Create diagram with single Z rotation
539        let mut diagram = ZXDiagram::new();
540        let input = diagram.add_boundary(QubitId(0), true);
541        let z_spider = diagram.add_spider(SpiderType::Z, PI / 2.0);
542        let output = diagram.add_boundary(QubitId(0), false);
543
544        diagram.add_edge(input, z_spider, EdgeType::Regular);
545        diagram.add_edge(z_spider, output, EdgeType::Regular);
546
547        let mut extractor = ZXExtractor::new(diagram);
548        let circuit = extractor.extract_circuit().unwrap();
549
550        // Should extract one RZ gate
551        assert_eq!(circuit.len(), 1);
552        assert_eq!(circuit[0].name(), "RZ");
553    }
554
555    #[test]
556    fn test_zx_pipeline_optimization() {
557        // Create a circuit that can be optimized: HZH = X
558        let gates: Vec<Box<dyn GateOp>> = vec![
559            Box::new(Hadamard { target: QubitId(0) }),
560            Box::new(PauliZ { target: QubitId(0) }),
561            Box::new(Hadamard { target: QubitId(0) }),
562        ];
563
564        let pipeline = ZXPipeline::new();
565        let optimized = pipeline.optimize(&gates).unwrap();
566
567        // The optimized circuit should be simpler
568        assert!(optimized.len() <= gates.len());
569    }
570
571    #[test]
572    fn test_t_count_reduction() {
573        // Create a circuit with T gates
574        let gates: Vec<Box<dyn GateOp>> = vec![
575            Box::new(RotationZ {
576                target: QubitId(0),
577                theta: PI / 4.0,
578            }), // T gate
579            Box::new(RotationZ {
580                target: QubitId(0),
581                theta: PI / 4.0,
582            }), // T gate
583        ];
584
585        let pipeline = ZXPipeline::new();
586        let optimized = pipeline.optimize(&gates).unwrap();
587
588        let (original_t, optimized_t) = pipeline.compare_t_count(&gates, &optimized);
589
590        // Two T gates should fuse into S gate (or equivalent)
591        assert_eq!(original_t, 2);
592        assert!(optimized_t <= original_t);
593    }
594}