use std::{
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
};
use tokio::sync::oneshot;
use uuid::Uuid;
use crate::common::{RemoteFnPayload, RemoteFnResponse, RemoteFnResponsePayload, WorkerInbound};
pub struct OutboundRegistration {
pub payload: RemoteFnPayload,
pub tx: tokio::sync::oneshot::Sender<RemoteFnResponse>,
}
pub static ORCHESTRATOR_TX: std::sync::OnceLock<
tokio::sync::mpsc::UnboundedSender<OutboundRegistration>,
> = std::sync::OnceLock::new();
#[must_use]
pub struct Future<T> {
rx: oneshot::Receiver<RemoteFnResponse>,
_marker: PhantomData<T>,
}
impl<T: for<'a> serde::Deserialize<'a>> Future<T> {
pub fn new(rx: oneshot::Receiver<RemoteFnResponse>) -> Self {
Self {
rx,
_marker: PhantomData,
}
}
}
impl<T: for<'a> serde::Deserialize<'a>> ::std::future::Future for Future<T> {
type Output = T;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let f = unsafe { self.get_unchecked_mut() };
match Pin::new(&mut f.rx).poll(cx) {
Poll::Ready(Ok(RemoteFnResponse::Ok(return_value_bytes))) => Poll::Ready(
postcard::from_bytes(&return_value_bytes)
.expect("Failed to deserialize return value"),
),
Poll::Ready(Ok(RemoteFnResponse::Panic(msg))) => {
panic!("remote task panicked: {msg}")
}
Poll::Ready(Err(_)) => panic!("RPC channel closed unexpectedly"),
Poll::Pending => Poll::Pending,
}
}
}
pub async fn rpc(
mut registration_rx: tokio::sync::mpsc::UnboundedReceiver<OutboundRegistration>,
rpc_in: tokio::fs::File,
rpc_out: tokio::fs::File,
vault_keys: std::collections::HashMap<String, Vec<u8>>,
) {
let (reader_tx, mut reader_rx) =
tokio::sync::mpsc::unbounded_channel::<RemoteFnResponsePayload>();
let (writer_tx, mut writer_rx) = tokio::sync::mpsc::unbounded_channel::<Vec<u8>>();
let handshake = postcard::to_allocvec_cobs(&WorkerInbound::Handshake { vault_keys })
.expect("Failed to serialize handshake");
let _ = writer_tx.send(handshake);
let mut pending_futures: std::collections::HashMap<
Uuid,
tokio::sync::oneshot::Sender<RemoteFnResponse>,
> = std::collections::HashMap::new();
tokio::spawn(async move {
use tokio::io::AsyncBufReadExt as _;
let mut stdin = tokio::io::BufReader::new(rpc_in);
let mut buffer = Vec::with_capacity(4096);
while stdin.read_until(0x00, &mut buffer).await.unwrap_or(0) > 0 {
match postcard::from_bytes_cobs::<RemoteFnResponsePayload>(&mut buffer) {
Ok(response) => {
let _ = reader_tx.send(response);
}
Err(e) => {
panic!(
"RPC Framing Corruption! Error: {:?}\nRaw Bytes: {:?}",
e,
String::from_utf8_lossy(&buffer)
);
}
}
buffer.clear();
}
});
tokio::spawn(async move {
let mut stdout = rpc_out;
while let Some(cobs_packet) = writer_rx.recv().await {
use tokio::io::AsyncWriteExt as _;
let _ = stdout.write_all(&cobs_packet).await;
let _ = stdout.flush().await;
}
});
loop {
tokio::select! {
Some(registration) = registration_rx.recv() => {
let uuid = registration.payload.uuid;
let outbound_bytes =
postcard::to_allocvec_cobs(&WorkerInbound::Call(registration.payload))
.expect("Failed to serialize `WorkerInbound::Call`");
let _ = writer_tx.send(outbound_bytes);
pending_futures.insert(uuid, registration.tx);
}
Some(response) = reader_rx.recv() => {
match pending_futures.remove(&response.uuid) {
Some(tx) => {
let _ = tx.send(response.response);
}
None => {}
}
}
else => break,
}
}
}