quic-rpc 0.20.0

A streaming rpc system based on quic
Documentation
#![cfg(any(
    feature = "flume-transport",
    feature = "hyper-transport",
    feature = "quinn-transport",
    feature = "iroh-transport",
))]
#![allow(dead_code)]
use std::{
    io::{self, Write},
    result,
};

use async_stream::stream;
use derive_more::{From, TryInto};
use futures_buffered::BufferedStreamExt;
use futures_lite::{Stream, StreamExt};
use futures_util::SinkExt;
use quic_rpc::{
    message::{
        BidiStreaming, BidiStreamingMsg, ClientStreaming, ClientStreamingMsg, Msg, RpcMsg,
        ServerStreaming, ServerStreamingMsg,
    },
    server::{RpcChannel, RpcServerError},
    transport::StreamTypes,
    Connector, Listener, RpcClient, RpcServer, Service,
};
use serde::{Deserialize, Serialize};
use thousands::Separable;
use tokio_util::task::AbortOnDropHandle;

/// compute the square of a number
#[derive(Debug, Serialize, Deserialize)]
pub struct Sqr(pub u64);

#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct SqrResponse(pub u128);

/// sum a stream of numbers
#[derive(Debug, Serialize, Deserialize)]
pub struct Sum;

#[derive(Debug, Serialize, Deserialize)]
pub struct SumUpdate(pub u64);

#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct SumResponse(pub u128);

/// compute the fibonacci sequence as a stream
#[derive(Debug, Serialize, Deserialize)]
pub struct Fibonacci(pub u64);

#[derive(Debug, Serialize, Deserialize)]
pub struct FibonacciResponse(pub u128);

/// multiply a stream of numbers, returning a stream
#[derive(Debug, Serialize, Deserialize)]
pub struct Multiply(pub u64);

#[derive(Debug, Serialize, Deserialize)]
pub struct MultiplyUpdate(pub u64);

#[derive(Debug, Serialize, Deserialize)]
pub struct MultiplyResponse(pub u128);

/// request enum
#[derive(Debug, Serialize, Deserialize, From, TryInto)]
pub enum ComputeRequest {
    Sqr(Sqr),
    Sum(Sum),
    SumUpdate(SumUpdate),
    Fibonacci(Fibonacci),
    Multiply(Multiply),
    MultiplyUpdate(MultiplyUpdate),
}

/// response enum
#[allow(clippy::enum_variant_names)]
#[derive(Debug, Serialize, Deserialize, From, TryInto)]
pub enum ComputeResponse {
    SqrResponse(SqrResponse),
    SumResponse(SumResponse),
    FibonacciResponse(FibonacciResponse),
    MultiplyResponse(MultiplyResponse),
}

#[derive(Debug, Clone)]
pub struct ComputeService;

impl Service for ComputeService {
    type Req = ComputeRequest;
    type Res = ComputeResponse;
}

impl RpcMsg<ComputeService> for Sqr {
    type Response = SqrResponse;
}

impl Msg<ComputeService> for Sum {
    type Pattern = ClientStreaming;
}

impl ClientStreamingMsg<ComputeService> for Sum {
    type Update = SumUpdate;
    type Response = SumResponse;
}

impl Msg<ComputeService> for Fibonacci {
    type Pattern = ServerStreaming;
}

impl ServerStreamingMsg<ComputeService> for Fibonacci {
    type Response = FibonacciResponse;
}

impl Msg<ComputeService> for Multiply {
    type Pattern = BidiStreaming;
}

impl BidiStreamingMsg<ComputeService> for Multiply {
    type Update = MultiplyUpdate;
    type Response = MultiplyResponse;
}

impl ComputeService {
    async fn sqr(self, req: Sqr) -> SqrResponse {
        SqrResponse(req.0 as u128 * req.0 as u128)
    }

    async fn sum(self, _req: Sum, updates: impl Stream<Item = SumUpdate>) -> SumResponse {
        let mut sum = 0u128;
        tokio::pin!(updates);
        while let Some(SumUpdate(n)) = updates.next().await {
            sum += n as u128;
        }
        SumResponse(sum)
    }

    fn fibonacci(self, req: Fibonacci) -> impl Stream<Item = FibonacciResponse> {
        let mut a = 0u128;
        let mut b = 1u128;
        let mut n = req.0;
        stream! {
            while n > 0 {
                yield FibonacciResponse(a);
                let c = a + b;
                a = b;
                b = c;
                n -= 1;
            }
        }
    }

