include!("proto/deq.jit.rs");
pub mod jit_compiler;
use crate::bin;
use tokio_util::sync::CancellationToken;
pub async fn static_jit_compile(mut jit_library: JitLibrary) -> bin::Library {
let compiler = jit_compiler::JitCompiler::new();
let program = std::mem::take(&mut jit_library.program);
let token = CancellationToken::new();
let mut library = bin::Library::default();
for port_type in jit_library.port_types.iter() {
library.port_types.push(port_type.base.as_ref().unwrap().clone());
}
for gadget_type in jit_library.gadget_types.iter() {
library.gadget_types.push(gadget_type.base.as_ref().unwrap().clone());
}
compiler.load_library(jit_library).await;
if program.is_empty() {
return library;
}
let all_preassigned = program.iter().all(|i| i.gadget.as_ref().is_some_and(|g| g.gid != 0));
if !all_preassigned || program.len() < 64 {
return static_jit_compile_sequential(compiler, program, token, library).await;
}
use std::collections::HashMap;
use tokio::sync::watch;
let mut ready_txs: HashMap<u64, watch::Sender<bool>> = HashMap::new();
struct TaskInfo {
idx: usize,
instruction: JitInstruction,
gid: u64,
dep_gids: Vec<u64>,
}
let mut tasks: Vec<TaskInfo> = Vec::with_capacity(program.len());
for (idx, instruction) in program.into_iter().enumerate() {
let gadget = instruction.gadget.as_ref().unwrap();
let gid = gadget.gid;
let dep_gids: Vec<u64> = gadget.connectors.iter().map(|c| c.gid).collect();
let (tx, _) = watch::channel(false);
ready_txs.insert(gid, tx);
tasks.push(TaskInfo {
idx,
instruction,
gid,
dep_gids,
});
}
let mut all_dep_rxs: Vec<Vec<watch::Receiver<bool>>> = Vec::with_capacity(tasks.len());
for info in &tasks {
let rxs: Vec<watch::Receiver<bool>> = info.dep_gids.iter().map(|d| ready_txs[d].subscribe()).collect();
all_dep_rxs.push(rxs);
}
let n = tasks.len();
let mut handles = Vec::with_capacity(n);
for (info, dep_rxs) in tasks.into_iter().zip(all_dep_rxs) {
let ready_tx = ready_txs.remove(&info.gid).unwrap();
let comp = std::sync::Arc::clone(&compiler);
let tok = token.clone();
handles.push(tokio::spawn(async move {
for mut rx in dep_rxs {
let _ = rx.wait_for(|v| *v).await;
}
let (gadget, cmt, cm, error_future) = comp.compile(info.instruction, tok).await;
let _ = ready_tx.send(true);
let error_handle = tokio::spawn(error_future);
(info.idx, gadget, cmt, cm, error_handle)
}));
}
let mut results: Vec<_> = futures_util::future::join_all(handles)
.await
.into_iter()
.map(|r| r.unwrap())
.collect();
results.sort_by_key(|(idx, _, _, _, _)| *idx);
let mut error_handles = Vec::with_capacity(n);
for (_, gadget, cmt, cm, error_handle) in results {
library.program.push(bin::Instruction {
create: Some(bin::instruction::Create::Gadget(gadget)),
});
library.check_model_types.push(cmt);
library.program.push(bin::Instruction {
create: Some(bin::instruction::Create::CheckModel(cm)),
});
error_handles.push(error_handle);
}
let error_results = futures_util::future::join_all(error_handles).await;
for result in error_results {
let (error_model_type, error_model) = result.unwrap();
library.error_model_types.push(error_model_type);
library.program.push(bin::Instruction {
create: Some(bin::instruction::Create::ErrorModel(error_model)),
});
}
library
}
async fn static_jit_compile_sequential(
compiler: std::sync::Arc<jit_compiler::JitCompiler>,
program: Vec<JitInstruction>,
token: CancellationToken,
mut library: bin::Library,
) -> bin::Library {
let mut error_model_futures = vec![];
for instruction in program {
let (gadget, check_model_type, check_model, error_model_future) = compiler.compile(instruction, token.clone()).await;
error_model_futures.push(error_model_future);
library.program.push(bin::Instruction {
create: Some(bin::instruction::Create::Gadget(gadget)),
});
library.check_model_types.push(check_model_type);
library.program.push(bin::Instruction {
create: Some(bin::instruction::Create::CheckModel(check_model)),
});
}
let error_models = futures_util::future::join_all(error_model_futures).await;
for (error_model_type, error_model) in error_models {
library.error_model_types.push(error_model_type);
library.program.push(bin::Instruction {
create: Some(bin::instruction::Create::ErrorModel(error_model)),
});
}
library
}
#[cfg(feature = "python_binding")]
#[pyo3::pyfunction]
#[pyo3(name="static_jit_compile", signature = (jit_library))]
pub fn py_static_jit_compile(jit_library: Vec<u8>) -> Vec<u8> {
use prost::Message;
let jit_library = JitLibrary::decode(&*jit_library).unwrap();
let library = tokio::runtime::Runtime::new()
.unwrap()
.block_on(static_jit_compile(jit_library));
let mut buf = Vec::with_capacity(library.encoded_len());
library.encode(&mut buf).unwrap();
buf
}