use crate::{
serde::{ViaductDeserialize, ViaductSerialize},
ViaductEvent,
};
use interprocess::unnamed_pipe::{UnnamedPipeReader, UnnamedPipeWriter};
use parking_lot::{Condvar, Mutex};
use std::{
collections::BTreeSet,
io::{Read, Write},
marker::PhantomData,
mem::size_of,
sync::Arc,
time::{Duration, Instant},
};
use uuid::Uuid;
const RPC: u8 = 0;
const REQUEST: u8 = 1;
const SOME_RESPONSE: u8 = 2;
const NONE_RESPONSE: u8 = 3;
pub(super) const HELLO: &[u8] = b"Read this if you are a beautiful strong unnamed pipe who don't need no handles";
pub type Viaduct<RpcTx, RequestTx, RpcRx, RequestRx> = (
ViaductTx<RpcTx, RequestTx, RpcRx, RequestRx>,
ViaductRx<RpcTx, RequestTx, RpcRx, RequestRx>,
);
pub struct ViaductRequestResponder<RpcTx, RequestTx, RpcRx, RequestRx>
where
RpcTx: ViaductSerialize,
RequestTx: ViaductSerialize,
RpcRx: ViaductDeserialize,
RequestRx: ViaductDeserialize,
{
tx: ViaductTx<RpcTx, RequestTx, RpcRx, RequestRx>,
request_id: Uuid,
}
impl<RpcTx, RequestTx, RpcRx, RequestRx> ViaductRequestResponder<RpcTx, RequestTx, RpcRx, RequestRx>
where
RpcTx: ViaductSerialize,
RequestTx: ViaductSerialize,
RpcRx: ViaductDeserialize,
RequestRx: ViaductDeserialize,
{
pub fn respond(self, response: impl ViaductSerialize) -> Result<(), std::io::Error> {
{
let mut state = self.tx.0.state.lock();
let ViaductTxState { tx, buf, .. } = &mut *state;
response
.to_pipeable({
buf.clear();
buf
})
.expect("Failed to serialize response");
tx.write_all(&[2])?;
tx.write_all(self.request_id.as_bytes())?;
tx.write_all(&u64::to_ne_bytes(buf.len() as _))?;
tx.write_all(buf)?;
}
std::mem::forget(self);
Ok(())
}
}
impl<RpcTx, RequestTx, RpcRx, RequestRx> Drop for ViaductRequestResponder<RpcTx, RequestTx, RpcRx, RequestRx>
where
RpcTx: ViaductSerialize,
RequestTx: ViaductSerialize,
RpcRx: ViaductDeserialize,
RequestRx: ViaductDeserialize,
{
fn drop(&mut self) {
let mut state = self.tx.0.state.lock();
let ViaductTxState { tx, .. } = &mut *state;
(|| {
tx.write_all(&[3])?;
tx.write_all(self.request_id.as_bytes())?;
Ok::<_, std::io::Error>(())
})()
.unwrap();
}
}
pub struct ViaductRx<RpcTx, RequestTx, RpcRx, RequestRx>
where
RpcTx: ViaductSerialize,
RequestTx: ViaductSerialize,
RpcRx: ViaductDeserialize,
RequestRx: ViaductDeserialize,
{
pub(super) buf: Vec<u8>,
pub(super) tx: ViaductTx<RpcTx, RequestTx, RpcRx, RequestRx>,
pub(super) rx: UnnamedPipeReader,
pub(super) _phantom: PhantomData<RequestRx>,
}
impl<RpcTx, RequestTx, RpcRx, RequestRx> ViaductRx<RpcTx, RequestTx, RpcRx, RequestRx>
where
RpcTx: ViaductSerialize,
RpcRx: ViaductDeserialize,
RequestTx: ViaductSerialize,
RequestRx: ViaductDeserialize,
{
pub fn run<EventHandler>(mut self, mut event_handler: EventHandler) -> Result<(), std::io::Error>
where
EventHandler: FnMut(ViaductEvent<RpcTx, RequestTx, RpcRx, RequestRx>),
{
let recv_into_buf = |rx: &mut UnnamedPipeReader, buf: &mut Vec<u8>| -> Result<(), std::io::Error> {
let len = {
let mut len = [0u8; size_of::<u64>()];
rx.read_exact(&mut len)?;
usize::try_from(u64::from_ne_bytes(len)).expect("Viaduct packet was larger than what this architecture can handle")
};
buf.resize(len, 0);
rx.read_exact(buf)?;
Ok(())
};
loop {
let packet_type = {
let mut packet_type = [0u8];
self.rx.read_exact(&mut packet_type)?;
packet_type[0]
};
match packet_type {
RPC => {
recv_into_buf(&mut self.rx, &mut self.buf)?;
let rpc = RpcRx::from_pipeable(&self.buf).expect("Failed to deserialize RpcRx");
event_handler(ViaductEvent::Rpc(rpc));
}
REQUEST => {
let request_id = {
let mut request_id = [0u8; 16];
self.rx.read_exact(&mut request_id)?;
Uuid::from_bytes(request_id)
};
recv_into_buf(&mut self.rx, &mut self.buf)?;
event_handler(ViaductEvent::Request {
request: RequestRx::from_pipeable(&self.buf).expect("Failed to deserialize RequestRx"),
responder: ViaductRequestResponder {
tx: self.tx.clone(),
request_id,
},
});
}
SOME_RESPONSE => {
let mut response = self.tx.0.response.lock();
self.tx
.0
.response_condvar
.wait_while(&mut response, |response| response.for_request_id.is_some());
let request_id = {
let mut request_id = [0u8; 16];
self.rx.read_exact(&mut request_id)?;
Uuid::from_bytes(request_id)
};
response.buf.clear();
recv_into_buf(&mut self.rx, &mut response.buf)?;
if !response.pending.remove(&request_id) {
continue;
}
response.for_request_id = Some((request_id, true));
self.tx.0.response_condvar.notify_all();
}
NONE_RESPONSE => {
let mut response = self.tx.0.response.lock();
self.tx
.0
.response_condvar
.wait_while(&mut response, |response| response.for_request_id.is_some());
let request_id = {
let mut request_id = [0u8; 16];
self.rx.read_exact(&mut request_id)?;
Uuid::from_bytes(request_id)
};
if !response.pending.remove(&request_id) {
continue;
}
response.for_request_id = Some((request_id, false));
self.tx.0.response_condvar.notify_all();
}
_ => unreachable!(),
}
}
}
}
#[derive(Default)]
pub(super) struct ViaductResponseState {
pending: BTreeSet<Uuid>,
for_request_id: Option<(Uuid, bool)>,
buf: Vec<u8>,
}
impl ViaductResponseState {
#[inline]
fn request_id(&self) -> Option<&Uuid> {
self.for_request_id.as_ref().map(|(id, _)| id)
}
}
pub struct ViaductTx<RpcTx, RequestTx, RpcRx, RequestRx>(pub(super) Arc<ViaductTxInner<RpcTx, RequestTx, RpcRx, RequestRx>>)
where
RpcTx: ViaductSerialize,
RequestTx: ViaductSerialize,
RpcRx: ViaductDeserialize,
RequestRx: ViaductDeserialize;
pub(super) struct ViaductTxInner<RpcTx, RequestTx, RpcRx, RequestRx> {
pub(super) state: Mutex<ViaductTxState<RpcTx, RequestTx, RpcRx, RequestRx>>,
pub(super) response: Mutex<ViaductResponseState>,
pub(super) response_condvar: Condvar,
}
pub(super) struct ViaductTxState<RpcTx, RequestTx, RpcRx, RequestRx> {
pub(super) tx: UnnamedPipeWriter,
buf: Vec<u8>,
_phantom: PhantomData<(RpcTx, RequestTx, RpcRx, RequestRx)>,
}
impl<RpcTx, RequestTx, RpcRx, RequestRx> ViaductTxState<RpcTx, RequestTx, RpcRx, RequestRx>
where
RpcTx: ViaductSerialize,
RpcRx: ViaductDeserialize,
RequestTx: ViaductSerialize,
RequestRx: ViaductDeserialize,
{
#[inline]
pub(super) fn new(tx: UnnamedPipeWriter) -> Self {
Self {
buf: Vec::new(),
tx,
_phantom: Default::default(),
}
}
}
impl<RpcTx, RequestTx, RpcRx, RequestRx> ViaductTx<RpcTx, RequestTx, RpcRx, RequestRx>
where
RpcTx: ViaductSerialize,
RpcRx: ViaductDeserialize,
RequestTx: ViaductSerialize,
RequestRx: ViaductDeserialize,
{
pub fn rpc(&self, rpc: RpcTx) -> Result<(), std::io::Error> {
let mut state = self.0.state.lock();
let ViaductTxState { buf, tx, .. } = &mut *state;
rpc.to_pipeable({
buf.clear();
buf
})
.expect("Failed to serialize RpcTx");
tx.write_all(&[0])?;
tx.write_all(&u64::to_ne_bytes(buf.len() as _))?;
tx.write_all(&*buf)?;
Ok(())
}
pub fn request<Response: ViaductDeserialize>(&self, request: RequestTx) -> Result<Option<Response>, std::io::Error> {
let mut response = self.0.response.lock();
let request_id = Uuid::new_v4();
response.pending.insert(request_id);
{
let mut state = self.0.state.lock();
let ViaductTxState { buf, tx, .. } = &mut *state;
request
.to_pipeable({
buf.clear();
buf
})
.expect("Failed to serialize RequestTx");
tx.write_all(&[1])?;
tx.write_all(request_id.as_bytes())?;
tx.write_all(&u64::to_ne_bytes(buf.len() as _))?;
tx.write_all(&*buf)?;
}
self.0
.response_condvar
.wait_while(&mut response, |response| response.request_id() != Some(&request_id));
let (for_request_id, some) = response.for_request_id.take().unwrap();
debug_assert_eq!(for_request_id, request_id);
self.0.response_condvar.notify_all();
Ok(if some {
Some(Response::from_pipeable(&response.buf).expect("Failed to deserialize Response"))
} else {
None
})
}
pub fn request_timeout_at<Response: ViaductDeserialize>(
&self,
timeout_at: Instant,
request: RequestTx,
) -> Result<Option<Response>, std::io::Error> {
let mut response = self
.0
.response
.try_lock_until(timeout_at)
.ok_or_else(|| std::io::Error::from(std::io::ErrorKind::TimedOut))?;
let request_id = Uuid::new_v4();
response.pending.insert(request_id);
{
let mut state = self
.0
.state
.try_lock_until(timeout_at)
.ok_or_else(|| std::io::Error::from(std::io::ErrorKind::TimedOut))?;
let ViaductTxState { buf, tx, .. } = &mut *state;
request
.to_pipeable({
buf.clear();
buf
})
.expect("Failed to serialize RequestTx");
tx.write_all(&[1])?;
tx.write_all(request_id.as_bytes())?;
tx.write_all(&u64::to_ne_bytes(buf.len() as _))?;
tx.write_all(&*buf)?;
}
if self
.0
.response_condvar
.wait_while_until(&mut response, |response| response.request_id() != Some(&request_id), timeout_at)
.timed_out()
{
response.pending.remove(&request_id);
return Err(std::io::Error::from(std::io::ErrorKind::TimedOut));
}
let (for_request_id, some) = response.for_request_id.take().unwrap();
debug_assert_eq!(for_request_id, request_id);
self.0.response_condvar.notify_all();
Ok(if some {
Some(Response::from_pipeable(&response.buf).expect("Failed to deserialize Response"))
} else {
None
})
}
#[inline]
pub fn request_timeout<Response: ViaductDeserialize>(&self, timeout: Duration, request: RequestTx) -> Result<Option<Response>, std::io::Error> {
self.request_timeout_at(Instant::now() + timeout, request)
}
}
impl<RpcTx, RequestTx, RpcRx, RequestRx> Clone for ViaductTx<RpcTx, RequestTx, RpcRx, RequestRx>
where
RpcTx: ViaductSerialize,
RpcRx: ViaductDeserialize,
RequestTx: ViaductSerialize,
RequestRx: ViaductDeserialize,
{
#[inline]
fn clone(&self) -> Self {
Self(self.0.clone())
}
}