volo 0.12.3

Volo is a high-performance and strong-extensibility Rust RPC framework that helps developers build microservices.
Documentation
use std::{fmt::Debug, sync::Arc};

use async_broadcast::RecvError;
use motore::Service;
use tracing::warn;

use super::error::{LoadBalanceError, Retryable};
use crate::{Layer, context::Context, discovery::Discover, loadbalance::LoadBalance};

#[derive(Clone)]
pub struct LoadBalanceService<D, LB, S> {
    discover: D,
    load_balance: Arc<LB>,
    service: S,
    retry: usize,
}

impl<D, LB, S> LoadBalanceService<D, LB, S>
where
    D: Discover,
    LB: LoadBalance<D>,
{
    pub fn new(discover: D, load_balance: LB, service: S, retry: usize) -> Self {
        let lb = Arc::new(load_balance);

        let service = Self {
            discover,
            load_balance: lb.clone(),
            service,
            retry,
        };

        if let Some(mut channel) = service.discover.watch(None) {
            tokio::spawn(async move {
                loop {
                    match channel.recv().await {
                        Ok(recv) => lb.rebalance(recv),
                        Err(err) => match err {
                            RecvError::Closed => break,
                            _ => warn!("[VOLO] discovering subscription error: {:?}", err),
                        },
                    }
                }
            });
        }
        service
    }
}

impl<Cx, Req, D, LB, S> Service<Cx, Req> for LoadBalanceService<D, LB, S>
where
    Cx: 'static + Context + Send + Sync,
    D: Discover,
    LB: LoadBalance<D>,
    S: Service<Cx, Req> + 'static + Send + Sync,
    LoadBalanceError: Into<S::Error>,
    S::Error: Debug + Retryable,
    Req: Clone + Send + Sync + 'static,
{
    type Response = S::Response;

    type Error = S::Error;

    async fn call(&self, cx: &mut Cx, req: Req) -> Result<Self::Response, Self::Error> {
        let callee = cx.rpc_info().callee();

        let picker = match &callee.address {
            None => self
                .load_balance
                .get_picker(callee, &self.discover)
                .await
                .map_err(|err| err.into())?,
            _ => {
                return self.service.call(cx, req).await;
            }
        };
        let mut call_count = 0;
        for (addr, _) in picker.zip(0..self.retry + 1) {
            call_count += 1;
            cx.rpc_info_mut().callee_mut().address = Some(addr.clone());

            match self.service.call(cx, req.clone()).await {
                Ok(resp) => {
                    return Ok(resp);
                }
                Err(err) => {
                    warn!("[VOLO] call rpcinfo: {:?}, error: {:?}", cx.rpc_info(), err);
                    if !err.retryable() {
                        return Err(err);
                    }
                }
            }
        }
        if call_count == 0 {
            warn!("[VOLO] zero call count, call rpcinfo: {:?}", cx.rpc_info());
        }
        Err(LoadBalanceError::Retry.into())
    }
}

impl<D, LB, S> Debug for LoadBalanceService<D, LB, S>
where
    D: Debug,
    LB: Debug,
    S: Debug,
{
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("LBService")
            .field("discover", &self.discover)
            .field("load_balancer", &self.load_balance)
            .finish()
    }
}

#[derive(Clone, Default, Copy)]
pub struct LoadBalanceLayer<D, LB> {
    discover: D,
    load_balance: LB,
    retry_count: usize,
}

impl<D, LB> LoadBalanceLayer<D, LB> {
    pub fn new(discover: D, load_balance: LB, retry_count: usize) -> Self {
        LoadBalanceLayer {
            discover,
            load_balance,
            retry_count,
        }
    }
}

impl<D, LB, S> Layer<S> for LoadBalanceLayer<D, LB>
where
    D: Discover,
    LB: LoadBalance<D>,
{
    type Service = LoadBalanceService<D, LB, S>;

    fn layer(self, inner: S) -> Self::Service {
        LoadBalanceService::new(self.discover, self.load_balance, inner, self.retry_count)
    }
}

#[cfg(test)]
mod tests {
    use std::convert::Infallible;

    use motore::service::service_fn;

    use super::LoadBalanceService;
    use crate::{discovery::StaticDiscover, loadbalance::random::WeightedRandomBalance};

    #[derive(Debug)]
    struct MotoreContext;

    async fn handle(cx: &mut MotoreContext, request: String) -> Result<String, Infallible> {
        println!("{cx:?}, {request:?}");
        Ok::<_, Infallible>(request.to_uppercase())
    }

    #[test]
    fn test_service() {
        let discover = StaticDiscover::from(vec!["127.0.0.1:8000".parse().unwrap()]);
        let lb = WeightedRandomBalance::with_discover(&discover);
        let service = service_fn(handle);

        LoadBalanceService::new(discover, lb, service, 1);
    }
}