    fn multiply(
        self,
        req: Multiply,
        updates: impl Stream<Item = MultiplyUpdate>,
    ) -> impl Stream<Item = MultiplyResponse> {
        let product = req.0 as u128;
        stream! {
            tokio::pin!(updates);
            while let Some(MultiplyUpdate(n)) = updates.next().await {
                yield MultiplyResponse(product * n as u128);
            }
        }
    }

    pub fn server<C: Listener<ComputeService>>(
        server: RpcServer<ComputeService, C>,
    ) -> AbortOnDropHandle<()> {
        server.spawn_accept_loop(|req, chan| Self::handle_rpc_request(ComputeService, req, chan))
    }

    pub async fn handle_rpc_request<E>(
        self,
        req: ComputeRequest,
        chan: RpcChannel<ComputeService, E>,
    ) -> Result<(), RpcServerError<E>>
    where
        E: StreamTypes<In = ComputeRequest, Out = ComputeResponse>,
    {
        use ComputeRequest::*;
        #[rustfmt::skip]
        match req {
            Sqr(msg) => chan.rpc(msg, self, Self::sqr).await,
            Sum(msg) => chan.client_streaming(msg, self, Self::sum).await,
            Fibonacci(msg) => chan.server_streaming(msg, self, Self::fibonacci).await,
            Multiply(msg) => chan.bidi_streaming(msg, self, Self::multiply).await,
            MultiplyUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?,
            SumUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?,
        }?;
        Ok(())
    }

    /// Runs the service until `count` requests have been received.
    pub async fn server_bounded<C: Listener<ComputeService>>(
        server: RpcServer<ComputeService, C>,
        count: usize,
    ) -> result::Result<RpcServer<ComputeService, C>, RpcServerError<C>> {
        tracing::info!(%count, "server running");
        let s = server;
        let mut received = 0;
        let service = ComputeService;
        while received < count {
            received += 1;
            let (req, chan) = s.accept().await?.read_first().await?;
            let service = service.clone();
            tokio::spawn(async move {
                use ComputeRequest::*;
                tracing::info!(?req, "got request");
                #[rustfmt::skip]
                match req {
                    Sqr(msg) => chan.rpc(msg, service, ComputeService::sqr).await,
                    Sum(msg) => chan.client_streaming(msg, service, ComputeService::sum).await,
                    Fibonacci(msg) => chan.server_streaming(msg, service, ComputeService::fibonacci).await,
                    Multiply(msg) => chan.bidi_streaming(msg, service, ComputeService::multiply).await,
                    SumUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?,
                    MultiplyUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?,
                }?;
                Ok::<_, RpcServerError<C>>(())
            });
        }
        tracing::info!(%count, "server finished");
        Ok(s)
    }

    pub async fn server_par<C: Listener<ComputeService>>(
        server: RpcServer<ComputeService, C>,
        parallelism: usize,
    ) -> result::Result<(), RpcServerError<C>> {
        let s = server.clone();
        let s2 = s.clone();
        let service = ComputeService;
        let request_stream = stream! {
            loop {
                yield s2.accept().await?.read_first().await;
            }
        };
        let process_stream = request_stream.map(move |r| {
            let service = service.clone();
            async move {
                let (req, chan) = r?;
                use ComputeRequest::*;
                #[rustfmt::skip]
                match req {
                    Sqr(msg) => chan.rpc(msg, service, ComputeService::sqr).await,
                    Sum(msg) => chan.client_streaming(msg, service, ComputeService::sum).await,
                    Fibonacci(msg) => chan.server_streaming(msg, service, ComputeService::fibonacci).await,
                    Multiply(msg) => chan.bidi_streaming(msg, service, ComputeService::multiply).await,
                    SumUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?,
                    MultiplyUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?,
                }?;
                Ok::<_, RpcServerError<C>>(())
            }
        });
        process_stream
            .buffered_unordered(parallelism)
            .for_each(|x| {
                if let Err(e) = x {
                    eprintln!("error: {e:?}");
                }
            })
            .await;
        Ok(())
    }
}

