1use std::sync::{Arc, Mutex};
24use std::task::{self, Poll};
25
26use tokio::sync::oneshot;
27use tower_service::Service;
28
29use self::internal::{DitchGuard, SingletonError, SingletonFuture, State};
30
31type BoxError = Box<dyn std::error::Error + Send + Sync>;
32
33#[cfg(docsrs)]
34pub use self::internal::Singled;
35
36#[derive(Debug)]
42pub struct Singleton<M, Dst>
43where
44 M: Service<Dst>,
45{
46 mk_svc: M,
47 state: Arc<Mutex<State<M::Response>>>,
48}
49
50impl<M, Target> Singleton<M, Target>
51where
52 M: Service<Target>,
53 M::Response: Clone,
54{
55 pub fn new(mk_svc: M) -> Self {
57 Singleton {
58 mk_svc,
59 state: Arc::new(Mutex::new(State::Empty)),
60 }
61 }
62
63 pub fn retain<F>(&mut self, mut predicate: F)
67 where
68 F: FnMut(&mut M::Response) -> bool,
69 {
70 let mut locked = self.state.lock().unwrap();
71 match *locked {
72 State::Empty => {}
73 State::Making(..) => {}
74 State::Made(ref mut svc) => {
75 if !predicate(svc) {
76 *locked = State::Empty;
77 }
78 }
79 }
80 }
81
82 pub fn is_empty(&self) -> bool {
87 matches!(*self.state.lock().unwrap(), State::Empty)
88 }
89}
90
91impl<M, Target> Service<Target> for Singleton<M, Target>
92where
93 M: Service<Target>,
94 M::Response: Clone,
95 M::Error: Into<BoxError>,
96{
97 type Response = internal::Singled<M::Response>;
98 type Error = SingletonError;
99 type Future = SingletonFuture<M::Future, M::Response>;
100
101 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
102 if let State::Empty = *self.state.lock().unwrap() {
103 return self
104 .mk_svc
105 .poll_ready(cx)
106 .map_err(|e| SingletonError(e.into()));
107 }
108 Poll::Ready(Ok(()))
109 }
110
111 fn call(&mut self, dst: Target) -> Self::Future {
112 let mut locked = self.state.lock().unwrap();
113 match *locked {
114 State::Empty => {
115 let fut = self.mk_svc.call(dst);
116 *locked = State::Making(Vec::new());
117 SingletonFuture::Driving {
118 future: fut,
119 singleton: DitchGuard(Arc::downgrade(&self.state)),
120 }
121 }
122 State::Making(ref mut waiters) => {
123 let (tx, rx) = oneshot::channel();
124 waiters.push(tx);
125 SingletonFuture::Waiting {
126 rx,
127 state: Arc::downgrade(&self.state),
128 }
129 }
130 State::Made(ref svc) => SingletonFuture::Made {
131 svc: Some(svc.clone()),
132 state: Arc::downgrade(&self.state),
133 },
134 }
135 }
136}
137
138impl<M, Target> Clone for Singleton<M, Target>
139where
140 M: Service<Target> + Clone,
141{
142 fn clone(&self) -> Self {
143 Self {
144 mk_svc: self.mk_svc.clone(),
145 state: self.state.clone(),
146 }
147 }
148}
149
150mod internal {
152 use std::future::Future;
153 use std::pin::Pin;
154 use std::sync::{Mutex, Weak};
155 use std::task::{self, Poll};
156
157 use futures_core::ready;
158 use pin_project_lite::pin_project;
159 use tokio::sync::oneshot;
160 use tower_service::Service;
161
162 use super::BoxError;
163
164 pin_project! {
165 #[project = SingletonFutureProj]
166 pub enum SingletonFuture<F, S> {
167 Driving {
168 #[pin]
169 future: F,
170 singleton: DitchGuard<S>,
171 },
172 Waiting {
173 rx: oneshot::Receiver<S>,
174 state: Weak<Mutex<State<S>>>,
175 },
176 Made {
177 svc: Option<S>,
178 state: Weak<Mutex<State<S>>>,
179 },
180 }
181 }
182
183 #[derive(Debug)]
185 pub enum State<S> {
186 Empty,
187 Making(Vec<oneshot::Sender<S>>),
188 Made(S),
189 }
190
191 pub struct DitchGuard<S>(pub(super) Weak<Mutex<State<S>>>);
193
194 #[derive(Debug)]
208 pub struct Singled<S> {
209 inner: S,
210 state: Weak<Mutex<State<S>>>,
211 }
212
213 impl<F, S, E> Future for SingletonFuture<F, S>
214 where
215 F: Future<Output = Result<S, E>>,
216 E: Into<BoxError>,
217 S: Clone,
218 {
219 type Output = Result<Singled<S>, SingletonError>;
220
221 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
222 match self.project() {
223 SingletonFutureProj::Driving { future, singleton } => {
224 match ready!(future.poll(cx)) {
225 Ok(svc) => {
226 if let Some(state) = singleton.0.upgrade() {
227 let mut locked = state.lock().unwrap();
228 match std::mem::replace(&mut *locked, State::Made(svc.clone())) {
229 State::Making(waiters) => {
230 for tx in waiters {
231 let _ = tx.send(svc.clone());
232 }
233 }
234 State::Empty | State::Made(_) => {
235 unreachable!()
237 }
238 }
239 }
240 let state = std::mem::replace(&mut singleton.0, Weak::new());
242 Poll::Ready(Ok(Singled::new(svc, state)))
243 }
244 Err(e) => {
245 if let Some(state) = singleton.0.upgrade() {
246 let mut locked = state.lock().unwrap();
247 singleton.0 = Weak::new();
248 *locked = State::Empty;
249 }
250 Poll::Ready(Err(SingletonError(e.into())))
251 }
252 }
253 }
254 SingletonFutureProj::Waiting { rx, state } => match ready!(Pin::new(rx).poll(cx)) {
255 Ok(svc) => Poll::Ready(Ok(Singled::new(svc, state.clone()))),
256 Err(_canceled) => Poll::Ready(Err(SingletonError(Canceled.into()))),
257 },
258 SingletonFutureProj::Made { svc, state } => {
259 Poll::Ready(Ok(Singled::new(svc.take().unwrap(), state.clone())))
260 }
261 }
262 }
263 }
264
265 impl<S> Drop for DitchGuard<S> {
266 fn drop(&mut self) {
267 if let Some(state) = self.0.upgrade() {
268 if let Ok(mut locked) = state.lock() {
269 *locked = State::Empty;
270 }
271 }
272 }
273 }
274
275 impl<S> Singled<S> {
276 fn new(inner: S, state: Weak<Mutex<State<S>>>) -> Self {
277 Singled { inner, state }
278 }
279 }
280
281 impl<S, Req> Service<Req> for Singled<S>
282 where
283 S: Service<Req>,
284 {
285 type Response = S::Response;
286 type Error = S::Error;
287 type Future = S::Future;
288
289 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
290 match self.inner.poll_ready(cx) {
292 Poll::Ready(Err(err)) => {
293 if let Some(state) = self.state.upgrade() {
294 *state.lock().unwrap() = State::Empty;
295 }
296 Poll::Ready(Err(err))
297 }
298 other => other,
299 }
300 }
301
302 fn call(&mut self, req: Req) -> Self::Future {
303 self.inner.call(req)
304 }
305 }
306
307 #[derive(Debug)]
311 pub struct SingletonError(pub(super) BoxError);
312
313 impl std::fmt::Display for SingletonError {
314 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315 f.write_str("singleton connection error")
316 }
317 }
318
319 impl std::error::Error for SingletonError {
320 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
321 Some(&*self.0)
322 }
323 }
324
325 #[derive(Debug)]
326 struct Canceled;
327
328 impl std::fmt::Display for Canceled {
329 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330 f.write_str("singleton connection canceled")
331 }
332 }
333
334 impl std::error::Error for Canceled {}
335}
336
337#[cfg(test)]
338mod tests {
339 use std::future::Future;
340 use std::pin::Pin;
341 use std::task::Poll;
342
343 use tower_service::Service;
344
345 use super::Singleton;
346
347 #[tokio::test]
348 async fn first_call_drives_subsequent_wait() {
349 let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
350
351 let mut singleton = Singleton::new(mock_svc);
352
353 handle.allow(1);
354 crate::common::future::poll_fn(|cx| singleton.poll_ready(cx))
355 .await
356 .unwrap();
357 let fut1 = singleton.call(());
359 let fut2 = singleton.call(());
361
362 let ((), send_response) = handle.next_request().await.unwrap();
364 send_response.send_response("svc");
365
366 fut1.await.unwrap();
368 fut2.await.unwrap();
369 }
370
371 #[tokio::test]
372 async fn made_state_returns_immediately() {
373 let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
374 let mut singleton = Singleton::new(mock_svc);
375
376 handle.allow(1);
377 crate::common::future::poll_fn(|cx| singleton.poll_ready(cx))
378 .await
379 .unwrap();
380 let fut1 = singleton.call(());
382 let ((), send_response) = handle.next_request().await.unwrap();
383 send_response.send_response("svc");
384 fut1.await.unwrap();
385
386 singleton.call(()).await.unwrap();
388 }
389
390 #[tokio::test]
391 async fn cached_service_poll_ready_error_clears_singleton() {
392 let (outer, mut outer_handle) =
394 tower_test::mock::pair::<(), tower_test::mock::Mock<(), &'static str>>();
395 let mut singleton = Singleton::new(outer);
396
397 outer_handle.allow(2);
399 crate::common::future::poll_fn(|cx| singleton.poll_ready(cx))
400 .await
401 .unwrap();
402
403 let fut1 = singleton.call(());
405 let ((), send_inner) = outer_handle.next_request().await.unwrap();
406 let (inner, mut inner_handle) = tower_test::mock::pair::<(), &'static str>();
407 send_inner.send_response(inner);
408 let mut cached = fut1.await.unwrap();
409
410 inner_handle.allow(1);
412
413 inner_handle.send_error(std::io::Error::new(
415 std::io::ErrorKind::Other,
416 "cached poll_ready failed",
417 ));
418
419 let err = crate::common::future::poll_fn(|cx| cached.poll_ready(cx))
421 .await
422 .err()
423 .expect("expected poll_ready error");
424 assert_eq!(err.to_string(), "cached poll_ready failed");
425
426 outer_handle.allow(1);
428 crate::common::future::poll_fn(|cx| singleton.poll_ready(cx))
429 .await
430 .unwrap();
431 let fut2 = singleton.call(());
432 let ((), send_inner2) = outer_handle.next_request().await.unwrap();
433 let (inner2, mut inner_handle2) = tower_test::mock::pair::<(), &'static str>();
434 send_inner2.send_response(inner2);
435 let mut cached2 = fut2.await.unwrap();
436
437 inner_handle2.allow(1);
439 crate::common::future::poll_fn(|cx| cached2.poll_ready(cx))
440 .await
441 .expect("expected poll_ready");
442 let cfut2 = cached2.call(());
443 let ((), send_cached2) = inner_handle2.next_request().await.unwrap();
444 send_cached2.send_response("svc2");
445 cfut2.await.unwrap();
446 }
447
448 #[tokio::test]
449 async fn cancel_waiter_does_not_affect_others() {
450 let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
451 let mut singleton = Singleton::new(mock_svc);
452
453 crate::common::future::poll_fn(|cx| singleton.poll_ready(cx))
454 .await
455 .unwrap();
456 let fut1 = singleton.call(());
457 let fut2 = singleton.call(());
458 drop(fut2); let ((), send_response) = handle.next_request().await.unwrap();
461 send_response.send_response("svc");
462
463 fut1.await.unwrap();
464 }
465
466 #[tokio::test]
468 async fn cancel_driver_cancels_all() {
469 let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
470 let mut singleton = Singleton::new(mock_svc);
471
472 crate::common::future::poll_fn(|cx| singleton.poll_ready(cx))
473 .await
474 .unwrap();
475 let mut fut1 = singleton.call(());
476 let fut2 = singleton.call(());
477
478 crate::common::future::poll_fn(move |cx| {
480 let _ = Pin::new(&mut fut1).poll(cx);
481 Poll::Ready(())
482 })
483 .await;
484
485 let ((), send_response) = handle.next_request().await.unwrap();
486 send_response.send_response("svc");
487
488 assert_eq!(
489 fut2.await.unwrap_err().0.to_string(),
490 "singleton connection canceled"
491 );
492 }
493}