use crate::syndrome::DetectorBitmap;
#[derive(Debug, Clone)]
pub struct DecoderConfig {
pub distance: usize,
pub physical_error_rate: f64,
pub window_size: usize,
pub parallel: bool,
}
impl Default for DecoderConfig {
fn default() -> Self {
Self {
distance: 7,
physical_error_rate: 0.001,
window_size: 1,
parallel: false,
}
}
}
#[derive(Debug, Clone)]
pub struct Correction {
pub x_corrections: Vec<usize>,
pub z_corrections: Vec<usize>,
pub confidence: f64,
pub decode_time_ns: u64,
}
impl Default for Correction {
fn default() -> Self {
Self {
x_corrections: Vec::new(),
z_corrections: Vec::new(),
confidence: 1.0,
decode_time_ns: 0,
}
}
}
#[cfg(feature = "decoder")]
pub struct MWPMDecoder {
config: DecoderConfig,
solver: fusion_blossom::mwpm_solver::SolverSerial,
vertex_count: usize,
edges: Vec<(usize, usize, i32)>,
detector_to_vertex: Vec<usize>,
}
#[cfg(feature = "decoder")]
impl MWPMDecoder {
pub fn new(config: DecoderConfig) -> Self {
use fusion_blossom::mwpm_solver::{SolverSerial, SolverInitializer};
use fusion_blossom::util::*;
let d = config.distance;
let num_detectors = d * d;
let vertex_count = num_detectors + 1;
let weight = (-(config.physical_error_rate.ln()) * 1000.0) as i32;
let mut edges = Vec::new();
for row in 0..d {
for col in 0..d {
let v = row * d + col;
if col + 1 < d {
let neighbor = row * d + (col + 1);
edges.push((v, neighbor, weight));
}
if row + 1 < d {
let neighbor = (row + 1) * d + col;
edges.push((v, neighbor, weight));
}
}
}
let boundary_vertex = num_detectors;
for col in 0..d {
edges.push((col, boundary_vertex, weight / 2)); edges.push(((d - 1) * d + col, boundary_vertex, weight / 2)); }
for row in 0..d {
edges.push((row * d, boundary_vertex, weight / 2)); edges.push((row * d + (d - 1), boundary_vertex, weight / 2)); }
let fb_edges: Vec<(VertexIndex, VertexIndex, Weight)> = edges
.iter()
.map(|(v1, v2, w)| (*v1 as VertexIndex, *v2 as VertexIndex, *w as Weight))
.collect();
let initializer = SolverInitializer::new(vertex_count as VertexNum, fb_edges);
let solver = SolverSerial::new(&initializer);
let detector_to_vertex: Vec<usize> = (0..num_detectors).collect();
Self {
config,
solver,
vertex_count,
edges,
detector_to_vertex,
}
}
pub fn decode(&mut self, syndrome: &DetectorBitmap) -> Correction {
use fusion_blossom::mwpm_solver::PrimalDualSolver;
use std::time::Instant;
let start = Instant::now();
self.solver.clear();
let mut defect_vertices = Vec::new();
for detector_idx in syndrome.iter_fired() {
if detector_idx < self.detector_to_vertex.len() {
let vertex = self.detector_to_vertex[detector_idx];
defect_vertices.push(vertex as fusion_blossom::util::VertexIndex);
}
}
if defect_vertices.len() % 2 == 1 {
defect_vertices.push((self.vertex_count - 1) as fusion_blossom::util::VertexIndex);
}
self.solver.solve_visualizer(None);
let matching = self.solver.perfect_matching();
let mut x_corrections = Vec::new();
let d = self.config.distance;
for (v1, v2) in matching.iter() {
let v1 = *v1 as usize;
let v2 = *v2 as usize;
if v1 < d * d && v2 < d * d {
let row1 = v1 / d;
let col1 = v1 % d;
let row2 = v2 / d;
let col2 = v2 % d;
let min_row = row1.min(row2);
let max_row = row1.max(row2);
let min_col = col1.min(col2);
let max_col = col1.max(col2);
for r in min_row..=max_row {
for c in min_col..=max_col {
x_corrections.push(r * d + c);
}
}
}
}
x_corrections.sort_unstable();
let mut deduped = Vec::new();
let mut i = 0;
while i < x_corrections.len() {
let mut count = 1;
while i + count < x_corrections.len() && x_corrections[i] == x_corrections[i + count] {
count += 1;
}
if count % 2 == 1 {
deduped.push(x_corrections[i]);
}
i += count;
}
let elapsed = start.elapsed();
Correction {
x_corrections: deduped,
z_corrections: Vec::new(), confidence: if syndrome.fired_count() == 0 { 1.0 } else { 0.9 },
decode_time_ns: elapsed.as_nanos() as u64,
}
}
pub fn config(&self) -> &DecoderConfig {
&self.config
}
}
#[cfg(not(feature = "decoder"))]
pub struct MWPMDecoder {
config: DecoderConfig,
}
#[cfg(not(feature = "decoder"))]
impl MWPMDecoder {
pub fn new(config: DecoderConfig) -> Self {
Self { config }
}
pub fn decode(&mut self, syndrome: &DetectorBitmap) -> Correction {
let start = std::time::Instant::now();
let fired: Vec<usize> = syndrome.iter_fired().collect();
let d = self.config.distance;
let mut x_corrections = Vec::new();
let mut used = vec![false; fired.len()];
for (i, &det1) in fired.iter().enumerate() {
if used[i] {
continue;
}
let row1 = det1 / d;
let col1 = det1 % d;
let mut best_dist = usize::MAX;
let mut best_j = None;
for (j, &det2) in fired.iter().enumerate().skip(i + 1) {
if used[j] {
continue;
}
let row2 = det2 / d;
let col2 = det2 % d;
let dist = row1.abs_diff(row2) + col1.abs_diff(col2);
if dist < best_dist {
best_dist = dist;
best_j = Some(j);
}
}
if let Some(j) = best_j {
used[i] = true;
used[j] = true;
let det2 = fired[j];
let row2 = det2 / d;
let col2 = det2 % d;
let min_row = row1.min(row2);
let max_row = row1.max(row2);
let min_col = col1.min(col2);
let max_col = col1.max(col2);
for c in min_col..max_col {
x_corrections.push(min_row * d + c);
}
for r in min_row..max_row {
x_corrections.push(r * d + max_col);
}
}
}
let elapsed = start.elapsed();
Correction {
x_corrections,
z_corrections: Vec::new(),
confidence: if fired.is_empty() { 1.0 } else { 0.7 }, decode_time_ns: elapsed.as_nanos() as u64,
}
}
pub fn config(&self) -> &DecoderConfig {
&self.config
}
}
pub struct StreamingDecoder {
inner: MWPMDecoder,
correction_history: Vec<Correction>,
history_size: usize,
}
impl StreamingDecoder {
pub fn new(config: DecoderConfig) -> Self {
let history_size = config.window_size.max(10);
Self {
inner: MWPMDecoder::new(config),
correction_history: Vec::with_capacity(history_size),
history_size,
}
}
pub fn process(&mut self, syndrome: &DetectorBitmap) -> Correction {
let correction = self.inner.decode(syndrome);
if self.correction_history.len() >= self.history_size {
self.correction_history.remove(0);
}
self.correction_history.push(correction.clone());
correction
}
pub fn average_decode_time_ns(&self) -> u64 {
if self.correction_history.is_empty() {
return 0;
}
let sum: u64 = self.correction_history.iter().map(|c| c.decode_time_ns).sum();
sum / self.correction_history.len() as u64
}
pub fn config(&self) -> &DecoderConfig {
self.inner.config()
}
pub fn clear_history(&mut self) {
self.correction_history.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decoder_config_default() {
let config = DecoderConfig::default();
assert_eq!(config.distance, 7);
assert!((config.physical_error_rate - 0.001).abs() < 1e-10);
}
#[test]
fn test_decoder_empty_syndrome() {
let config = DecoderConfig::default();
let mut decoder = MWPMDecoder::new(config);
let syndrome = DetectorBitmap::new(49); let correction = decoder.decode(&syndrome);
assert!(correction.x_corrections.is_empty());
assert_eq!(correction.confidence, 1.0);
}
#[test]
fn test_decoder_single_pair() {
let config = DecoderConfig {
distance: 5,
physical_error_rate: 0.01,
window_size: 1,
parallel: false,
};
let mut decoder = MWPMDecoder::new(config);
let mut syndrome = DetectorBitmap::new(25); syndrome.set(0, true); syndrome.set(1, true);
let correction = decoder.decode(&syndrome);
assert!(!correction.x_corrections.is_empty());
assert!(correction.decode_time_ns > 0);
}
#[test]
fn test_streaming_decoder() {
let config = DecoderConfig::default();
let mut decoder = StreamingDecoder::new(config);
for i in 0..5 {
let mut syndrome = DetectorBitmap::new(49);
if i % 2 == 0 {
syndrome.set(0, true);
syndrome.set(6, true);
}
let _ = decoder.process(&syndrome);
}
assert!(decoder.average_decode_time_ns() > 0);
}
#[test]
fn test_correction_default() {
let correction = Correction::default();
assert!(correction.x_corrections.is_empty());
assert!(correction.z_corrections.is_empty());
assert_eq!(correction.confidence, 1.0);
}
}