quantrs2_core/
zx_extraction.rs

1//! Circuit extraction from ZX-diagrams
2//!
3//! This module provides algorithms to extract quantum circuits from
4//! optimized ZX-diagrams, completing the optimization pipeline.
5
6use crate::{
7    error::{QuantRS2Error, QuantRS2Result},
8    gate::{multi::*, single::*, GateOp},
9    qubit::QubitId,
10    zx_calculus::{CircuitToZX, EdgeType, SpiderType, ZXDiagram, ZXOptimizer},
11};
12use rustc_hash::FxHashMap;
13use std::collections::{HashSet, VecDeque};
14use std::f64::consts::PI;
15
16/// Represents a layer of gates in the extracted circuit
17#[derive(Debug, Clone)]
18struct GateLayer {
19    /// Gates in this layer (can be applied in parallel)
20    gates: Vec<Box<dyn GateOp>>,
21}
22
23/// Circuit extractor from ZX-diagrams
24pub struct ZXExtractor {
25    diagram: ZXDiagram,
26    /// Maps spider IDs to their positions in the circuit
27    spider_positions: FxHashMap<usize, (usize, usize)>, // (layer, position_in_layer)
28    /// Layers of the circuit
29    layers: Vec<GateLayer>,
30}
31
32impl ZXExtractor {
33    /// Create a new extractor from a ZX-diagram
34    pub fn new(diagram: ZXDiagram) -> Self {
35        Self {
36            diagram,
37            spider_positions: FxHashMap::default(),
38            layers: Vec::new(),
39        }
40    }
41
42    /// Extract a quantum circuit from the ZX-diagram
43    pub fn extract_circuit(&mut self) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
44        // First, perform graph analysis to understand the structure
45        self.analyze_diagram()?;
46
47        // Extract gates layer by layer
48        self.extract_gates()?;
49
50        // Flatten layers into a single circuit
51        let mut circuit = Vec::new();
52        for layer in &self.layers {
53            circuit.extend(layer.gates.clone());
54        }
55
56        Ok(circuit)
57    }
58
59    /// Analyze the diagram structure
60    fn analyze_diagram(&mut self) -> QuantRS2Result<()> {
61        // Find the flow from inputs to outputs using BFS
62        let inputs = self.diagram.inputs.clone();
63        let outputs = self.diagram.outputs.clone();
64
65        // Create a topological ordering of spiders
66        let topo_order = self.topological_sort(&inputs, &outputs)?;
67
68        // Assign positions to spiders
69        for (layer_idx, spider_id) in topo_order.iter().enumerate() {
70            self.spider_positions.insert(*spider_id, (layer_idx, 0));
71        }
72
73        Ok(())
74    }
75
76    /// Perform topological sort from inputs to outputs
77    fn topological_sort(&self, inputs: &[usize], outputs: &[usize]) -> QuantRS2Result<Vec<usize>> {
78        let mut in_degree: FxHashMap<usize, usize> = FxHashMap::default();
79        let mut adjacency: FxHashMap<usize, Vec<usize>> = FxHashMap::default();
80
81        // Build directed graph (from inputs to outputs)
82        for &spider_id in self.diagram.spiders.keys() {
83            in_degree.insert(spider_id, 0);
84            adjacency.insert(spider_id, Vec::new());
85        }
86
87        // Count in-degrees and build adjacency
88        for (&source, neighbors) in &self.diagram.adjacency {
89            for &(target, _) in neighbors {
90                // Only count edges going "forward" (avoid double counting)
91                if self.is_forward_edge(source, target, inputs, outputs) {
92                    if let Some(degree) = in_degree.get_mut(&target) {
93                        *degree += 1;
94                    }
95                    if let Some(adj) = adjacency.get_mut(&source) {
96                        adj.push(target);
97                    }
98                }
99            }
100        }
101
102        // Kahn's algorithm for topological sort
103        let mut queue: VecDeque<usize> = inputs.iter().copied().collect();
104        let mut topo_order = Vec::new();
105
106        while let Some(spider) = queue.pop_front() {
107            topo_order.push(spider);
108
109            if let Some(neighbors) = adjacency.get(&spider) {
110                for &neighbor in neighbors {
111                    if let Some(degree) = in_degree.get_mut(&neighbor) {
112                        *degree -= 1;
113                        if *degree == 0 {
114                            queue.push_back(neighbor);
115                        }
116                    }
117                }
118            }
119        }
120
121        if topo_order.len() != self.diagram.spiders.len() {
122            return Err(QuantRS2Error::InvalidInput(
123                "ZX-diagram contains cycles or disconnected components".to_string(),
124            ));
125        }
126
127        Ok(topo_order)
128    }
129
130    /// Determine if an edge goes "forward" in the circuit
131    fn is_forward_edge(
132        &self,
133        source: usize,
134        target: usize,
135        inputs: &[usize],
136        outputs: &[usize],
137    ) -> bool {
138        // More sophisticated flow analysis
139        // Boundary nodes have clear direction
140        if inputs.contains(&source) && !inputs.contains(&target) {
141            return true;
142        }
143        if outputs.contains(&target) && !outputs.contains(&source) {
144            return true;
145        }
146
147        // For non-boundary nodes, check if source is closer to inputs
148        let source_dist = self.distance_from_inputs(source, inputs);
149        let target_dist = self.distance_from_inputs(target, inputs);
150
151        source_dist < target_dist
152    }
153
154    /// Calculate minimum distance from a spider to any input
155    fn distance_from_inputs(&self, spider: usize, inputs: &[usize]) -> usize {
156        if inputs.contains(&spider) {
157            return 0;
158        }
159
160        let mut visited = HashSet::new();
161        let mut queue = VecDeque::new();
162
163        for &input in inputs {
164            queue.push_back((input, 0));
165            visited.insert(input);
166        }
167
168        while let Some((current, dist)) = queue.pop_front() {
169            if current == spider {
170                return dist;
171            }
172
173            for (neighbor, _) in self.diagram.neighbors(current) {
174                if visited.insert(neighbor) {
175                    queue.push_back((neighbor, dist + 1));
176                }
177            }
178        }
179
180        usize::MAX // Not reachable from inputs
181    }
182
183    /// Extract gates from the analyzed diagram
184    fn extract_gates(&mut self) -> QuantRS2Result<()> {
185        // Group spiders by qubit
186        let qubit_spiders = self.group_by_qubit()?;
187
188        // Process each qubit line
189        for (qubit, spider_chain) in qubit_spiders {
190            self.extract_single_qubit_gates(&qubit, &spider_chain)?;
191        }
192
193        // Extract two-qubit gates
194        self.extract_two_qubit_gates()?;
195
196        Ok(())
197    }
198
199    /// Group spiders by their associated qubit
200    fn group_by_qubit(&self) -> QuantRS2Result<FxHashMap<QubitId, Vec<usize>>> {
201        let mut qubit_spiders: FxHashMap<QubitId, Vec<usize>> = FxHashMap::default();
202
203        // Start from input boundaries
204        for &input_id in &self.diagram.inputs {
205            if let Some(input_spider) = self.diagram.spiders.get(&input_id) {
206                if let Some(qubit) = input_spider.qubit {
207                    let chain = self.trace_qubit_line(input_id)?;
208                    qubit_spiders.insert(qubit, chain);
209                }
210            }
211        }
212
213        Ok(qubit_spiders)
214    }
215
216    /// Trace a qubit line from input to output
217    fn trace_qubit_line(&self, start: usize) -> QuantRS2Result<Vec<usize>> {
218        let mut chain = vec![start];
219        let mut current = start;
220        let mut visited = HashSet::new();
221        visited.insert(start);
222
223        // Follow the qubit line
224        while !self.diagram.outputs.contains(&current) {
225            let neighbors = self.diagram.neighbors(current);
226
227            // Find the next spider in the chain (not visited, not backwards)
228            let next = neighbors
229                .iter()
230                .find(|(id, _)| !visited.contains(id) && self.is_on_qubit_line(*id))
231                .map(|(id, _)| *id);
232
233            if let Some(next_id) = next {
234                chain.push(next_id);
235                visited.insert(next_id);
236                current = next_id;
237            } else {
238                break;
239            }
240        }
241
242        Ok(chain)
243    }
244
245    /// Check if a spider is on a qubit line (not a control/target connection)
246    fn is_on_qubit_line(&self, spider_id: usize) -> bool {
247        // Boundary spiders are always on qubit lines
248        if let Some(spider) = self.diagram.spiders.get(&spider_id) {
249            spider.spider_type == SpiderType::Boundary || self.diagram.degree(spider_id) <= 2
250        } else {
251            false
252        }
253    }
254
255    /// Extract single-qubit gates from a chain of spiders
256    fn extract_single_qubit_gates(
257        &mut self,
258        qubit: &QubitId,
259        spider_chain: &[usize],
260    ) -> QuantRS2Result<()> {
261        let mut i = 0;
262
263        while i < spider_chain.len() {
264            let spider_id = spider_chain[i];
265
266            if let Some(spider) = self.diagram.spiders.get(&spider_id) {
267                match spider.spider_type {
268                    SpiderType::Z if spider.phase.abs() > 1e-10 => {
269                        // Z rotation
270                        let gate: Box<dyn GateOp> = Box::new(RotationZ {
271                            target: *qubit,
272                            theta: spider.phase,
273                        });
274                        self.add_gate_to_layer(gate, i);
275                    }
276                    SpiderType::X if spider.phase.abs() > 1e-10 => {
277                        // X rotation
278                        let gate: Box<dyn GateOp> = Box::new(RotationX {
279                            target: *qubit,
280                            theta: spider.phase,
281                        });
282                        self.add_gate_to_layer(gate, i);
283                    }
284                    _ => {}
285                }
286
287                // Check for Hadamard edges
288                if i + 1 < spider_chain.len() {
289                    let next_id = spider_chain[i + 1];
290                    let edge_type = self.get_edge_type(spider_id, next_id);
291
292                    if edge_type == Some(EdgeType::Hadamard) {
293                        let gate: Box<dyn GateOp> = Box::new(Hadamard { target: *qubit });
294                        self.add_gate_to_layer(gate, i);
295                    }
296                }
297            }
298
299            i += 1;
300        }
301
302        Ok(())
303    }
304
305    /// Extract two-qubit gates from spider connections
306    fn extract_two_qubit_gates(&mut self) -> QuantRS2Result<()> {
307        let mut processed = HashSet::new();
308
309        for (&spider_id, spider) in &self.diagram.spiders.clone() {
310            if processed.contains(&spider_id) {
311                continue;
312            }
313
314            // Look for patterns that represent two-qubit gates
315            if self.diagram.degree(spider_id) > 2 {
316                // This spider connects multiple qubits
317                let neighbors = self.diagram.neighbors(spider_id);
318
319                // Check for CNOT pattern (Z spider connected to X spider)
320                if spider.spider_type == SpiderType::Z && spider.phase.abs() < 1e-10 {
321                    for &(neighbor_id, edge_type) in &neighbors {
322                        if let Some(neighbor) = self.diagram.spiders.get(&neighbor_id) {
323                            if neighbor.spider_type == SpiderType::X
324                                && neighbor.phase.abs() < 1e-10
325                                && edge_type == EdgeType::Regular
326                            {
327                                // Found CNOT pattern
328                                if let (Some(control_qubit), Some(target_qubit)) = (
329                                    self.get_spider_qubit(spider_id),
330                                    self.get_spider_qubit(neighbor_id),
331                                ) {
332                                    let gate: Box<dyn GateOp> = Box::new(CNOT {
333                                        control: control_qubit,
334                                        target: target_qubit,
335                                    });
336
337                                    let layer = self
338                                        .spider_positions
339                                        .get(&spider_id)
340                                        .map_or(0, |(l, _)| *l);
341
342                                    self.add_gate_to_layer(gate, layer);
343                                    processed.insert(spider_id);
344                                    processed.insert(neighbor_id);
345                                }
346                            }
347                        }
348                    }
349                }
350
351                // Check for CZ pattern (two Z spiders connected with Hadamard)
352                if spider.spider_type == SpiderType::Z && spider.phase.abs() < 1e-10 {
353                    for &(neighbor_id, edge_type) in &neighbors {
354                        if let Some(neighbor) = self.diagram.spiders.get(&neighbor_id) {
355                            if neighbor.spider_type == SpiderType::Z
356                                && neighbor.phase.abs() < 1e-10
357                                && edge_type == EdgeType::Hadamard
358                            {
359                                // Found CZ pattern
360                                if let (Some(qubit1), Some(qubit2)) = (
361                                    self.get_spider_qubit(spider_id),
362                                    self.get_spider_qubit(neighbor_id),
363                                ) {
364                                    let gate: Box<dyn GateOp> = Box::new(CZ {
365                                        control: qubit1,
366                                        target: qubit2,
367                                    });
368
369                                    let layer = self
370                                        .spider_positions
371                                        .get(&spider_id)
372                                        .map_or(0, |(l, _)| *l);
373
374                                    self.add_gate_to_layer(gate, layer);
375                                    processed.insert(spider_id);
376                                    processed.insert(neighbor_id);
377                                }
378                            }
379                        }
380                    }
381                }
382            }
383        }
384
385        Ok(())
386    }
387
388    /// Get the qubit associated with a spider
389    fn get_spider_qubit(&self, spider_id: usize) -> Option<QubitId> {
390        // First check if it's a boundary spider
391        if let Some(spider) = self.diagram.spiders.get(&spider_id) {
392            if let Some(qubit) = spider.qubit {
393                return Some(qubit);
394            }
395        }
396
397        // Otherwise, trace back to find the input boundary
398        self.find_connected_boundary(spider_id)
399    }
400
401    /// Find a boundary spider connected to this spider
402    fn find_connected_boundary(&self, spider_id: usize) -> Option<QubitId> {
403        let mut visited = HashSet::new();
404        let mut queue = VecDeque::new();
405        queue.push_back(spider_id);
406        visited.insert(spider_id);
407
408        while let Some(current) = queue.pop_front() {
409            if let Some(spider) = self.diagram.spiders.get(&current) {
410                if let Some(qubit) = spider.qubit {
411                    return Some(qubit);
412                }
413            }
414
415            for (neighbor, _) in self.diagram.neighbors(current) {
416                if visited.insert(neighbor) {
417                    queue.push_back(neighbor);
418                }
419            }
420        }
421
422        None
423    }
424
425    /// Get the edge type between two spiders
426    fn get_edge_type(&self, spider1: usize, spider2: usize) -> Option<EdgeType> {
427        self.diagram
428            .neighbors(spider1)
429            .iter()
430            .find(|(id, _)| *id == spider2)
431            .map(|(_, edge_type)| *edge_type)
432    }
433
434    /// Add a gate to the appropriate layer
435    fn add_gate_to_layer(&mut self, gate: Box<dyn GateOp>, layer_idx: usize) {
436        // Ensure we have enough layers
437        while self.layers.len() <= layer_idx {
438            self.layers.push(GateLayer { gates: Vec::new() });
439        }
440
441        self.layers[layer_idx].gates.push(gate);
442    }
443}
444
445/// Complete ZX-calculus optimization pipeline
446pub struct ZXPipeline {
447    optimizer: ZXOptimizer,
448}
449
450impl ZXPipeline {
451    /// Create a new ZX optimization pipeline
452    pub const fn new() -> Self {
453        Self {
454            optimizer: ZXOptimizer::new(),
455        }
456    }
457
458    /// Optimize a circuit using ZX-calculus
459    pub fn optimize(&self, gates: &[Box<dyn GateOp>]) -> QuantRS2Result<Vec<Box<dyn GateOp>>> {
460        // Find number of qubits
461        let num_qubits = gates
462            .iter()
463            .flat_map(|g| g.qubits())
464            .map(|q| q.0 + 1)
465            .max()
466            .unwrap_or(0);
467
468        // Convert to ZX-diagram
469        let mut converter = CircuitToZX::new(num_qubits as usize);
470        for gate in gates {
471            converter.add_gate(gate.as_ref())?;
472        }
473
474        let mut diagram = converter.into_diagram();
475
476        // Optimize the diagram
477        let rewrites = diagram.simplify(100);
478        println!("Applied {rewrites} ZX-calculus rewrites");
479
480        // Extract optimized circuit
481        let mut extractor = ZXExtractor::new(diagram);
482        extractor.extract_circuit()
483    }
484
485    /// Compare T-count before and after optimization
486    pub fn compare_t_count(
487        &self,
488        original: &[Box<dyn GateOp>],
489        optimized: &[Box<dyn GateOp>],
490    ) -> (usize, usize) {
491        let count_t = |gates: &[Box<dyn GateOp>]| {
492            gates
493                .iter()
494                .filter(|g| {
495                    g.name() == "T"
496                        || (g.name() == "RZ" && {
497                            if let Some(rz) = g.as_any().downcast_ref::<RotationZ>() {
498                                (rz.theta - PI / 4.0).abs() < 1e-10
499                            } else {
500                                false
501                            }
502                        })
503                })
504                .count()
505        };
506
507        (count_t(original), count_t(optimized))
508    }
509}
510
511impl Default for ZXPipeline {
512    fn default() -> Self {
513        Self::new()
514    }
515}
516
517#[cfg(test)]
518mod tests {
519    use super::*;
520
521    #[test]
522    fn test_circuit_extraction_identity() {
523        // Create a simple diagram with just boundaries
524        let mut diagram = ZXDiagram::new();
525        let input = diagram.add_boundary(QubitId(0), true);
526        let output = diagram.add_boundary(QubitId(0), false);
527        diagram.add_edge(input, output, EdgeType::Regular);
528
529        let mut extractor = ZXExtractor::new(diagram);
530        let circuit = extractor
531            .extract_circuit()
532            .expect("Failed to extract circuit");
533
534        // Should extract empty circuit (identity)
535        assert_eq!(circuit.len(), 0);
536    }
537
538    #[test]
539    fn test_circuit_extraction_single_gate() {
540        // Create diagram with single Z rotation
541        let mut diagram = ZXDiagram::new();
542        let input = diagram.add_boundary(QubitId(0), true);
543        let z_spider = diagram.add_spider(SpiderType::Z, PI / 2.0);
544        let output = diagram.add_boundary(QubitId(0), false);
545
546        diagram.add_edge(input, z_spider, EdgeType::Regular);
547        diagram.add_edge(z_spider, output, EdgeType::Regular);
548
549        let mut extractor = ZXExtractor::new(diagram);
550        let circuit = extractor
551            .extract_circuit()
552            .expect("Failed to extract circuit");
553
554        // Should extract one RZ gate
555        assert_eq!(circuit.len(), 1);
556        assert_eq!(circuit[0].name(), "RZ");
557    }
558
559    #[test]
560    fn test_zx_pipeline_optimization() {
561        // Create a circuit that can be optimized: HZH = X
562        let gates: Vec<Box<dyn GateOp>> = vec![
563            Box::new(Hadamard { target: QubitId(0) }),
564            Box::new(PauliZ { target: QubitId(0) }),
565            Box::new(Hadamard { target: QubitId(0) }),
566        ];
567
568        let pipeline = ZXPipeline::new();
569        let optimized = pipeline
570            .optimize(&gates)
571            .expect("Failed to optimize circuit");
572
573        // The optimized circuit should be simpler
574        assert!(optimized.len() <= gates.len());
575    }
576
577    #[test]
578    fn test_t_count_reduction() {
579        // Create a circuit with T gates
580        let gates: Vec<Box<dyn GateOp>> = vec![
581            Box::new(RotationZ {
582                target: QubitId(0),
583                theta: PI / 4.0,
584            }), // T gate
585            Box::new(RotationZ {
586                target: QubitId(0),
587                theta: PI / 4.0,
588            }), // T gate
589        ];
590
591        let pipeline = ZXPipeline::new();
592        let optimized = pipeline
593            .optimize(&gates)
594            .expect("Failed to optimize circuit");
595
596        let (original_t, optimized_t) = pipeline.compare_t_count(&gates, &optimized);
597
598        // Two T gates should fuse into S gate (or equivalent)
599        assert_eq!(original_t, 2);
600        assert!(optimized_t <= original_t);
601    }
602}