pub async fn smoke_test<C: Connector<ComputeService>>(client: C) -> anyhow::Result<()> {
    let client = RpcClient::<ComputeService, C>::new(client);
    // a rpc call
    tracing::debug!("calling rpc S(1234)");
    let res = client.rpc(Sqr(1234)).await?;
    tracing::debug!("got response {:?}", res);
    assert_eq!(res, SqrResponse(1522756));

    // client streaming call
    tracing::debug!("calling client_streaming Sum");
    let (mut send, recv) = client.client_streaming(Sum).await?;
    tokio::task::spawn(async move {
        for i in 1..=3 {
            send.send(SumUpdate(i)).await?;
        }
        Ok::<_, C::SendError>(())
    });
    let res = recv.await?;
    tracing::debug!("got response {:?}", res);
    assert_eq!(res, SumResponse(6));

    // server streaming call
    tracing::debug!("calling server_streaming Fibonacci(10)");
    let s = client.server_streaming(Fibonacci(10)).await?;
    let res: Vec<_> = s.map(|x| x.map(|x| x.0)).try_collect().await?;
    tracing::debug!("got response {:?}", res);
    assert_eq!(res, vec![0, 1, 1, 2, 3, 5, 8, 13, 21, 34]);

    // bidi streaming call
    tracing::debug!("calling bidi Multiply(2)");
    let (mut send, recv) = client.bidi(Multiply(2)).await?;
    tokio::task::spawn(async move {
        for i in 1..=3 {
            send.send(MultiplyUpdate(i)).await?;
        }
        Ok::<_, C::SendError>(())
    });
    let res: Vec<_> = recv.map(|x| x.map(|x| x.0)).try_collect().await?;
    tracing::debug!("got response {:?}", res);
    assert_eq!(res, vec![2, 4, 6]);

    tracing::debug!("dropping client!");
    Ok(())
}

fn clear_line() {
    print!("\r{}\r", " ".repeat(80));
}

pub async fn bench<C>(client: RpcClient<ComputeService, C>, n: u64) -> anyhow::Result<()>
where
    C::SendError: std::error::Error,
    C: Connector<ComputeService>,
{
    // individual RPCs
    {
        let mut sum = 0;
        let t0 = std::time::Instant::now();
        for i in 0..n {
            sum += client.rpc(Sqr(i)).await?.0;
            if i % 10000 == 0 {
                print!(".");
                io::stdout().flush()?;
            }
        }
        let rps = ((n as f64) / t0.elapsed().as_secs_f64()).round();
        assert_eq!(sum, sum_of_squares(n));
        clear_line();
        println!("RPC seq {} rps", rps.separate_with_underscores(),);
    }
    // parallel RPCs
    {
        let t0 = std::time::Instant::now();
        let reqs = futures_lite::stream::iter((0..n).map(Sqr));
        let resp: Vec<_> = reqs
            .map(|x| {
                let client = client.clone();
                async move {
                    let res = client.rpc(x).await?.0;
                    anyhow::Ok(res)
                }
            })
            .buffered_unordered(32)
            .try_collect()
            .await?;
        let sum = resp.into_iter().sum::<u128>();
        let rps = ((n as f64) / t0.elapsed().as_secs_f64()).round();
        assert_eq!(sum, sum_of_squares(n));
        clear_line();
        println!("RPC par {} rps", rps.separate_with_underscores(),);
    }
    // sequential streaming
    {
        let t0 = std::time::Instant::now();
        let (send, recv) = client.bidi(Multiply(2)).await?;
        let handle = tokio::task::spawn(async move {
            let requests = futures_lite::stream::iter((0..n).map(MultiplyUpdate));
            futures_util::StreamExt::forward(requests.map(Ok), send).await?;
            anyhow::Result::<()>::Ok(())
        });
        let mut sum = 0;
        tokio::pin!(recv);
        let mut i = 0;
        while let Some(res) = recv.next().await {
            sum += res?.0;
            if i % 10000 == 0 {
                print!(".");
                io::stdout().flush()?;
            }
            i += 1;
        }
        assert_eq!(sum, (0..n as u128).map(|x| x * 2).sum());
        let rps = ((n as f64) / t0.elapsed().as_secs_f64()).round();
        clear_line();
        println!("bidi seq {} rps", rps.separate_with_underscores(),);

        handle.await??;
    }
    Ok(())
}

fn sum_of_squares(n: u64) -> u128 {
    (0..n).map(|x| (x * x) as u128).sum()
}