ruqu_neural_decoder/
graph.rs

1//! Syndrome Graph Construction
2//!
3//! This module provides functionality to construct graphs from syndrome bitmaps
4//! for quantum error correction codes, particularly surface codes.
5//!
6//! ## Graph Structure
7//!
8//! Each detector in the syndrome becomes a node in the graph, with edges
9//! connecting neighboring detectors. Edge weights are derived from the
10//! correlation structure of the error model.
11
12use crate::error::{NeuralDecoderError, Result};
13use ndarray::{Array1, Array2};
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16
17/// A node in the detector graph
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct Node {
20    /// Unique node identifier
21    pub id: usize,
22    /// Row position in the surface code lattice
23    pub row: usize,
24    /// Column position in the surface code lattice
25    pub col: usize,
26    /// Whether this detector is fired (syndrome bit is 1)
27    pub fired: bool,
28    /// Node type (X-type or Z-type stabilizer)
29    pub node_type: NodeType,
30    /// Feature vector for this node
31    pub features: Vec<f32>,
32}
33
34/// Type of stabilizer node
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
36pub enum NodeType {
37    /// X-type stabilizer (measures bit flips)
38    XStabilizer,
39    /// Z-type stabilizer (measures phase flips)
40    ZStabilizer,
41    /// Boundary node (virtual)
42    Boundary,
43}
44
45/// An edge in the detector graph
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct Edge {
48    /// Source node index
49    pub from: usize,
50    /// Target node index
51    pub to: usize,
52    /// Edge weight (derived from error probability)
53    pub weight: f32,
54    /// Edge type
55    pub edge_type: EdgeType,
56}
57
58/// Type of edge
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
60pub enum EdgeType {
61    /// Horizontal edge in the lattice
62    Horizontal,
63    /// Vertical edge in the lattice
64    Vertical,
65    /// Temporal edge (between measurement rounds)
66    Temporal,
67    /// Boundary edge (to virtual boundary node)
68    Boundary,
69}
70
71/// The detector graph representation
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct DetectorGraph {
74    /// All nodes in the graph
75    pub nodes: Vec<Node>,
76    /// All edges in the graph
77    pub edges: Vec<Edge>,
78    /// Adjacency list representation
79    adjacency: HashMap<usize, Vec<usize>>,
80    /// Code distance
81    pub distance: usize,
82    /// Number of fired detectors
83    pub num_fired: usize,
84}
85
86impl DetectorGraph {
87    /// Create an empty detector graph
88    pub fn new(distance: usize) -> Self {
89        Self {
90            nodes: Vec::new(),
91            edges: Vec::new(),
92            adjacency: HashMap::new(),
93            distance,
94            num_fired: 0,
95        }
96    }
97
98    /// Add a node to the graph
99    pub fn add_node(&mut self, node: Node) {
100        let id = node.id;
101        if node.fired {
102            self.num_fired += 1;
103        }
104        self.nodes.push(node);
105        self.adjacency.entry(id).or_default();
106    }
107
108    /// Add an edge to the graph
109    pub fn add_edge(&mut self, edge: Edge) {
110        self.adjacency.entry(edge.from).or_default().push(edge.to);
111        self.adjacency.entry(edge.to).or_default().push(edge.from);
112        self.edges.push(edge);
113    }
114
115    /// Get neighbors of a node
116    pub fn neighbors(&self, node_id: usize) -> Option<&Vec<usize>> {
117        self.adjacency.get(&node_id)
118    }
119
120    /// Get the node features as a matrix
121    pub fn node_features(&self) -> Array2<f32> {
122        if self.nodes.is_empty() {
123            return Array2::zeros((0, 1));
124        }
125
126        let feature_dim = self.nodes[0].features.len();
127        let mut features = Array2::zeros((self.nodes.len(), feature_dim));
128
129        for (i, node) in self.nodes.iter().enumerate() {
130            for (j, &f) in node.features.iter().enumerate() {
131                features[[i, j]] = f;
132            }
133        }
134
135        features
136    }
137
138    /// Get the adjacency matrix
139    pub fn adjacency_matrix(&self) -> Array2<f32> {
140        let n = self.nodes.len();
141        let mut adj = Array2::zeros((n, n));
142
143        for edge in &self.edges {
144            adj[[edge.from, edge.to]] = edge.weight;
145            adj[[edge.to, edge.from]] = edge.weight;
146        }
147
148        adj
149    }
150
151    /// Get edge weights as a vector
152    pub fn edge_weights(&self) -> Array1<f32> {
153        Array1::from_iter(self.edges.iter().map(|e| e.weight))
154    }
155
156    /// Get fired detector indices
157    pub fn fired_indices(&self) -> Vec<usize> {
158        self.nodes
159            .iter()
160            .filter(|n| n.fired)
161            .map(|n| n.id)
162            .collect()
163    }
164
165    /// Check if the graph is valid
166    pub fn validate(&self) -> Result<()> {
167        if self.nodes.is_empty() {
168            return Err(NeuralDecoderError::EmptyGraph);
169        }
170
171        // Check all edge endpoints are valid
172        for edge in &self.edges {
173            if edge.from >= self.nodes.len() || edge.to >= self.nodes.len() {
174                return Err(NeuralDecoderError::InvalidDetector(
175                    edge.from.max(edge.to)
176                ));
177            }
178        }
179
180        Ok(())
181    }
182
183    /// Get the number of nodes
184    pub fn num_nodes(&self) -> usize {
185        self.nodes.len()
186    }
187
188    /// Get the number of edges
189    pub fn num_edges(&self) -> usize {
190        self.edges.len()
191    }
192}
193
194/// Builder for constructing detector graphs
195pub struct GraphBuilder {
196    distance: usize,
197    syndrome: Option<Vec<bool>>,
198    node_type_pattern: NodeTypePattern,
199    error_rate: f64,
200}
201
202/// Pattern for determining node types
203#[derive(Debug, Clone, Copy)]
204pub enum NodeTypePattern {
205    /// Checkerboard pattern (standard surface code)
206    Checkerboard,
207    /// All X-type
208    AllX,
209    /// All Z-type
210    AllZ,
211}
212
213impl GraphBuilder {
214    /// Create a builder for a surface code of given distance
215    pub fn from_surface_code(distance: usize) -> Self {
216        Self {
217            distance,
218            syndrome: None,
219            node_type_pattern: NodeTypePattern::Checkerboard,
220            error_rate: 0.001,
221        }
222    }
223
224    /// Set the syndrome bitmap
225    pub fn with_syndrome(mut self, syndrome: &[bool]) -> Result<Self> {
226        let expected = self.distance * self.distance;
227        if syndrome.len() != expected {
228            return Err(NeuralDecoderError::syndrome_dim(
229                self.distance,
230                syndrome.len(),
231                1,
232            ));
233        }
234        self.syndrome = Some(syndrome.to_vec());
235        Ok(self)
236    }
237
238    /// Set the node type pattern
239    pub fn with_node_pattern(mut self, pattern: NodeTypePattern) -> Self {
240        self.node_type_pattern = pattern;
241        self
242    }
243
244    /// Set the error rate (for edge weights)
245    pub fn with_error_rate(mut self, rate: f64) -> Self {
246        self.error_rate = rate;
247        self
248    }
249
250    /// Build the detector graph
251    pub fn build(self) -> Result<DetectorGraph> {
252        let d = self.distance;
253        let mut graph = DetectorGraph::new(d);
254
255        // Default syndrome: all zeros
256        let syndrome = self.syndrome.unwrap_or_else(|| vec![false; d * d]);
257
258        // Create nodes
259        for row in 0..d {
260            for col in 0..d {
261                let id = row * d + col;
262                let fired = syndrome.get(id).copied().unwrap_or(false);
263
264                let node_type = match self.node_type_pattern {
265                    NodeTypePattern::Checkerboard => {
266                        if (row + col) % 2 == 0 {
267                            NodeType::XStabilizer
268                        } else {
269                            NodeType::ZStabilizer
270                        }
271                    }
272                    NodeTypePattern::AllX => NodeType::XStabilizer,
273                    NodeTypePattern::AllZ => NodeType::ZStabilizer,
274                };
275
276                // Feature vector: [fired, row_norm, col_norm, node_type_x, node_type_z]
277                let features = vec![
278                    if fired { 1.0 } else { 0.0 },
279                    row as f32 / d as f32,
280                    col as f32 / d as f32,
281                    if node_type == NodeType::XStabilizer { 1.0 } else { 0.0 },
282                    if node_type == NodeType::ZStabilizer { 1.0 } else { 0.0 },
283                ];
284
285                graph.add_node(Node {
286                    id,
287                    row,
288                    col,
289                    fired,
290                    node_type,
291                    features,
292                });
293            }
294        }
295
296        // Create edges (grid connectivity)
297        let weight = (-self.error_rate.ln()) as f32;
298
299        for row in 0..d {
300            for col in 0..d {
301                let id = row * d + col;
302
303                // Horizontal edge
304                if col + 1 < d {
305                    let neighbor = row * d + (col + 1);
306                    graph.add_edge(Edge {
307                        from: id,
308                        to: neighbor,
309                        weight,
310                        edge_type: EdgeType::Horizontal,
311                    });
312                }
313
314                // Vertical edge
315                if row + 1 < d {
316                    let neighbor = (row + 1) * d + col;
317                    graph.add_edge(Edge {
318                        from: id,
319                        to: neighbor,
320                        weight,
321                        edge_type: EdgeType::Vertical,
322                    });
323                }
324            }
325        }
326
327        graph.validate()?;
328        Ok(graph)
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    #[test]
337    fn test_node_creation() {
338        let node = Node {
339            id: 0,
340            row: 0,
341            col: 0,
342            fired: true,
343            node_type: NodeType::XStabilizer,
344            features: vec![1.0],
345        };
346        assert_eq!(node.id, 0);
347        assert!(node.fired);
348    }
349
350    #[test]
351    fn test_edge_creation() {
352        let edge = Edge {
353            from: 0,
354            to: 1,
355            weight: 1.5,
356            edge_type: EdgeType::Horizontal,
357        };
358        assert_eq!(edge.from, 0);
359        assert_eq!(edge.to, 1);
360    }
361
362    #[test]
363    fn test_graph_construction_d3() {
364        let graph = GraphBuilder::from_surface_code(3)
365            .build()
366            .unwrap();
367
368        // 3x3 = 9 nodes
369        assert_eq!(graph.num_nodes(), 9);
370
371        // Grid edges: 2*3 horizontal + 3*2 vertical = 12 edges
372        assert_eq!(graph.num_edges(), 12);
373    }
374
375    #[test]
376    fn test_graph_construction_d5() {
377        let graph = GraphBuilder::from_surface_code(5)
378            .build()
379            .unwrap();
380
381        // 5x5 = 25 nodes
382        assert_eq!(graph.num_nodes(), 25);
383
384        // Grid edges: 4*5 horizontal + 5*4 vertical = 40 edges
385        assert_eq!(graph.num_edges(), 40);
386    }
387
388    #[test]
389    fn test_graph_with_syndrome() {
390        let syndrome = vec![true, false, true, false, true, false, false, false, true];
391        let graph = GraphBuilder::from_surface_code(3)
392            .with_syndrome(&syndrome)
393            .unwrap()
394            .build()
395            .unwrap();
396
397        assert_eq!(graph.num_fired, 4);
398        assert_eq!(graph.fired_indices(), vec![0, 2, 4, 8]);
399    }
400
401    #[test]
402    fn test_graph_syndrome_dimension_mismatch() {
403        let syndrome = vec![true, false, true]; // Wrong size
404        let result = GraphBuilder::from_surface_code(3)
405            .with_syndrome(&syndrome);
406
407        assert!(result.is_err());
408    }
409
410    #[test]
411    fn test_graph_adjacency() {
412        let graph = GraphBuilder::from_surface_code(3)
413            .build()
414            .unwrap();
415
416        // Corner node (0) should have 2 neighbors
417        let neighbors = graph.neighbors(0).unwrap();
418        assert_eq!(neighbors.len(), 2);
419
420        // Center node (4) should have 4 neighbors
421        let neighbors = graph.neighbors(4).unwrap();
422        assert_eq!(neighbors.len(), 4);
423    }
424
425    #[test]
426    fn test_node_features_matrix() {
427        let graph = GraphBuilder::from_surface_code(3)
428            .build()
429            .unwrap();
430
431        let features = graph.node_features();
432        assert_eq!(features.shape(), &[9, 5]);
433    }
434
435    #[test]
436    fn test_adjacency_matrix() {
437        let graph = GraphBuilder::from_surface_code(3)
438            .build()
439            .unwrap();
440
441        let adj = graph.adjacency_matrix();
442        assert_eq!(adj.shape(), &[9, 9]);
443
444        // Matrix should be symmetric
445        for i in 0..9 {
446            for j in 0..9 {
447                assert_eq!(adj[[i, j]], adj[[j, i]]);
448            }
449        }
450    }
451
452    #[test]
453    fn test_edge_weights() {
454        let graph = GraphBuilder::from_surface_code(3)
455            .with_error_rate(0.01)
456            .build()
457            .unwrap();
458
459        let weights = graph.edge_weights();
460        assert_eq!(weights.len(), 12);
461
462        // All weights should be positive
463        for w in weights.iter() {
464            assert!(*w > 0.0);
465        }
466    }
467
468    #[test]
469    fn test_node_type_pattern_checkerboard() {
470        let graph = GraphBuilder::from_surface_code(3)
471            .with_node_pattern(NodeTypePattern::Checkerboard)
472            .build()
473            .unwrap();
474
475        // Check checkerboard pattern
476        for node in &graph.nodes {
477            let expected = if (node.row + node.col) % 2 == 0 {
478                NodeType::XStabilizer
479            } else {
480                NodeType::ZStabilizer
481            };
482            assert_eq!(node.node_type, expected);
483        }
484    }
485
486    #[test]
487    fn test_node_type_pattern_all_x() {
488        let graph = GraphBuilder::from_surface_code(3)
489            .with_node_pattern(NodeTypePattern::AllX)
490            .build()
491            .unwrap();
492
493        for node in &graph.nodes {
494            assert_eq!(node.node_type, NodeType::XStabilizer);
495        }
496    }
497
498    #[test]
499    fn test_empty_syndrome() {
500        let syndrome = vec![false; 9];
501        let graph = GraphBuilder::from_surface_code(3)
502            .with_syndrome(&syndrome)
503            .unwrap()
504            .build()
505            .unwrap();
506
507        assert_eq!(graph.num_fired, 0);
508        assert!(graph.fired_indices().is_empty());
509    }
510
511    #[test]
512    fn test_all_fired_syndrome() {
513        let syndrome = vec![true; 9];
514        let graph = GraphBuilder::from_surface_code(3)
515            .with_syndrome(&syndrome)
516            .unwrap()
517            .build()
518            .unwrap();
519
520        assert_eq!(graph.num_fired, 9);
521        assert_eq!(graph.fired_indices().len(), 9);
522    }
523
524    #[test]
525    fn test_graph_validation() {
526        let graph = GraphBuilder::from_surface_code(3)
527            .build()
528            .unwrap();
529
530        assert!(graph.validate().is_ok());
531    }
532
533    #[test]
534    fn test_empty_graph_validation() {
535        let graph = DetectorGraph::new(3);
536        assert!(graph.validate().is_err());
537    }
538}