1use crate::error;
2use futures_core::ready;
3use futures_util::future::{self, TryFutureExt};
4use pin_project::pin_project;
5use rand::{rngs::SmallRng, SeedableRng};
6use std::marker::PhantomData;
7use std::{
8 fmt,
9 future::Future,
10 pin::Pin,
11 task::{Context, Poll},
12};
13use tokio::sync::oneshot;
14use tower_discover::{Change, Discover};
15use tower_load::Load;
16use tower_ready_cache::{error::Failed, ReadyCache};
17use tower_service::Service;
18use tracing::{debug, trace};
19
20pub struct Balance<D: Discover, Req> {
41 discover: D,
42
43 services: ReadyCache<D::Key, D::Service, Req>,
44 ready_index: Option<usize>,
45
46 rng: SmallRng,
47
48 _req: PhantomData<Req>,
49}
50
51impl<D: Discover, Req> fmt::Debug for Balance<D, Req>
52where
53 D: fmt::Debug,
54 D::Key: fmt::Debug,
55 D::Service: fmt::Debug,
56 Req: fmt::Debug,
57{
58 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 f.debug_struct("Balance")
60 .field("discover", &self.discover)
61 .field("services", &self.services)
62 .finish()
63 }
64}
65
66#[pin_project]
67#[derive(Debug)]
71struct UnreadyService<K, S, Req> {
72 key: Option<K>,
73 #[pin]
74 cancel: oneshot::Receiver<()>,
75 service: Option<S>,
76
77 _req: PhantomData<Req>,
78}
79
80enum Error<E> {
81 Inner(E),
82 Canceled,
83}
84
85impl<D, Req> Balance<D, Req>
86where
87 D: Discover,
88 D::Service: Service<Req>,
89 <D::Service as Service<Req>>::Error: Into<error::Error>,
90{
91 pub fn new(discover: D, rng: SmallRng) -> Self {
93 Self {
94 rng,
95 discover,
96 services: ReadyCache::default(),
97 ready_index: None,
98
99 _req: PhantomData,
100 }
101 }
102
103 pub fn from_entropy(discover: D) -> Self {
105 Self::new(discover, SmallRng::from_entropy())
106 }
107
108 pub fn len(&self) -> usize {
110 self.services.len()
111 }
112}
113
114impl<D, Req> Balance<D, Req>
115where
116 D: Discover + Unpin,
117 D::Key: Clone,
118 D::Error: Into<error::Error>,
119 D::Service: Service<Req> + Load,
120 <D::Service as Load>::Metric: std::fmt::Debug,
121 <D::Service as Service<Req>>::Error: Into<error::Error>,
122{
123 fn update_pending_from_discover(
127 &mut self,
128 cx: &mut Context<'_>,
129 ) -> Poll<Result<(), error::Discover>> {
130 debug!("updating from discover");
131 loop {
132 match ready!(Pin::new(&mut self.discover).poll_discover(cx))
133 .map_err(|e| error::Discover(e.into()))?
134 {
135 Change::Remove(key) => {
136 trace!("remove");
137 self.services.evict(&key);
138 }
139 Change::Insert(key, svc) => {
140 trace!("insert");
141 self.services.push(key, svc);
144 }
145 }
146 }
147 }
148
149 fn promote_pending_to_ready(&mut self, cx: &mut Context<'_>) {
150 loop {
151 match self.services.poll_pending(cx) {
152 Poll::Ready(Ok(())) => {
153 debug_assert_eq!(self.services.pending_len(), 0);
155 break;
156 }
157 Poll::Pending => {
158 debug_assert!(self.services.pending_len() > 0);
160 break;
161 }
162 Poll::Ready(Err(error)) => {
163 debug!(%error, "dropping failed endpoint");
166 }
167 }
168 }
169 trace!(
170 ready = %self.services.ready_len(),
171 pending = %self.services.pending_len(),
172 "poll_unready"
173 );
174 }
175
176 fn p2c_ready_index(&mut self) -> Option<usize> {
178 match self.services.ready_len() {
179 0 => None,
180 1 => Some(0),
181 len => {
182 let idxs = rand::seq::index::sample(&mut self.rng, len, 2);
185
186 let aidx = idxs.index(0);
187 let bidx = idxs.index(1);
188 debug_assert_ne!(aidx, bidx, "random indices must be distinct");
189
190 let aload = self.ready_index_load(aidx);
191 let bload = self.ready_index_load(bidx);
192 let chosen = if aload <= bload { aidx } else { bidx };
193
194 trace!(
195 a.index = aidx,
196 a.load = ?aload,
197 b.index = bidx,
198 b.load = ?bload,
199 chosen = if chosen == aidx { "a" } else { "b" },
200 "p2c",
201 );
202 Some(chosen)
203 }
204 }
205 }
206
207 fn ready_index_load(&self, index: usize) -> <D::Service as Load>::Metric {
209 let (_, svc) = self.services.get_ready_index(index).expect("invalid index");
210 svc.load()
211 }
212
213 pub(crate) fn discover_mut(&mut self) -> &mut D {
214 &mut self.discover
215 }
216}
217
218impl<D, Req> Service<Req> for Balance<D, Req>
219where
220 D: Discover + Unpin,
221 D::Key: Clone,
222 D::Error: Into<error::Error>,
223 D::Service: Service<Req> + Load,
224 <D::Service as Load>::Metric: std::fmt::Debug,
225 <D::Service as Service<Req>>::Error: Into<error::Error>,
226{
227 type Response = <D::Service as Service<Req>>::Response;
228 type Error = error::Error;
229 type Future = future::MapErr<
230 <D::Service as Service<Req>>::Future,
231 fn(<D::Service as Service<Req>>::Error) -> error::Error,
232 >;
233
234 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
235 let _ = self.update_pending_from_discover(cx)?;
238 self.promote_pending_to_ready(cx);
239
240 loop {
241 if let Some(index) = self.ready_index.take() {
248 match self.services.check_ready_index(cx, index) {
249 Ok(true) => {
250 self.ready_index = Some(index);
252 return Poll::Ready(Ok(()));
253 }
254 Ok(false) => {
255 trace!("ready service became unavailable");
257 }
258 Err(Failed(_, error)) => {
259 debug!(%error, "endpoint failed");
262 }
263 }
264 }
265
266 self.ready_index = self.p2c_ready_index();
269 if self.ready_index.is_none() {
270 debug_assert_eq!(self.services.ready_len(), 0);
271 return Poll::Pending;
274 }
275 }
276 }
277
278 fn call(&mut self, request: Req) -> Self::Future {
279 let index = self.ready_index.take().expect("called before ready");
280 self.services
281 .call_ready_index(index, request)
282 .map_err(Into::into)
283 }
284}
285
286impl<K, S: Service<Req>, Req> Future for UnreadyService<K, S, Req> {
287 type Output = Result<(K, S), (K, Error<S::Error>)>;
288
289 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
290 let this = self.project();
291
292 if let Poll::Ready(Ok(())) = this.cancel.poll(cx) {
293 let key = this.key.take().expect("polled after ready");
294 return Poll::Ready(Err((key, Error::Canceled)));
295 }
296
297 let res = ready!(this
298 .service
299 .as_mut()
300 .expect("poll after ready")
301 .poll_ready(cx));
302
303 let key = this.key.take().expect("polled after ready");
304 let svc = this.service.take().expect("polled after ready");
305
306 match res {
307 Ok(()) => Poll::Ready(Ok((key, svc))),
308 Err(e) => Poll::Ready(Err((key, Error::Inner(e)))),
309 }
310 }
311}