#![deny(missing_docs)]
#![warn(clippy::all)]
#![allow(clippy::module_name_repetitions)]
pub mod error;
pub mod graph;
pub mod gnn;
pub mod mamba;
pub mod fusion;
pub use error::{NeuralDecoderError, Result};
pub use graph::{DetectorGraph, GraphBuilder, Node, Edge};
pub use gnn::{GNNEncoder, GNNConfig, AttentionLayer};
pub use mamba::{MambaDecoder, MambaConfig, MambaState};
pub use fusion::{FeatureFusion, FusionConfig};
use ndarray::{Array1, Array2};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DecoderConfig {
pub distance: usize,
pub embed_dim: usize,
pub hidden_dim: usize,
pub num_gnn_layers: usize,
pub num_heads: usize,
pub mamba_state_dim: usize,
pub use_mincut_fusion: bool,
pub dropout: f32,
}
impl Default for DecoderConfig {
fn default() -> Self {
Self {
distance: 5,
embed_dim: 64,
hidden_dim: 128,
num_gnn_layers: 3,
num_heads: 4,
mamba_state_dim: 64,
use_mincut_fusion: false,
dropout: 0.1,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Correction {
pub x_corrections: Vec<usize>,
pub z_corrections: Vec<usize>,
pub confidence: f64,
pub decode_time_ns: u64,
}
pub struct NeuralDecoder {
config: DecoderConfig,
gnn: GNNEncoder,
mamba: MambaDecoder,
fusion: Option<FeatureFusion>,
}
impl NeuralDecoder {
pub fn new(config: DecoderConfig) -> Result<Self> {
let gnn_config = GNNConfig {
input_dim: 5, embed_dim: config.embed_dim,
hidden_dim: config.hidden_dim,
num_layers: config.num_gnn_layers,
num_heads: config.num_heads,
dropout: config.dropout,
};
let mamba_config = MambaConfig {
input_dim: config.hidden_dim,
state_dim: config.mamba_state_dim,
output_dim: config.distance * config.distance,
};
let fusion = if config.use_mincut_fusion {
let fusion_config = FusionConfig {
gnn_dim: config.hidden_dim,
mincut_dim: 16,
output_dim: config.hidden_dim,
gnn_weight: 0.5,
mincut_weight: 0.3,
boundary_weight: 0.2,
adaptive_weights: true,
temperature: 1.0,
};
FeatureFusion::new(fusion_config).ok()
} else {
None
};
Ok(Self {
config,
gnn: GNNEncoder::new(gnn_config)?,
mamba: MambaDecoder::new(mamba_config),
fusion,
})
}
pub fn decode(&mut self, syndrome: &[bool]) -> Result<Correction> {
let start = std::time::Instant::now();
let graph = GraphBuilder::from_surface_code(self.config.distance)
.with_syndrome(syndrome)?
.build()?;
let node_embeddings = self.gnn.encode(&graph)?;
let fused = node_embeddings;
let output = self.mamba.decode(&fused)?;
let corrections = self.output_to_corrections(&output)?;
let elapsed = start.elapsed();
Ok(Correction {
x_corrections: corrections.0,
z_corrections: corrections.1,
confidence: corrections.2,
decode_time_ns: elapsed.as_nanos() as u64,
})
}
fn output_to_corrections(&self, output: &Array1<f32>) -> Result<(Vec<usize>, Vec<usize>, f64)> {
let threshold = 0.5;
let mut x_corrections = Vec::new();
for (i, &val) in output.iter().enumerate() {
if val > threshold {
x_corrections.push(i);
}
}
let confidence = if output.is_empty() {
0.0
} else {
output.iter()
.map(|&v| (v - 0.5).abs() * 2.0)
.sum::<f32>() / output.len() as f32
};
Ok((x_corrections, Vec::new(), confidence as f64))
}
#[must_use]
pub fn config(&self) -> &DecoderConfig {
&self.config
}
pub fn reset(&mut self) {
self.mamba.reset();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decoder_config_default() {
let config = DecoderConfig::default();
assert_eq!(config.distance, 5);
assert_eq!(config.embed_dim, 64);
assert_eq!(config.hidden_dim, 128);
assert!(config.dropout >= 0.0 && config.dropout <= 1.0);
}
#[test]
fn test_decoder_creation() {
let config = DecoderConfig::default();
let decoder = NeuralDecoder::new(config).unwrap();
assert_eq!(decoder.config().distance, 5);
}
#[test]
fn test_correction_default() {
let correction = Correction::default();
assert!(correction.x_corrections.is_empty());
assert!(correction.z_corrections.is_empty());
assert_eq!(correction.confidence, 0.0);
}
#[test]
fn test_decoder_empty_syndrome() {
let config = DecoderConfig {
distance: 3,
..Default::default()
};
let mut decoder = NeuralDecoder::new(config).unwrap();
let syndrome = vec![false; 9];
let result = decoder.decode(&syndrome);
assert!(result.is_ok());
}
}