use std::future::Future;
use std::pin::Pin;
use std::{collections::HashMap, sync::LazyLock};
use crate::common::{RemoteFnPayload, RemoteFnResponse, RemoteFnResponsePayload, WorkerInbound};
fn panic_message(payload: &(dyn std::any::Any + Send)) -> String {
if let Some(s) = payload.downcast_ref::<&'static str>() {
(*s).to_string()
} else if let Some(s) = payload.downcast_ref::<String>() {
s.clone()
} else {
"<non-string panic payload>".to_string()
}
}
#[doc(hidden)]
pub type RemoteFnId = &'static str;
#[doc(hidden)]
pub type RemoteFnFuture = Pin<Box<dyn Future<Output = Vec<u8>> + Send>>;
#[doc(hidden)]
pub type RemoteFnSignature = fn(args: &[u8]) -> RemoteFnFuture;
#[doc(hidden)]
pub struct RemoteFn {
pub id: RemoteFnId,
pub function: RemoteFnSignature,
}
inventory::collect!(RemoteFn);
pub static REMOTE_FN_MAP: LazyLock<HashMap<RemoteFnId, RemoteFnSignature>> = LazyLock::new(|| {
inventory::iter::<RemoteFn>
.into_iter()
.map(|rfn| (rfn.id, rfn.function))
.collect()
});
fn worker_vault_preflight() -> crate::Result<()> {
let missing = crate::secret::missing_vaults(crate::secret::registered_vaults());
if missing.is_empty() {
return Ok(());
}
crate::bail!(
"missing decryption keys on the remote worker: no DEK for vault(s) {missing:?}. \
The orchestrator did not ship a key for them — ensure `keys/<name>.dek` exists \
on the controller for every vault this play's `secret!`s reference."
)
}
pub async fn rpc(rpc_in: tokio::fs::File, rpc_out: tokio::fs::File) {
std::fs::remove_file(std::env::current_exe().expect("Couldn't get current executable path"))
.expect("Couldn't remove the executable");
let (process_tx, mut process_rx) = tokio::sync::mpsc::unbounded_channel::<RemoteFnPayload>();
let (writer_tx, mut writer_rx) = tokio::sync::mpsc::unbounded_channel::<Vec<u8>>();
tokio::spawn(async move {
use tokio::io::AsyncBufReadExt as _;
let mut stdin = tokio::io::BufReader::new(rpc_in);
let mut buffer = Vec::with_capacity(4096);
while stdin.read_until(0x00, &mut buffer).await.unwrap_or(0) > 0 {
match postcard::from_bytes_cobs::<WorkerInbound>(&mut buffer) {
Ok(WorkerInbound::Handshake { vault_keys }) => {
if let Err(e) = crate::secret::keychain::install_raw_keys(vault_keys)
.and_then(|()| worker_vault_preflight())
{
eprintln!("\x1b[31m{e:?}\x1b[0m");
std::process::exit(1);
}
}
Ok(WorkerInbound::Call(payload)) => {
process_tx.send(payload).expect("Processing channel closed");
}
Err(_) => {}
}
buffer.clear();
}
});
tokio::spawn(async move {
let mut stdout = rpc_out;
while let Some(cobs_packet) = writer_rx.recv().await {
use tokio::io::AsyncWriteExt as _;
stdout
.write_all(&cobs_packet)
.await
.expect("Failed to write to stdout");
stdout.flush().await.expect("Failed to flush stdout");
}
});
while let Some(payload) = process_rx.recv().await {
let writer_tx = writer_tx.clone();
tokio::spawn(async move {
let uuid = payload.uuid;
let work = tokio::spawn(async move {
let registered_fn =
REMOTE_FN_MAP
.get(payload.fn_id.as_str())
.unwrap_or_else(|| {
panic!("called an unregistered remote function: {}", payload.fn_id)
});
registered_fn(&payload.data).await
});
let response = match work.await {
Ok(bytes) => RemoteFnResponse::Ok(bytes),
Err(join_err) if join_err.is_panic() => {
RemoteFnResponse::Panic(panic_message(&*join_err.into_panic()))
}
Err(_) => RemoteFnResponse::Panic("remote task was cancelled".to_string()),
};
let response_payload = RemoteFnResponsePayload { uuid, response };
let outbound_cobs = postcard::to_allocvec_cobs(&response_payload)
.expect("Failed to serialize response");
writer_tx
.send(outbound_cobs)
.expect("Writer channel closed");
});
}
}