from typing import Any, Dict, List
import numpy as np
from scipy.sparse import csr_matrix
from relay_bp import RelayDecoderF64
class RelayBPDecoder:
def __init__(self, vertex_num: int, num_hyperedges: int, solver: RelayDecoderF64):
self._vertex_num = vertex_num
self._num_hyperedges = num_hyperedges
self._solver = solver
def decode(self, syndrome: List[int]) -> List[int]:
assert isinstance(syndrome, list)
dense = np.zeros(self._vertex_num, dtype=np.uint8)
for index in syndrome:
dense[int(index)] = 1
result = self._solver.decode(dense)
return [int(i) for i in np.flatnonzero(np.asarray(result))]
def reset(self) -> None:
return None
def new(hypergraph: Any, config: Dict[str, Any]) -> Any:
vertex_num = int(hypergraph.vertex_num)
hyperedges = list(hypergraph.hyperedges)
num_hyperedges = len(hyperedges)
rows: List[int] = []
cols: List[int] = []
error_priors = np.empty(num_hyperedges, dtype=np.float64)
for column, hyperedge in enumerate(hyperedges):
for vertex in hyperedge.vertices:
rows.append(int(vertex))
cols.append(column)
error_priors[column] = float(hyperedge.probability)
data = np.ones(len(rows), dtype=np.uint8)
check_matrix = csr_matrix(
(data, (rows, cols)),
shape=(vertex_num, num_hyperedges),
)
kwargs = dict(config or {})
if "gamma_dist_interval" in kwargs and isinstance(kwargs["gamma_dist_interval"], list):
kwargs["gamma_dist_interval"] = tuple(kwargs["gamma_dist_interval"])
solver = RelayDecoderF64(check_matrix, error_priors, **kwargs)
return RelayBPDecoder(vertex_num, num_hyperedges, solver)