use crate::error::{NeuralDecoderError, Result};
use ndarray::Array2;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum StabilizerType {
X,
Z,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SurfaceCodeTopology {
Rotated,
Unrotated,
Planar,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranslateConfig {
pub distance: usize,
pub topology: SurfaceCodeTopology,
pub error_rate: f64,
pub measurement_error_rate: f64,
pub num_rounds: usize,
pub include_boundaries: bool,
}
impl Default for TranslateConfig {
fn default() -> Self {
Self {
distance: 5,
topology: SurfaceCodeTopology::Rotated,
error_rate: 0.001,
measurement_error_rate: 0.001,
num_rounds: 1,
include_boundaries: true,
}
}
}
impl TranslateConfig {
pub fn validate(&self) -> Result<()> {
if self.distance < 3 {
return Err(NeuralDecoderError::ConfigError(
"Distance must be at least 3".to_string(),
));
}
if self.error_rate < 0.0 || self.error_rate > 1.0 {
return Err(NeuralDecoderError::ConfigError(format!(
"Error rate must be in [0, 1], got {}",
self.error_rate
)));
}
Ok(())
}
pub fn num_x_stabilizers(&self) -> usize {
match self.topology {
SurfaceCodeTopology::Rotated => (self.distance - 1) * self.distance / 2 + self.distance / 2,
SurfaceCodeTopology::Unrotated => (self.distance - 1) * self.distance,
SurfaceCodeTopology::Planar => (self.distance - 1) * self.distance,
}
}
pub fn num_z_stabilizers(&self) -> usize {
self.num_x_stabilizers() }
pub fn num_detectors_per_round(&self) -> usize {
self.num_x_stabilizers() + self.num_z_stabilizers()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DetectorNode {
pub index: usize,
pub stabilizer_type: StabilizerType,
pub position: (usize, usize),
pub round: usize,
pub is_boundary: bool,
pub features: Vec<f32>,
}
impl DetectorNode {
pub fn new(
index: usize,
stabilizer_type: StabilizerType,
position: (usize, usize),
round: usize,
) -> Self {
Self {
index,
stabilizer_type,
position,
round,
is_boundary: false,
features: vec![],
}
}
pub fn to_features(&self, max_distance: usize, num_rounds: usize) -> Vec<f32> {
let mut features = Vec::with_capacity(8);
features.push(self.position.0 as f32 / max_distance as f32);
features.push(self.position.1 as f32 / max_distance as f32);
features.push(if self.stabilizer_type == StabilizerType::X { 1.0 } else { 0.0 });
features.push(if self.stabilizer_type == StabilizerType::Z { 1.0 } else { 0.0 });
features.push(self.round as f32 / num_rounds.max(1) as f32);
features.push(if self.is_boundary { 1.0 } else { 0.0 });
let center = max_distance as f32 / 2.0;
let dist = ((self.position.0 as f32 - center).powi(2)
+ (self.position.1 as f32 - center).powi(2))
.sqrt();
features.push(dist / (max_distance as f32 * 1.414));
features.push(((self.position.0 + self.position.1) % 2) as f32);
features
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DetectorGraph {
config: TranslateConfig,
nodes: Vec<DetectorNode>,
adjacency: HashMap<usize, Vec<usize>>,
edge_weights: HashMap<(usize, usize), f32>,
active_detectors: HashSet<usize>,
boundary_nodes: Vec<usize>,
}
impl DetectorGraph {
pub fn new(config: TranslateConfig) -> Result<Self> {
config.validate()?;
let mut graph = Self {
config: config.clone(),
nodes: vec![],
adjacency: HashMap::new(),
edge_weights: HashMap::new(),
active_detectors: HashSet::new(),
boundary_nodes: vec![],
};
graph.build_topology()?;
Ok(graph)
}
fn build_topology(&mut self) -> Result<()> {
match self.config.topology {
SurfaceCodeTopology::Rotated => self.build_rotated_topology(),
SurfaceCodeTopology::Unrotated => self.build_unrotated_topology(),
SurfaceCodeTopology::Planar => self.build_planar_topology(),
}
}
fn build_rotated_topology(&mut self) -> Result<()> {
let d = self.config.distance;
let mut node_idx = 0;
for round in 0..self.config.num_rounds {
for i in 0..d - 1 {
for j in 0..d - 1 {
if (i + j) % 2 == 1 {
self.nodes.push(DetectorNode::new(
node_idx,
StabilizerType::X,
(i, j),
round,
));
self.adjacency.insert(node_idx, vec![]);
node_idx += 1;
}
}
}
for i in 0..d - 1 {
for j in 0..d - 1 {
if (i + j) % 2 == 0 {
self.nodes.push(DetectorNode::new(
node_idx,
StabilizerType::Z,
(i, j),
round,
));
self.adjacency.insert(node_idx, vec![]);
node_idx += 1;
}
}
}
}
if self.config.include_boundaries {
self.add_boundary_nodes(&mut node_idx);
}
self.build_edges();
Ok(())
}
fn build_unrotated_topology(&mut self) -> Result<()> {
let d = self.config.distance;
let mut node_idx = 0;
for round in 0..self.config.num_rounds {
for i in 0..d - 1 {
for j in 0..d {
self.nodes.push(DetectorNode::new(
node_idx,
StabilizerType::X,
(i, j),
round,
));
self.adjacency.insert(node_idx, vec![]);
node_idx += 1;
}
}
for i in 0..d {
for j in 0..d - 1 {
self.nodes.push(DetectorNode::new(
node_idx,
StabilizerType::Z,
(i, j),
round,
));
self.adjacency.insert(node_idx, vec![]);
node_idx += 1;
}
}
}
if self.config.include_boundaries {
self.add_boundary_nodes(&mut node_idx);
}
self.build_edges();
Ok(())
}
fn build_planar_topology(&mut self) -> Result<()> {
self.build_unrotated_topology()
}
fn add_boundary_nodes(&mut self, node_idx: &mut usize) {
let d = self.config.distance;
for i in 0..d {
let mut node = DetectorNode::new(*node_idx, StabilizerType::X, (i, 0), 0);
node.is_boundary = true;
self.nodes.push(node);
self.boundary_nodes.push(*node_idx);
self.adjacency.insert(*node_idx, vec![]);
*node_idx += 1;
let mut node = DetectorNode::new(*node_idx, StabilizerType::X, (i, d - 1), 0);
node.is_boundary = true;
self.nodes.push(node);
self.boundary_nodes.push(*node_idx);
self.adjacency.insert(*node_idx, vec![]);
*node_idx += 1;
}
for j in 0..d {
let mut node = DetectorNode::new(*node_idx, StabilizerType::Z, (0, j), 0);
node.is_boundary = true;
self.nodes.push(node);
self.boundary_nodes.push(*node_idx);
self.adjacency.insert(*node_idx, vec![]);
*node_idx += 1;
let mut node = DetectorNode::new(*node_idx, StabilizerType::Z, (d - 1, j), 0);
node.is_boundary = true;
self.nodes.push(node);
self.boundary_nodes.push(*node_idx);
self.adjacency.insert(*node_idx, vec![]);
*node_idx += 1;
}
}
fn build_edges(&mut self) {
let n_nodes = self.nodes.len();
for i in 0..n_nodes {
for j in (i + 1)..n_nodes {
let node_i = &self.nodes[i];
let node_j = &self.nodes[j];
if node_i.stabilizer_type == node_j.stabilizer_type
&& node_i.round == node_j.round
{
let dist = self.manhattan_distance(node_i.position, node_j.position);
if dist <= 2 {
self.add_edge(i, j);
}
}
if node_i.position == node_j.position
&& node_i.stabilizer_type == node_j.stabilizer_type
&& (node_i.round as i32 - node_j.round as i32).abs() == 1
{
self.add_edge(i, j);
}
}
}
for &boundary_idx in &self.boundary_nodes.clone() {
let boundary = &self.nodes[boundary_idx];
for i in 0..n_nodes {
if i == boundary_idx {
continue;
}
let node = &self.nodes[i];
if node.stabilizer_type == boundary.stabilizer_type
&& !node.is_boundary
{
let dist = self.manhattan_distance(boundary.position, node.position);
if dist <= 2 {
self.add_edge(boundary_idx, i);
}
}
}
}
}
fn add_edge(&mut self, i: usize, j: usize) {
let weight = self.compute_edge_weight(i, j);
self.adjacency.entry(i).or_default().push(j);
self.adjacency.entry(j).or_default().push(i);
let key = if i < j { (i, j) } else { (j, i) };
self.edge_weights.insert(key, weight);
}
fn compute_edge_weight(&self, i: usize, j: usize) -> f32 {
let node_i = &self.nodes[i];
let node_j = &self.nodes[j];
let spatial_dist = self.manhattan_distance(node_i.position, node_j.position);
let temporal_dist = (node_i.round as i32 - node_j.round as i32).abs() as usize;
let base_weight = if temporal_dist > 0 {
self.config.measurement_error_rate as f32
} else {
self.config.error_rate as f32
};
let dist = spatial_dist + temporal_dist;
let scaled_weight = base_weight * (1.0 / (dist as f32 + 1.0));
-scaled_weight.max(1e-10).ln()
}
fn manhattan_distance(&self, p1: (usize, usize), p2: (usize, usize)) -> usize {
((p1.0 as i32 - p2.0 as i32).abs() + (p1.1 as i32 - p2.1 as i32).abs()) as usize
}
pub fn translate_syndrome(&mut self, syndrome: &Array2<u8>, round: usize) -> Result<()> {
let rows = syndrome.shape()[0];
let cols = syndrome.shape()[1];
self.active_detectors.clear();
for i in 0..rows {
for j in 0..cols {
if syndrome[[i, j]] == 1 {
if let Some(idx) = self.find_detector_at((i, j), round) {
self.active_detectors.insert(idx);
}
}
}
}
Ok(())
}
fn find_detector_at(&self, position: (usize, usize), round: usize) -> Option<usize> {
self.nodes.iter().position(|node| {
node.position == position && node.round == round && !node.is_boundary
})
}
pub fn update_incremental(
&mut self,
changed_positions: &[(usize, usize)],
new_values: &[u8],
round: usize,
) -> Result<()> {
if changed_positions.len() != new_values.len() {
return Err(NeuralDecoderError::ConfigError(
"Position and value arrays must have same length".to_string(),
));
}
for (pos, &value) in changed_positions.iter().zip(new_values.iter()) {
if let Some(idx) = self.find_detector_at(*pos, round) {
if value == 1 {
self.active_detectors.insert(idx);
} else {
self.active_detectors.remove(&idx);
}
}
}
Ok(())
}
pub fn get_node_features(&self) -> Array2<f32> {
let feature_dim = 8; let n_nodes = self.nodes.len();
let mut features = Array2::zeros((n_nodes, feature_dim));
for (i, node) in self.nodes.iter().enumerate() {
let node_features = node.to_features(self.config.distance, self.config.num_rounds);
for (j, &f) in node_features.iter().enumerate() {
features[[i, j]] = f;
}
}
features
}
pub fn get_active_mask(&self) -> Vec<bool> {
(0..self.nodes.len())
.map(|i| self.active_detectors.contains(&i))
.collect()
}
pub fn adjacency(&self) -> &HashMap<usize, Vec<usize>> {
&self.adjacency
}
pub fn edge_weights(&self) -> &HashMap<(usize, usize), f32> {
&self.edge_weights
}
pub fn get_positions(&self) -> Vec<(f32, f32)> {
self.nodes
.iter()
.map(|n| (n.position.0 as f32, n.position.1 as f32))
.collect()
}
pub fn get_boundary_distances(&self) -> Vec<f32> {
let d = self.config.distance as f32;
self.nodes
.iter()
.map(|n| {
let (i, j) = n.position;
let dist_to_boundary = [
i as f32,
(self.config.distance - 1 - i) as f32,
j as f32,
(self.config.distance - 1 - j) as f32,
]
.iter()
.cloned()
.fold(f32::INFINITY, f32::min);
dist_to_boundary / d
})
.collect()
}
pub fn num_nodes(&self) -> usize {
self.nodes.len()
}
pub fn num_edges(&self) -> usize {
self.edge_weights.len()
}
pub fn config(&self) -> &TranslateConfig {
&self.config
}
pub fn active_detectors(&self) -> &HashSet<usize> {
&self.active_detectors
}
}
#[derive(Debug, Clone)]
pub struct SyndromeTranslator {
graph: DetectorGraph,
prev_syndrome: Option<Array2<u8>>,
current_round: usize,
}
impl SyndromeTranslator {
pub fn new(config: TranslateConfig) -> Result<Self> {
Ok(Self {
graph: DetectorGraph::new(config)?,
prev_syndrome: None,
current_round: 0,
})
}
pub fn process(&mut self, syndrome: &Array2<u8>) -> Result<&DetectorGraph> {
if let Some(ref prev) = self.prev_syndrome {
let mut changed_positions = Vec::new();
let mut new_values = Vec::new();
for i in 0..syndrome.shape()[0] {
for j in 0..syndrome.shape()[1] {
if syndrome[[i, j]] != prev[[i, j]] {
changed_positions.push((i, j));
new_values.push(syndrome[[i, j]]);
}
}
}
if !changed_positions.is_empty() {
self.graph.update_incremental(
&changed_positions,
&new_values,
self.current_round,
)?;
}
} else {
self.graph.translate_syndrome(syndrome, self.current_round)?;
}
self.prev_syndrome = Some(syndrome.clone());
self.current_round += 1;
Ok(&self.graph)
}
pub fn reset(&mut self) {
self.prev_syndrome = None;
self.current_round = 0;
self.graph.active_detectors.clear();
}
pub fn graph(&self) -> &DetectorGraph {
&self.graph
}
pub fn graph_mut(&mut self) -> &mut DetectorGraph {
&mut self.graph
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_validation() {
let mut config = TranslateConfig::default();
assert!(config.validate().is_ok());
config.distance = 2;
assert!(config.validate().is_err());
config.distance = 5;
config.error_rate = 1.5;
assert!(config.validate().is_err());
}
#[test]
fn test_detector_node_features() {
let node = DetectorNode::new(0, StabilizerType::X, (2, 3), 0);
let features = node.to_features(5, 1);
assert_eq!(features.len(), 8);
assert!(features.iter().all(|&f| f >= 0.0 && f <= 2.0));
}
#[test]
fn test_detector_graph_creation() {
let config = TranslateConfig {
distance: 5,
topology: SurfaceCodeTopology::Rotated,
..Default::default()
};
let graph = DetectorGraph::new(config).unwrap();
assert!(graph.num_nodes() > 0);
assert!(graph.num_edges() > 0);
}
#[test]
fn test_syndrome_translation() {
let config = TranslateConfig {
distance: 5,
topology: SurfaceCodeTopology::Rotated,
include_boundaries: false,
..Default::default()
};
let mut graph = DetectorGraph::new(config).unwrap();
let mut syndrome = Array2::zeros((4, 4));
syndrome[[1, 1]] = 1;
syndrome[[2, 2]] = 1;
graph.translate_syndrome(&syndrome, 0).unwrap();
assert!(graph.active_detectors().len() > 0 || graph.num_nodes() == 0);
}
#[test]
fn test_incremental_update() {
let config = TranslateConfig {
distance: 5,
include_boundaries: false,
..Default::default()
};
let mut graph = DetectorGraph::new(config).unwrap();
let syndrome = Array2::zeros((4, 4));
graph.translate_syndrome(&syndrome, 0).unwrap();
let initial_active = graph.active_detectors().len();
let changed = vec![(1, 1)];
let values = vec![1];
graph.update_incremental(&changed, &values, 0).unwrap();
assert!(graph.active_detectors().len() >= initial_active);
}
#[test]
fn test_node_features_matrix() {
let config = TranslateConfig::default();
let graph = DetectorGraph::new(config).unwrap();
let features = graph.get_node_features();
assert_eq!(features.shape()[0], graph.num_nodes());
assert_eq!(features.shape()[1], 8); }
#[test]
fn test_boundary_distances() {
let config = TranslateConfig {
distance: 5,
include_boundaries: true,
..Default::default()
};
let graph = DetectorGraph::new(config).unwrap();
let distances = graph.get_boundary_distances();
assert_eq!(distances.len(), graph.num_nodes());
for &d in &distances {
assert!(d >= 0.0);
}
}
#[test]
fn test_syndrome_translator() {
let config = TranslateConfig::default();
let mut translator = SyndromeTranslator::new(config).unwrap();
let syndrome1 = Array2::zeros((4, 4));
let graph1 = translator.process(&syndrome1).unwrap();
assert_eq!(graph1.active_detectors().len(), 0);
let mut syndrome2 = Array2::zeros((4, 4));
syndrome2[[1, 1]] = 1;
let _ = translator.process(&syndrome2).unwrap();
translator.reset();
assert_eq!(translator.graph().active_detectors().len(), 0);
}
#[test]
fn test_different_topologies() {
for topology in &[
SurfaceCodeTopology::Rotated,
SurfaceCodeTopology::Unrotated,
SurfaceCodeTopology::Planar,
] {
let config = TranslateConfig {
distance: 5,
topology: *topology,
..Default::default()
};
let graph = DetectorGraph::new(config).unwrap();
assert!(graph.num_nodes() > 0);
}
}
#[test]
fn test_edge_weights() {
let config = TranslateConfig {
distance: 5,
error_rate: 0.01,
..Default::default()
};
let graph = DetectorGraph::new(config).unwrap();
for &weight in graph.edge_weights().values() {
assert!(weight > 0.0);
}
}
#[test]
fn test_positions_and_adjacency() {
let config = TranslateConfig::default();
let graph = DetectorGraph::new(config).unwrap();
let positions = graph.get_positions();
assert_eq!(positions.len(), graph.num_nodes());
for (&node, neighbors) in graph.adjacency() {
for &neighbor in neighbors {
assert!(
graph.adjacency().get(&neighbor).map_or(false, |n| n.contains(&node)),
"Adjacency should be symmetric"
);
}
}
}
}