1use std::{fmt::Debug, sync::Arc};
12
13use async_broadcast::RecvError;
14use motore::{layer::Layer, service::Service};
15use volo::{
16 context::Context,
17 discovery::Discover,
18 loadbalance::{LoadBalance, MkLbLayer, random::WeightedRandomBalance},
19};
20
21use super::dns::{DiscoverKey, DnsResolver};
22use crate::{
23 context::ClientContext,
24 error::{
25 ClientError,
26 client::{lb_error, no_available_endpoint},
27 },
28 request::Request,
29};
30
31pub type DefaultLb = LbConfig<WeightedRandomBalance<DiscoverKey>, DnsResolver>;
33pub type DefaultLbService<S> =
35 LoadBalanceService<WeightedRandomBalance<DiscoverKey>, DnsResolver, S>;
36
37pub struct LbConfig<L, D> {
39 load_balance: L,
40 discover: D,
41}
42
43impl Default for DefaultLb {
44 fn default() -> Self {
45 LbConfig::new(WeightedRandomBalance::new(), DnsResolver::default())
46 }
47}
48
49impl<L, D> LbConfig<L, D> {
50 pub fn new(load_balance: L, discover: D) -> Self {
52 LbConfig {
53 load_balance,
54 discover,
55 }
56 }
57
58 pub fn load_balance<NL>(self, load_balance: NL) -> LbConfig<NL, D> {
60 LbConfig {
61 load_balance,
62 discover: self.discover,
63 }
64 }
65
66 pub fn discover<ND>(self, discover: ND) -> LbConfig<L, ND> {
68 LbConfig {
69 load_balance: self.load_balance,
70 discover,
71 }
72 }
73}
74
75impl<LB, D> MkLbLayer for LbConfig<LB, D> {
76 type Layer = LoadBalanceLayer<LB, D>;
77
78 fn make(self) -> Self::Layer {
79 LoadBalanceLayer::new(self.load_balance, self.discover)
80 }
81}
82
83#[derive(Clone, Default, Copy)]
85pub struct LoadBalanceLayer<LB, D> {
86 load_balance: LB,
87 discover: D,
88}
89
90impl<LB, D> LoadBalanceLayer<LB, D> {
91 fn new(load_balance: LB, discover: D) -> Self {
92 LoadBalanceLayer {
93 load_balance,
94 discover,
95 }
96 }
97}
98
99impl<LB, D, S> Layer<S> for LoadBalanceLayer<LB, D>
100where
101 LB: LoadBalance<D>,
102 D: Discover,
103{
104 type Service = LoadBalanceService<LB, D, S>;
105
106 fn layer(self, inner: S) -> Self::Service {
107 LoadBalanceService::new(self.load_balance, self.discover, inner)
108 }
109}
110
111#[derive(Clone)]
113pub struct LoadBalanceService<LB, D, S> {
114 load_balance: Arc<LB>,
115 discover: D,
116 service: S,
117}
118
119impl<LB, D, S> LoadBalanceService<LB, D, S>
120where
121 LB: LoadBalance<D>,
122 D: Discover,
123{
124 fn new(load_balance: LB, discover: D, service: S) -> Self {
125 let lb = Arc::new(load_balance);
126
127 let service = Self {
128 load_balance: lb.clone(),
129 discover,
130 service,
131 };
132
133 let Some(mut channel) = service.discover.watch(None) else {
134 return service;
135 };
136
137 tokio::spawn(async move {
138 loop {
139 match channel.recv().await {
140 Ok(recv) => lb.rebalance(recv),
141 Err(err) => match err {
142 RecvError::Closed => break,
143 _ => tracing::warn!("[Volo-HTTP] discovering subscription error: {err}"),
144 },
145 }
146 }
147 });
148
149 service
150 }
151}
152
153impl<LB, D, S, B> Service<ClientContext, Request<B>> for LoadBalanceService<LB, D, S>
154where
155 LB: LoadBalance<D>,
156 D: Discover,
157 S: Service<ClientContext, Request<B>, Error = ClientError> + Send + Sync,
158 B: Send,
159{
160 type Response = S::Response;
161 type Error = S::Error;
162
163 async fn call(
164 &self,
165 cx: &mut ClientContext,
166 req: Request<B>,
167 ) -> Result<Self::Response, Self::Error> {
168 let callee = cx.rpc_info().callee();
169
170 let mut picker = match &callee.address {
171 None => self
172 .load_balance
173 .get_picker(callee, &self.discover)
174 .await
175 .map_err(lb_error)?,
176 _ => {
177 return self.service.call(cx, req).await;
178 }
179 };
180
181 let addr = picker.next().ok_or_else(no_available_endpoint)?;
182 cx.rpc_info_mut().callee_mut().set_address(addr);
183
184 self.service.call(cx, req).await
185 }
186}
187
188impl<LB, D, S> Debug for LoadBalanceService<LB, D, S>
189where
190 LB: Debug,
191 D: Debug,
192 S: Debug,
193{
194 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195 f.debug_struct("LBService")
196 .field("load_balancer", &self.load_balance)
197 .field("discover", &self.discover)
198 .finish()
199 }
200}