cindy 0.1.1

Managing infrastructure at breakneck speed.
Documentation
use std::{
    marker::PhantomData,
    pin::Pin,
    task::{Context, Poll},
};

use tokio::sync::oneshot;
use uuid::Uuid;

use crate::common::{RemoteFnPayload, RemoteFnResponse, RemoteFnResponsePayload};

pub struct OutboundRegistration {
    pub payload: RemoteFnPayload,
    pub tx: tokio::sync::oneshot::Sender<RemoteFnResponse>,
}

pub static ORCHESTRATOR_TX: std::sync::OnceLock<
    tokio::sync::mpsc::UnboundedSender<OutboundRegistration>,
> = std::sync::OnceLock::new();

#[must_use]
pub struct Future<T> {
    rx: oneshot::Receiver<RemoteFnResponse>,
    _marker: PhantomData<T>,
}

impl<T: for<'a> serde::Deserialize<'a>> Future<T> {
    pub fn new(rx: oneshot::Receiver<RemoteFnResponse>) -> Self {
        Self {
            rx,
            _marker: PhantomData,
        }
    }
}

impl<T: for<'a> serde::Deserialize<'a>> ::std::future::Future for Future<T> {
    type Output = T;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        // Safety:
        // The future doesn't actually contain a `T`. `tokio::oneshot::Receiver` implements `std::marker::Unpin`
        let f = unsafe { self.get_unchecked_mut() };

        match Pin::new(&mut f.rx).poll(cx) {
            Poll::Ready(Ok(RemoteFnResponse::Ok(return_value_bytes))) => Poll::Ready(
                postcard::from_bytes(&return_value_bytes)
                    .expect("Failed to deserialize return value"),
            ),
            // The remote worker caught a panic in the user's `#[remote]`
            // function and forwarded the message to us. Re-raise it here so
            // it bubbles up through whatever `tokio::join!`/`try_join_all`
            // the orchestrator's `main` is using and ultimately surfaces in
            // the `#[cindy::main]` `JoinError` arm as `exit(1)`. The remote
            // already printed the full panic location to its stderr, which
            // the CLI tags with `[hostname]`, so the user sees both the
            // origin and the propagation point.
            Poll::Ready(Ok(RemoteFnResponse::Panic(msg))) => {
                panic!("remote task panicked: {msg}")
            }
            Poll::Ready(Err(_)) => panic!("RPC channel closed unexpectedly"),
            Poll::Pending => Poll::Pending,
        }
    }
}

pub async fn rpc(
    mut registration_rx: tokio::sync::mpsc::UnboundedReceiver<OutboundRegistration>,
    rpc_in: tokio::fs::File,
    rpc_out: tokio::fs::File,
) {
    let (reader_tx, mut reader_rx) =
        tokio::sync::mpsc::unbounded_channel::<RemoteFnResponsePayload>();
    let (writer_tx, mut writer_rx) = tokio::sync::mpsc::unbounded_channel::<Vec<u8>>();

    let mut pending_futures: std::collections::HashMap<
        Uuid,
        tokio::sync::oneshot::Sender<RemoteFnResponse>,
    > = std::collections::HashMap::new();

    tokio::spawn(async move {
        use tokio::io::AsyncBufReadExt as _;

        let mut stdin = tokio::io::BufReader::new(rpc_in);
        let mut buffer = Vec::with_capacity(4096);

        while stdin.read_until(0x00, &mut buffer).await.unwrap_or(0) > 0 {
            match postcard::from_bytes_cobs::<RemoteFnResponsePayload>(&mut buffer) {
                Ok(response) => {
                    let _ = reader_tx.send(response);
                }
                Err(e) => {
                    panic!(
                        "RPC Framing Corruption! Error: {:?}\nRaw Bytes: {:?}",
                        e,
                        String::from_utf8_lossy(&buffer)
                    );
                }
            }
            buffer.clear();
        }
    });

    tokio::spawn(async move {
        let mut stdout = rpc_out;
        while let Some(cobs_packet) = writer_rx.recv().await {
            use tokio::io::AsyncWriteExt as _;

            let _ = stdout.write_all(&cobs_packet).await;
            let _ = stdout.flush().await;
        }
    });

    loop {
        tokio::select! {
            Some(registration) = registration_rx.recv() => {
                let uuid = registration.payload.uuid;

                let outbound_bytes = postcard::to_allocvec_cobs(&registration.payload)
                    .expect("Failed to serialize `RemoteFnPayload`");
                let _ = writer_tx.send(outbound_bytes);

                pending_futures.insert(uuid, registration.tx);
            }

            Some(response) = reader_rx.recv() => {
                match pending_futures.remove(&response.uuid) {
                    // If the receiving `Future<T>` was dropped before we got
                    // the response (e.g. a sibling task in a `try_join!`
                    // already errored and the join was short-circuited),
                    // there's nothing to deliver to. That's not a bug —
                    // silently drop instead of panicking on `expect`.
                    Some(tx) => {
                        let _ = tx.send(response.response);
                    }
                    None => panic!("received response for unknown uuid {}", response.uuid),
                }
            }

            else => break,
        }
    }
}