use crate::bin;
use crate::coordinator;
use crate::coordinator::{CoordinatorClient, ResetRequest};
#[cfg(feature = "cli")]
use crate::util::BitVector;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[cfg(feature = "cli")]
use structdoc::StructDoc;
use tokio::sync::{Mutex, RwLock};
use tokio::task::JoinSet;
use tonic::Status;
#[cfg(feature = "cli")]
use tonic::transport::server::Router;
#[cfg(feature = "cli")]
use tonic::{Request, Response};
include!("../proto/deq.controller.static_controller.rs");
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "cli", derive(StructDoc))]
#[serde(deny_unknown_fields)]
pub struct StaticControllerConfig {
pub filepath: String,
#[serde(default)]
pub reset_library: bool,
#[serde(default)]
pub reset_decoder_service: bool,
}
pub struct StaticController {
pub config: StaticControllerConfig,
pub library: crate::bin::Library,
coordinator: RwLock<Option<CoordinatorClient>>,
#[cfg_attr(not(feature = "cli"), allow(dead_code))]
info: Arc<StaticControllerInfo>,
#[cfg_attr(not(feature = "cli"), allow(dead_code))]
state: Arc<Mutex<StaticControllerState>>,
}
struct StaticControllerState {
next_index: usize,
gid_vec: Vec<u64>,
outcomes: Vec<bool>,
pending_decodes: JoinSet<Result<(usize, coordinator::Readouts), Status>>,
pending_readouts: Vec<Option<coordinator::Readouts>>,
dispatched_count: usize,
}
struct StaticControllerInfo {
accumulated_measurements: Vec<usize>,
total_measurements: usize,
}
impl StaticControllerInfo {
pub fn new(library: &bin::Library) -> Self {
let mut accumulated_measurements: Vec<usize> = vec![];
let mut total_measurements: usize = 0;
for instruction in library.program.iter() {
if let bin::instruction::Create::Gadget(gadget) = instruction.create.as_ref().unwrap() {
let gadget_type = library.gadget_types.iter().find(|gt| gt.gtype == gadget.gtype).unwrap();
total_measurements += gadget_type.measurements.len();
accumulated_measurements.push(total_measurements);
}
}
Self {
accumulated_measurements,
total_measurements,
}
}
}
impl StaticController {
pub fn new(config: serde_json::Value) -> Self {
let config: StaticControllerConfig = serde_json::from_value(config).unwrap();
let data = std::fs::read(config.filepath.clone()).unwrap();
let library: bin::Library = prost::Message::decode(&mut data.as_slice()).unwrap();
let info = Arc::new(StaticControllerInfo::new(&library));
let state = StaticControllerState {
next_index: 0,
gid_vec: Vec::with_capacity(info.accumulated_measurements.len()),
outcomes: Vec::with_capacity(info.total_measurements),
pending_decodes: JoinSet::new(),
pending_readouts: Vec::new(),
dispatched_count: 0,
};
Self {
config,
library,
coordinator: RwLock::new(None),
info,
state: Arc::new(Mutex::new(state)),
}
}
#[cfg(feature = "cli")]
pub fn add_service(self: &Arc<Self>, router: Router) -> Router {
let service =
static_controller_server::StaticControllerServer::from_arc(self.clone()).max_decoding_message_size(usize::MAX);
router.add_service(service)
}
pub async fn start(self: &Arc<Self>, mut client: CoordinatorClient) {
let mut coordinator = self.coordinator.write().await;
client.load_library(self.library.clone()).await.unwrap();
self.reset_and_instantiate_all(&mut client).await;
coordinator.replace(client);
}
#[cfg(feature = "cli")]
async fn wait_until_library_loaded(&self) -> CoordinatorClient {
while self.coordinator.read().await.is_none() {
tokio::task::yield_now().await;
}
self.coordinator.read().await.clone().unwrap()
}
async fn reset_and_instantiate_all(&self, coordinator: &mut CoordinatorClient) {
coordinator
.reset(ResetRequest {
reset_library: self.config.reset_library,
reset_decoder_service: self.config.reset_decoder_service,
..Default::default()
})
.await
.unwrap();
let mut gid_vec: Vec<u64> = vec![];
let mut cid_vec: Vec<u64> = vec![];
for instruction in self.library.program.iter() {
match instruction.clone().create.unwrap() {
bin::instruction::Create::Gadget(mut gadget) => {
for connector in gadget.connectors.iter_mut() {
connector.gid = gid_vec[(connector.gid as usize) - 1];
}
let gid = coordinator
.execute(bin::Instruction {
create: Some(bin::instruction::Create::Gadget(gadget)),
})
.await
.unwrap()
.id;
gid_vec.push(gid);
}
bin::instruction::Create::CheckModel(mut check_model) => {
check_model.gid = gid_vec[(check_model.gid as usize) - 1];
let cid = coordinator
.execute(bin::Instruction {
create: Some(bin::instruction::Create::CheckModel(check_model)),
})
.await
.unwrap()
.id;
cid_vec.push(cid);
}
bin::instruction::Create::ErrorModel(mut error_model) => {
error_model.cid = cid_vec[(error_model.cid as usize) - 1];
let _eid = coordinator
.execute(bin::Instruction {
create: Some(bin::instruction::Create::ErrorModel(error_model)),
})
.await
.unwrap()
.id;
}
}
}
let mut state = self.state.lock().await;
state.next_index = 0;
state.gid_vec = gid_vec;
state.outcomes.clear();
state.pending_decodes.shutdown().await;
state.pending_readouts.clear();
state.dispatched_count = 0;
}
}
#[cfg(feature = "cli")]
#[tonic::async_trait]
impl static_controller_server::StaticController for StaticController {
async fn decode(&self, request: Request<BitVector>) -> std::result::Result<Response<BitVector>, Status> {
let outcomes = request.into_inner();
let coordinator = self.wait_until_library_loaded().await;
let outcomes = crate::misc::bit_vector::unpack_bits(&outcomes.data, outcomes.size);
let is_complete;
{
let mut state = self.state.lock().await;
state.outcomes.extend_from_slice(&outcomes);
assert!(state.outcomes.len() <= self.info.total_measurements, "too many outcomes");
is_complete = state.outcomes.len() == self.info.total_measurements;
while state.next_index < self.info.accumulated_measurements.len() {
if state.outcomes.len() < self.info.accumulated_measurements[state.next_index] {
break; }
let gid = state.gid_vec[state.next_index];
let start = if state.next_index == 0 {
0
} else {
self.info.accumulated_measurements[state.next_index - 1]
};
let slice: Vec<bool> = state.outcomes[start..self.info.accumulated_measurements[state.next_index]].into();
let bit_vector = BitVector {
size: slice.len() as u64,
data: crate::misc::bit_vector::pack_bits(&slice),
};
let dispatch_idx = state.dispatched_count;
state.dispatched_count += 1;
state.pending_readouts.push(None);
let coordinator_clone = coordinator.clone();
state.pending_decodes.spawn(async move {
coordinator_clone
.decode(coordinator::Outcomes {
gid,
outcomes: Some(bit_vector),
modifiers: vec![],
})
.await
.map(|readouts| (dispatch_idx, readouts))
});
state.next_index += 1;
}
while let Some(res) = state.pending_decodes.try_join_next() {
let (idx, readouts) = res.map_err(|e| Status::internal(format!("join error: {}", e)))??;
state.pending_readouts[idx] = Some(readouts);
}
}
if !is_complete {
let empty = BitVector { size: 0, data: vec![] };
return Ok(empty.into());
}
{
let mut state = self.state.lock().await;
while let Some(res) = state.pending_decodes.try_join_next() {
let (idx, readouts) = res.map_err(|e| Status::internal(format!("join error: {}", e)))??;
state.pending_readouts[idx] = Some(readouts);
}
while let Some(res) = state.pending_decodes.join_next().await {
let (idx, readouts) = res.map_err(|e| Status::internal(format!("join error: {}", e)))??;
state.pending_readouts[idx] = Some(readouts);
}
}
let state = self.state.lock().await;
let mut gathered_readouts = vec![];
for readouts in state.pending_readouts.iter() {
if let Some(r) = readouts {
let bit_vector = r
.readouts
.as_ref()
.ok_or_else(|| Status::internal("empty bit vector in readouts"))?;
gathered_readouts
.extend_from_slice(&crate::misc::bit_vector::unpack_bits(&bit_vector.data, bit_vector.size));
} else {
return Err(Status::internal("missing readouts"));
}
}
let gathered_readouts = BitVector {
size: gathered_readouts.len() as u64,
data: crate::misc::bit_vector::pack_bits(&gathered_readouts),
};
Ok(gathered_readouts.into())
}
async fn reset(&self, _request: Request<()>) -> std::result::Result<Response<()>, Status> {
let mut coordinator = self.wait_until_library_loaded().await;
self.reset_and_instantiate_all(&mut coordinator).await;
Ok(Response::new(()))
}
}