#![cfg(any(
feature = "flume-transport",
feature = "hyper-transport",
feature = "quinn-transport",
feature = "iroh-transport",
))]
#![allow(dead_code)]
use std::{
io::{self, Write},
result,
};
use async_stream::stream;
use derive_more::{From, TryInto};
use futures_buffered::BufferedStreamExt;
use futures_lite::{Stream, StreamExt};
use futures_util::SinkExt;
use quic_rpc::{
message::{
BidiStreaming, BidiStreamingMsg, ClientStreaming, ClientStreamingMsg, Msg, RpcMsg,
ServerStreaming, ServerStreamingMsg,
},
server::{RpcChannel, RpcServerError},
transport::StreamTypes,
Connector, Listener, RpcClient, RpcServer, Service,
};
use serde::{Deserialize, Serialize};
use thousands::Separable;
use tokio_util::task::AbortOnDropHandle;
#[derive(Debug, Serialize, Deserialize)]
pub struct Sqr(pub u64);
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct SqrResponse(pub u128);
#[derive(Debug, Serialize, Deserialize)]
pub struct Sum;
#[derive(Debug, Serialize, Deserialize)]
pub struct SumUpdate(pub u64);
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct SumResponse(pub u128);
#[derive(Debug, Serialize, Deserialize)]
pub struct Fibonacci(pub u64);
#[derive(Debug, Serialize, Deserialize)]
pub struct FibonacciResponse(pub u128);
#[derive(Debug, Serialize, Deserialize)]
pub struct Multiply(pub u64);
#[derive(Debug, Serialize, Deserialize)]
pub struct MultiplyUpdate(pub u64);
#[derive(Debug, Serialize, Deserialize)]
pub struct MultiplyResponse(pub u128);
#[derive(Debug, Serialize, Deserialize, From, TryInto)]
pub enum ComputeRequest {
Sqr(Sqr),
Sum(Sum),
SumUpdate(SumUpdate),
Fibonacci(Fibonacci),
Multiply(Multiply),
MultiplyUpdate(MultiplyUpdate),
}
#[allow(clippy::enum_variant_names)]
#[derive(Debug, Serialize, Deserialize, From, TryInto)]
pub enum ComputeResponse {
SqrResponse(SqrResponse),
SumResponse(SumResponse),
FibonacciResponse(FibonacciResponse),
MultiplyResponse(MultiplyResponse),
}
#[derive(Debug, Clone)]
pub struct ComputeService;
impl Service for ComputeService {
type Req = ComputeRequest;
type Res = ComputeResponse;
}
impl RpcMsg<ComputeService> for Sqr {
type Response = SqrResponse;
}
impl Msg<ComputeService> for Sum {
type Pattern = ClientStreaming;
}
impl ClientStreamingMsg<ComputeService> for Sum {
type Update = SumUpdate;
type Response = SumResponse;
}
impl Msg<ComputeService> for Fibonacci {
type Pattern = ServerStreaming;
}
impl ServerStreamingMsg<ComputeService> for Fibonacci {
type Response = FibonacciResponse;
}
impl Msg<ComputeService> for Multiply {
type Pattern = BidiStreaming;
}
impl BidiStreamingMsg<ComputeService> for Multiply {
type Update = MultiplyUpdate;
type Response = MultiplyResponse;
}
impl ComputeService {
async fn sqr(self, req: Sqr) -> SqrResponse {
SqrResponse(req.0 as u128 * req.0 as u128)
}
async fn sum(self, _req: Sum, updates: impl Stream<Item = SumUpdate>) -> SumResponse {
let mut sum = 0u128;
tokio::pin!(updates);
while let Some(SumUpdate(n)) = updates.next().await {
sum += n as u128;
}
SumResponse(sum)
}
fn fibonacci(self, req: Fibonacci) -> impl Stream<Item = FibonacciResponse> {
let mut a = 0u128;
let mut b = 1u128;
let mut n = req.0;
stream! {
while n > 0 {
yield FibonacciResponse(a);
let c = a + b;
a = b;
b = c;
n -= 1;
}
}
}
fn multiply(
self,
req: Multiply,
updates: impl Stream<Item = MultiplyUpdate>,
) -> impl Stream<Item = MultiplyResponse> {
let product = req.0 as u128;
stream! {
tokio::pin!(updates);
while let Some(MultiplyUpdate(n)) = updates.next().await {
yield MultiplyResponse(product * n as u128);
}
}
}
pub fn server<C: Listener<ComputeService>>(
server: RpcServer<ComputeService, C>,
) -> AbortOnDropHandle<()> {
server.spawn_accept_loop(|req, chan| Self::handle_rpc_request(ComputeService, req, chan))
}
pub async fn handle_rpc_request<E>(
self,
req: ComputeRequest,
chan: RpcChannel<ComputeService, E>,
) -> Result<(), RpcServerError<E>>
where
E: StreamTypes<In = ComputeRequest, Out = ComputeResponse>,
{
use ComputeRequest::*;
#[rustfmt::skip]
match req {
Sqr(msg) => chan.rpc(msg, self, Self::sqr).await,
Sum(msg) => chan.client_streaming(msg, self, Self::sum).await,
Fibonacci(msg) => chan.server_streaming(msg, self, Self::fibonacci).await,
Multiply(msg) => chan.bidi_streaming(msg, self, Self::multiply).await,
MultiplyUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?,
SumUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?,
}?;
Ok(())
}
pub async fn server_bounded<C: Listener<ComputeService>>(
server: RpcServer<ComputeService, C>,
count: usize,
) -> result::Result<RpcServer<ComputeService, C>, RpcServerError<C>> {
tracing::info!(%count, "server running");
let s = server;
let mut received = 0;
let service = ComputeService;
while received < count {
received += 1;
let (req, chan) = s.accept().await?.read_first().await?;
let service = service.clone();
tokio::spawn(async move {
use ComputeRequest::*;
tracing::info!(?req, "got request");
#[rustfmt::skip]
match req {
Sqr(msg) => chan.rpc(msg, service, ComputeService::sqr).await,
Sum(msg) => chan.client_streaming(msg, service, ComputeService::sum).await,
Fibonacci(msg) => chan.server_streaming(msg, service, ComputeService::fibonacci).await,
Multiply(msg) => chan.bidi_streaming(msg, service, ComputeService::multiply).await,
SumUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?,
MultiplyUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?,
}?;
Ok::<_, RpcServerError<C>>(())
});
}
tracing::info!(%count, "server finished");
Ok(s)
}
pub async fn server_par<C: Listener<ComputeService>>(
server: RpcServer<ComputeService, C>,
parallelism: usize,
) -> result::Result<(), RpcServerError<C>> {
let s = server.clone();
let s2 = s.clone();
let service = ComputeService;
let request_stream = stream! {
loop {
yield s2.accept().await?.read_first().await;
}
};
let process_stream = request_stream.map(move |r| {
let service = service.clone();
async move {
let (req, chan) = r?;
use ComputeRequest::*;
#[rustfmt::skip]
match req {
Sqr(msg) => chan.rpc(msg, service, ComputeService::sqr).await,
Sum(msg) => chan.client_streaming(msg, service, ComputeService::sum).await,
Fibonacci(msg) => chan.server_streaming(msg, service, ComputeService::fibonacci).await,
Multiply(msg) => chan.bidi_streaming(msg, service, ComputeService::multiply).await,
SumUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?,
MultiplyUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?,
}?;
Ok::<_, RpcServerError<C>>(())
}
});
process_stream
.buffered_unordered(parallelism)
.for_each(|x| {
if let Err(e) = x {
eprintln!("error: {e:?}");
}
})
.await;
Ok(())
}
}
pub async fn smoke_test<C: Connector<ComputeService>>(client: C) -> anyhow::Result<()> {
let client = RpcClient::<ComputeService, C>::new(client);
tracing::debug!("calling rpc S(1234)");
let res = client.rpc(Sqr(1234)).await?;
tracing::debug!("got response {:?}", res);
assert_eq!(res, SqrResponse(1522756));
tracing::debug!("calling client_streaming Sum");
let (mut send, recv) = client.client_streaming(Sum).await?;
tokio::task::spawn(async move {
for i in 1..=3 {
send.send(SumUpdate(i)).await?;
}
Ok::<_, C::SendError>(())
});
let res = recv.await?;
tracing::debug!("got response {:?}", res);
assert_eq!(res, SumResponse(6));
tracing::debug!("calling server_streaming Fibonacci(10)");
let s = client.server_streaming(Fibonacci(10)).await?;
let res: Vec<_> = s.map(|x| x.map(|x| x.0)).try_collect().await?;
tracing::debug!("got response {:?}", res);
assert_eq!(res, vec![0, 1, 1, 2, 3, 5, 8, 13, 21, 34]);
tracing::debug!("calling bidi Multiply(2)");
let (mut send, recv) = client.bidi(Multiply(2)).await?;
tokio::task::spawn(async move {
for i in 1..=3 {
send.send(MultiplyUpdate(i)).await?;
}
Ok::<_, C::SendError>(())
});
let res: Vec<_> = recv.map(|x| x.map(|x| x.0)).try_collect().await?;
tracing::debug!("got response {:?}", res);
assert_eq!(res, vec![2, 4, 6]);
tracing::debug!("dropping client!");
Ok(())
}
fn clear_line() {
print!("\r{}\r", " ".repeat(80));
}
pub async fn bench<C>(client: RpcClient<ComputeService, C>, n: u64) -> anyhow::Result<()>
where
C::SendError: std::error::Error,
C: Connector<ComputeService>,
{
{
let mut sum = 0;
let t0 = std::time::Instant::now();
for i in 0..n {
sum += client.rpc(Sqr(i)).await?.0;
if i % 10000 == 0 {
print!(".");
io::stdout().flush()?;
}
}
let rps = ((n as f64) / t0.elapsed().as_secs_f64()).round();
assert_eq!(sum, sum_of_squares(n));
clear_line();
println!("RPC seq {} rps", rps.separate_with_underscores(),);
}
{
let t0 = std::time::Instant::now();
let reqs = futures_lite::stream::iter((0..n).map(Sqr));
let resp: Vec<_> = reqs
.map(|x| {
let client = client.clone();
async move {
let res = client.rpc(x).await?.0;
anyhow::Ok(res)
}
})
.buffered_unordered(32)
.try_collect()
.await?;
let sum = resp.into_iter().sum::<u128>();
let rps = ((n as f64) / t0.elapsed().as_secs_f64()).round();
assert_eq!(sum, sum_of_squares(n));
clear_line();
println!("RPC par {} rps", rps.separate_with_underscores(),);
}
{
let t0 = std::time::Instant::now();
let (send, recv) = client.bidi(Multiply(2)).await?;
let handle = tokio::task::spawn(async move {
let requests = futures_lite::stream::iter((0..n).map(MultiplyUpdate));
futures_util::StreamExt::forward(requests.map(Ok), send).await?;
anyhow::Result::<()>::Ok(())
});
let mut sum = 0;
tokio::pin!(recv);
let mut i = 0;
while let Some(res) = recv.next().await {
sum += res?.0;
if i % 10000 == 0 {
print!(".");
io::stdout().flush()?;
}
i += 1;
}
assert_eq!(sum, (0..n as u128).map(|x| x * 2).sum());
let rps = ((n as f64) / t0.elapsed().as_secs_f64()).round();
clear_line();
println!("bidi seq {} rps", rps.separate_with_underscores(),);
handle.await??;
}
Ok(())
}
fn sum_of_squares(n: u64) -> u128 {
(0..n).map(|x| (x * x) as u128).sum()
}