use std::{fmt::Debug, sync::Arc};
use async_broadcast::RecvError;
use motore::{layer::Layer, service::Service};
use volo::{
context::Context,
discovery::Discover,
loadbalance::{LoadBalance, MkLbLayer, random::WeightedRandomBalance},
};
use super::dns::{DiscoverKey, DnsResolver};
use crate::{
context::ClientContext,
error::{
ClientError,
client::{lb_error, no_available_endpoint},
},
request::Request,
};
pub type DefaultLb = LbConfig<WeightedRandomBalance<DiscoverKey>, DnsResolver>;
pub type DefaultLbService<S> =
LoadBalanceService<WeightedRandomBalance<DiscoverKey>, DnsResolver, S>;
pub struct LbConfig<L, D> {
load_balance: L,
discover: D,
}
impl Default for DefaultLb {
fn default() -> Self {
LbConfig::new(WeightedRandomBalance::new(), DnsResolver::default())
}
}
impl<L, D> LbConfig<L, D> {
pub fn new(load_balance: L, discover: D) -> Self {
LbConfig {
load_balance,
discover,
}
}
pub fn load_balance<NL>(self, load_balance: NL) -> LbConfig<NL, D> {
LbConfig {
load_balance,
discover: self.discover,
}
}
pub fn discover<ND>(self, discover: ND) -> LbConfig<L, ND> {
LbConfig {
load_balance: self.load_balance,
discover,
}
}
}
impl<LB, D> MkLbLayer for LbConfig<LB, D> {
type Layer = LoadBalanceLayer<LB, D>;
fn make(self) -> Self::Layer {
LoadBalanceLayer::new(self.load_balance, self.discover)
}
}
#[derive(Clone, Default, Copy)]
pub struct LoadBalanceLayer<LB, D> {
load_balance: LB,
discover: D,
}
impl<LB, D> LoadBalanceLayer<LB, D> {
fn new(load_balance: LB, discover: D) -> Self {
LoadBalanceLayer {
load_balance,
discover,
}
}
}
impl<LB, D, S> Layer<S> for LoadBalanceLayer<LB, D>
where
LB: LoadBalance<D>,
D: Discover,
{
type Service = LoadBalanceService<LB, D, S>;
fn layer(self, inner: S) -> Self::Service {
LoadBalanceService::new(self.load_balance, self.discover, inner)
}
}
#[derive(Clone)]
pub struct LoadBalanceService<LB, D, S> {
load_balance: Arc<LB>,
discover: D,
service: S,
}
impl<LB, D, S> LoadBalanceService<LB, D, S>
where
LB: LoadBalance<D>,
D: Discover,
{
fn new(load_balance: LB, discover: D, service: S) -> Self {
let lb = Arc::new(load_balance);
let service = Self {
load_balance: lb.clone(),
discover,
service,
};
let Some(mut channel) = service.discover.watch(None) else {
return service;
};
tokio::spawn(async move {
loop {
match channel.recv().await {
Ok(recv) => lb.rebalance(recv),
Err(err) => match err {
RecvError::Closed => break,
_ => tracing::warn!("[Volo-HTTP] discovering subscription error: {err}"),
},
}
}
});
service
}
}
impl<LB, D, S, B> Service<ClientContext, Request<B>> for LoadBalanceService<LB, D, S>
where
LB: LoadBalance<D>,
D: Discover,
S: Service<ClientContext, Request<B>, Error = ClientError> + Send + Sync,
B: Send,
{
type Response = S::Response;
type Error = S::Error;
async fn call(
&self,
cx: &mut ClientContext,
req: Request<B>,
) -> Result<Self::Response, Self::Error> {
let callee = cx.rpc_info().callee();
let mut picker = match &callee.address {
None => self
.load_balance
.get_picker(callee, &self.discover)
.await
.map_err(lb_error)?,
_ => {
return self.service.call(cx, req).await;
}
};
let addr = picker.next().ok_or_else(no_available_endpoint)?;
cx.rpc_info_mut().callee_mut().set_address(addr);
self.service.call(cx, req).await
}
}
impl<LB, D, S> Debug for LoadBalanceService<LB, D, S>
where
LB: Debug,
D: Debug,
S: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LBService")
.field("load_balancer", &self.load_balance)
.field("discover", &self.discover)
.finish()
}
}