use crate::decoder::blackbox_decoder::{self, black_box_decoder_server};
use crate::util::BitVector;
use hashbrown::HashMap;
use serde::{Deserialize, Serialize};
#[cfg(feature = "cli")]
use std::sync::Arc;
#[cfg(feature = "cli")]
use structdoc::StructDoc;
use tokio::sync::RwLock;
#[cfg(feature = "cli")]
use tonic::transport::server::Router;
use tonic::{Request, Response, Status};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "cli", derive(StructDoc))]
#[serde(deny_unknown_fields)]
pub struct MockDecoderConfig {
#[serde(default)]
pub decode_delay_ms: u64,
}
pub struct MockDecoder {
pub state: RwLock<MockDecoderState>,
pub decode_delay: std::sync::Mutex<Option<std::time::Duration>>,
}
#[derive(Default)]
pub struct MockDecoderState {
pub decode_calls: Vec<DecodeProblem>,
pub loaded_hypergraphs: HashMap<u64, blackbox_decoder::DecodingHypergraph>,
pub decode_loaded_calls: Vec<LoadedDecodeProblem>,
pub next_hid: u64,
pub reset_count: usize,
pub custom_responses: HashMap<Vec<u8>, Vec<u64>>,
}
#[derive(Clone, Debug)]
pub struct DecodeProblem {
pub hypergraph: blackbox_decoder::DecodingHypergraph,
pub syndrome: BitVector,
}
#[derive(Clone, Debug)]
pub struct LoadedDecodeProblem {
pub hid: u64,
pub syndrome: BitVector,
}
impl MockDecoder {
pub fn new() -> Self {
Self {
state: RwLock::new(MockDecoderState {
next_hid: 1,
..Default::default()
}),
decode_delay: std::sync::Mutex::new(None),
}
}
pub fn from_config(config: serde_json::Value) -> Self {
let config: MockDecoderConfig = serde_json::from_value(config).unwrap();
let delay = if config.decode_delay_ms > 0 {
Some(std::time::Duration::from_millis(config.decode_delay_ms))
} else {
None
};
Self {
state: RwLock::new(MockDecoderState {
next_hid: 1,
..Default::default()
}),
decode_delay: std::sync::Mutex::new(delay),
}
}
#[cfg(feature = "cli")]
pub fn add_service(self: &Arc<Self>, router: Router) -> Router {
let service =
black_box_decoder_server::BlackBoxDecoderServer::from_arc(self.clone()).max_decoding_message_size(usize::MAX);
router.add_service(service)
}
pub async fn clear(&self) {
let mut state = self.state.write().await;
state.decode_calls.clear();
state.loaded_hypergraphs.clear();
state.decode_loaded_calls.clear();
state.next_hid = 1;
state.reset_count = 0;
state.custom_responses.clear();
}
pub async fn set_response(&self, syndrome_data: Vec<u8>, subgraph: Vec<u64>) {
let mut state = self.state.write().await;
state.custom_responses.insert(syndrome_data, subgraph);
}
pub fn set_decode_delay(&self, delay: std::time::Duration) {
*self.decode_delay.lock().unwrap() = Some(delay);
}
fn get_response(state: &MockDecoderState, syndrome: &BitVector) -> Vec<u64> {
state.custom_responses.get(&syndrome.data).cloned().unwrap_or_default()
}
async fn apply_delay(&self) {
let delay = *self.decode_delay.lock().unwrap();
if let Some(delay) = delay {
tokio::time::sleep(delay).await;
}
}
}
impl Default for MockDecoder {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for MockDecoder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MockDecoder").finish()
}
}
#[tonic::async_trait]
impl black_box_decoder_server::BlackBoxDecoder for MockDecoder {
async fn decode(
&self,
request: Request<blackbox_decoder::DecodingProblem>,
) -> Result<Response<blackbox_decoder::ParityFactor>, Status> {
let problem = request.into_inner();
let hypergraph = problem
.hypergraph
.ok_or_else(|| Status::invalid_argument("missing hypergraph"))?;
let syndrome = problem.syndrome.ok_or_else(|| Status::invalid_argument("missing syndrome"))?;
let mut state = self.state.write().await;
state.decode_calls.push(DecodeProblem {
hypergraph: hypergraph.clone(),
syndrome: syndrome.clone(),
});
let subgraph = Self::get_response(&state, &syndrome);
drop(state);
self.apply_delay().await;
Ok(Response::new(blackbox_decoder::ParityFactor { subgraph }))
}
async fn load_hypergraph(
&self,
request: Request<blackbox_decoder::DecodingHypergraph>,
) -> Result<Response<blackbox_decoder::LoadHypergraphResponse>, Status> {
let hypergraph = request.into_inner();
let mut state = self.state.write().await;
let hid = state.next_hid;
state.next_hid += 1;
state.loaded_hypergraphs.insert(hid, hypergraph);
Ok(Response::new(blackbox_decoder::LoadHypergraphResponse { hid }))
}
async fn decode_loaded(
&self,
request: Request<blackbox_decoder::LoadedDecodingProblem>,
) -> Result<Response<blackbox_decoder::ParityFactor>, Status> {
let problem = request.into_inner();
let syndrome = problem.syndrome.ok_or_else(|| Status::invalid_argument("missing syndrome"))?;
let mut state = self.state.write().await;
if !state.loaded_hypergraphs.contains_key(&problem.hid) {
return Err(Status::not_found(format!("hid={}", problem.hid)));
}
state.decode_loaded_calls.push(LoadedDecodeProblem {
hid: problem.hid,
syndrome: syndrome.clone(),
});
let subgraph = Self::get_response(&state, &syndrome);
drop(state);
self.apply_delay().await;
Ok(Response::new(blackbox_decoder::ParityFactor { subgraph }))
}
async fn reset(&self, request: Request<blackbox_decoder::ResetRequest>) -> Result<Response<()>, Status> {
let flags = request.into_inner();
let mut state = self.state.write().await;
state.reset_count += 1;
if flags.reset_hypergraphs {
state.loaded_hypergraphs.clear();
state.next_hid = 1;
}
Ok(().into())
}
}