use crate::{TransportError, TransportFut};
use alloy_json_rpc::{RequestPacket, ResponsePacket};
use tower::Service;
pub trait DualTransportHandler<L, R> {
fn call(&self, request: RequestPacket, left: L, right: R) -> TransportFut<'static>;
}
impl<F, L, R> DualTransportHandler<L, R> for F
where
F: Fn(RequestPacket, L, R) -> TransportFut<'static> + Send + Sync,
{
fn call(&self, request: RequestPacket, left: L, right: R) -> TransportFut<'static> {
(self)(request, left, right)
}
}
#[derive(Debug, Clone)]
pub struct DualTransport<L, R, H> {
left: L,
right: R,
handler: H,
}
impl<L, R, H> DualTransport<L, R, H> {
pub fn new(left: L, right: R, handler: H) -> Self {
Self { left, right, handler }
}
pub fn new_handler<F>(left: L, right: R, f: F) -> DualTransport<L, R, F>
where
F: Fn(RequestPacket, L, R) -> TransportFut<'static> + Send + Sync,
{
DualTransport { left, right, handler: f }
}
}
impl<L, R, H> Service<RequestPacket> for DualTransport<L, R, H>
where
L: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
+ Send
+ Sync
+ Clone
+ 'static,
R: Service<RequestPacket, Future = TransportFut<'static>, Error = TransportError>
+ Send
+ Sync
+ Clone
+ 'static,
H: DualTransportHandler<L, R> + 'static,
{
type Response = ResponsePacket;
type Error = TransportError;
type Future = TransportFut<'static>;
#[inline]
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
match (self.left.poll_ready(cx), self.right.poll_ready(cx)) {
(std::task::Poll::Ready(Ok(())), std::task::Poll::Ready(Ok(()))) => {
std::task::Poll::Ready(Ok(()))
}
(std::task::Poll::Ready(Err(e)), _) => std::task::Poll::Ready(Err(e)),
(_, std::task::Poll::Ready(Err(e))) => std::task::Poll::Ready(Err(e)),
_ => std::task::Poll::Pending,
}
}
#[inline]
fn call(&mut self, req: RequestPacket) -> Self::Future {
self.handler.call(req, self.left.clone(), self.right.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloy_json_rpc::{Id, Request, Response, ResponsePayload};
use alloy_primitives::B256;
use serde_json::value::RawValue;
use std::task::{Context, Poll};
fn request_fn<T>(f: T) -> RequestFn<T>
where
T: FnMut(RequestPacket) -> TransportFut<'static>,
{
RequestFn { f }
}
#[derive(Copy, Clone)]
struct RequestFn<T> {
f: T,
}
impl<T> Service<RequestPacket> for RequestFn<T>
where
T: FnMut(RequestPacket) -> TransportFut<'static>,
{
type Response = ResponsePacket;
type Error = TransportError;
type Future = TransportFut<'static>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), TransportError>> {
Ok(()).into()
}
fn call(&mut self, req: RequestPacket) -> Self::Future {
(self.f)(req)
}
}
fn make_hash_response() -> ResponsePacket {
ResponsePacket::Single(Response {
id: Id::Number(0),
payload: ResponsePayload::Success(
RawValue::from_string(serde_json::to_string(&B256::ZERO).unwrap()).unwrap(),
),
})
}
#[tokio::test]
async fn test_dual_transport() {
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
let left = request_fn(move |request: RequestPacket| {
let tx = tx.clone();
Box::pin(async move {
tx.send(request).unwrap();
Ok::<_, TransportError>(make_hash_response())
})
});
let right = request_fn(|_request: RequestPacket| {
Box::pin(async move { Ok::<_, TransportError>(make_hash_response()) })
});
let handler = |req: RequestPacket, mut left: RequestFn<_>, mut right: RequestFn<_>| {
let id = match &req {
RequestPacket::Single(req) => req.id().as_number().unwrap_or(0),
RequestPacket::Batch(reqs) => {
reqs.first().map(|r| r.id().as_number().unwrap_or(0)).unwrap_or(0)
}
};
if id % 2 == 0 {
left.call(req)
} else {
right.call(req)
}
};
let mut dual_transport = DualTransport::new(left, right, handler);
let req_even = RequestPacket::Single(
Request::new("test", Id::Number(2), None::<&'static RawValue>).try_into().unwrap(),
);
let _ = dual_transport.call(req_even.clone()).await.unwrap();
let received = rx.try_recv().unwrap();
match &received {
RequestPacket::Single(req) => assert_eq!(*req.id(), Id::Number(2)),
_ => panic!("Expected Single RequestPacket with id 2, but got something else"),
}
let req_odd = RequestPacket::Single(
Request::new("test", Id::Number(1), None::<&'static RawValue>)
.try_into()
.expect("Failed to serialize request"),
);
let _ = dual_transport.call(req_odd.clone()).await.unwrap();
assert!(rx.try_recv().is_err(), "Received unexpected request for odd ID");
}
}