1use std::sync::{Arc, Mutex};
24use std::task::{self, Poll};
25
26use tokio::sync::oneshot;
27use tower_service::Service;
28
29use self::internal::{DitchGuard, SingletonError, SingletonFuture, State};
30
31type BoxError = Box<dyn std::error::Error + Send + Sync>;
32
33#[cfg(docsrs)]
34pub use self::internal::Singled;
35
36#[derive(Debug)]
42pub struct Singleton<M, Dst>
43where
44 M: Service<Dst>,
45{
46 mk_svc: M,
47 state: Arc<Mutex<State<M::Response>>>,
48}
49
50impl<M, Target> Singleton<M, Target>
51where
52 M: Service<Target>,
53 M::Response: Clone,
54{
55 pub fn new(mk_svc: M) -> Self {
57 Singleton {
58 mk_svc,
59 state: Arc::new(Mutex::new(State::Empty)),
60 }
61 }
62
63 pub fn retain<F>(&mut self, mut predicate: F)
67 where
68 F: FnMut(&mut M::Response) -> bool,
69 {
70 let mut locked = self.state.lock().unwrap();
71 match *locked {
72 State::Empty => {}
73 State::Making(..) => {}
74 State::Made(ref mut svc) => {
75 if !predicate(svc) {
76 *locked = State::Empty;
77 }
78 }
79 }
80 }
81
82 pub fn is_empty(&self) -> bool {
87 matches!(*self.state.lock().unwrap(), State::Empty)
88 }
89}
90
91impl<M, Target> Service<Target> for Singleton<M, Target>
92where
93 M: Service<Target>,
94 M::Response: Clone,
95 M::Error: Into<BoxError>,
96{
97 type Response = internal::Singled<M::Response>;
98 type Error = SingletonError;
99 type Future = SingletonFuture<M::Future, M::Response>;
100
101 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
102 if let State::Empty = *self.state.lock().unwrap() {
103 return self
104 .mk_svc
105 .poll_ready(cx)
106 .map_err(|e| SingletonError(e.into()));
107 }
108 Poll::Ready(Ok(()))
109 }
110
111 fn call(&mut self, dst: Target) -> Self::Future {
112 let mut locked = self.state.lock().unwrap();
113 match *locked {
114 State::Empty => {
115 let fut = self.mk_svc.call(dst);
116 *locked = State::Making(Vec::new());
117 SingletonFuture::Driving {
118 future: fut,
119 singleton: DitchGuard(Arc::downgrade(&self.state)),
120 }
121 }
122 State::Making(ref mut waiters) => {
123 let (tx, rx) = oneshot::channel();
124 waiters.push(tx);
125 SingletonFuture::Waiting {
126 rx,
127 state: Arc::downgrade(&self.state),
128 }
129 }
130 State::Made(ref svc) => SingletonFuture::Made {
131 svc: Some(svc.clone()),
132 state: Arc::downgrade(&self.state),
133 },
134 }
135 }
136}
137
138impl<M, Target> Clone for Singleton<M, Target>
139where
140 M: Service<Target> + Clone,
141{
142 fn clone(&self) -> Self {
143 Self {
144 mk_svc: self.mk_svc.clone(),
145 state: self.state.clone(),
146 }
147 }
148}
149
150mod internal {
152 use std::future::Future;
153 use std::pin::Pin;
154 use std::sync::{Mutex, Weak};
155 use std::task::{self, ready, Poll};
156
157 use pin_project_lite::pin_project;
158 use tokio::sync::oneshot;
159 use tower_service::Service;
160
161 use super::BoxError;
162
163 pin_project! {
164 #[project = SingletonFutureProj]
165 pub enum SingletonFuture<F, S> {
166 Driving {
167 #[pin]
168 future: F,
169 singleton: DitchGuard<S>,
170 },
171 Waiting {
172 rx: oneshot::Receiver<S>,
173 state: Weak<Mutex<State<S>>>,
174 },
175 Made {
176 svc: Option<S>,
177 state: Weak<Mutex<State<S>>>,
178 },
179 }
180 }
181
182 #[derive(Debug)]
184 pub enum State<S> {
185 Empty,
186 Making(Vec<oneshot::Sender<S>>),
187 Made(S),
188 }
189
190 pub struct DitchGuard<S>(pub(super) Weak<Mutex<State<S>>>);
192
193 #[derive(Debug)]
207 pub struct Singled<S> {
208 inner: S,
209 state: Weak<Mutex<State<S>>>,
210 }
211
212 impl<F, S, E> Future for SingletonFuture<F, S>
213 where
214 F: Future<Output = Result<S, E>>,
215 E: Into<BoxError>,
216 S: Clone,
217 {
218 type Output = Result<Singled<S>, SingletonError>;
219
220 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
221 match self.project() {
222 SingletonFutureProj::Driving { future, singleton } => {
223 match ready!(future.poll(cx)) {
224 Ok(svc) => {
225 if let Some(state) = singleton.0.upgrade() {
226 let mut locked = state.lock().unwrap();
227 match std::mem::replace(&mut *locked, State::Made(svc.clone())) {
228 State::Making(waiters) => {
229 for tx in waiters {
230 let _ = tx.send(svc.clone());
231 }
232 }
233 State::Empty | State::Made(_) => {
234 unreachable!()
236 }
237 }
238 }
239 let state = std::mem::replace(&mut singleton.0, Weak::new());
241 Poll::Ready(Ok(Singled::new(svc, state)))
242 }
243 Err(e) => {
244 if let Some(state) = singleton.0.upgrade() {
245 let mut locked = state.lock().unwrap();
246 singleton.0 = Weak::new();
247 *locked = State::Empty;
248 }
249 Poll::Ready(Err(SingletonError(e.into())))
250 }
251 }
252 }
253 SingletonFutureProj::Waiting { rx, state } => match ready!(Pin::new(rx).poll(cx)) {
254 Ok(svc) => Poll::Ready(Ok(Singled::new(svc, state.clone()))),
255 Err(_canceled) => Poll::Ready(Err(SingletonError(Canceled.into()))),
256 },
257 SingletonFutureProj::Made { svc, state } => {
258 Poll::Ready(Ok(Singled::new(svc.take().unwrap(), state.clone())))
259 }
260 }
261 }
262 }
263
264 impl<S> Drop for DitchGuard<S> {
265 fn drop(&mut self) {
266 if let Some(state) = self.0.upgrade() {
267 if let Ok(mut locked) = state.lock() {
268 *locked = State::Empty;
269 }
270 }
271 }
272 }
273
274 impl<S> Singled<S> {
275 fn new(inner: S, state: Weak<Mutex<State<S>>>) -> Self {
276 Singled { inner, state }
277 }
278 }
279
280 impl<S, Req> Service<Req> for Singled<S>
281 where
282 S: Service<Req>,
283 {
284 type Response = S::Response;
285 type Error = S::Error;
286 type Future = S::Future;
287
288 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
289 match self.inner.poll_ready(cx) {
291 Poll::Ready(Err(err)) => {
292 if let Some(state) = self.state.upgrade() {
293 *state.lock().unwrap() = State::Empty;
294 }
295 Poll::Ready(Err(err))
296 }
297 other => other,
298 }
299 }
300
301 fn call(&mut self, req: Req) -> Self::Future {
302 self.inner.call(req)
303 }
304 }
305
306 #[derive(Debug)]
310 pub struct SingletonError(pub(super) BoxError);
311
312 impl std::fmt::Display for SingletonError {
313 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314 f.write_str("singleton connection error")
315 }
316 }
317
318 impl std::error::Error for SingletonError {
319 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
320 Some(&*self.0)
321 }
322 }
323
324 #[derive(Debug)]
325 struct Canceled;
326
327 impl std::fmt::Display for Canceled {
328 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
329 f.write_str("singleton connection canceled")
330 }
331 }
332
333 impl std::error::Error for Canceled {}
334}
335
336#[cfg(test)]
337mod tests {
338 use std::future::Future;
339 use std::pin::Pin;
340 use std::task::Poll;
341
342 use tower_service::Service;
343
344 use super::Singleton;
345
346 #[tokio::test]
347 async fn first_call_drives_subsequent_wait() {
348 let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
349
350 let mut singleton = Singleton::new(mock_svc);
351
352 handle.allow(1);
353 std::future::poll_fn(|cx| singleton.poll_ready(cx))
354 .await
355 .unwrap();
356 let fut1 = singleton.call(());
358 let fut2 = singleton.call(());
360
361 let ((), send_response) = handle.next_request().await.unwrap();
363 send_response.send_response("svc");
364
365 fut1.await.unwrap();
367 fut2.await.unwrap();
368 }
369
370 #[tokio::test]
371 async fn made_state_returns_immediately() {
372 let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
373 let mut singleton = Singleton::new(mock_svc);
374
375 handle.allow(1);
376 std::future::poll_fn(|cx| singleton.poll_ready(cx))
377 .await
378 .unwrap();
379 let fut1 = singleton.call(());
381 let ((), send_response) = handle.next_request().await.unwrap();
382 send_response.send_response("svc");
383 fut1.await.unwrap();
384
385 singleton.call(()).await.unwrap();
387 }
388
389 #[tokio::test]
390 async fn cached_service_poll_ready_error_clears_singleton() {
391 let (outer, mut outer_handle) =
393 tower_test::mock::pair::<(), tower_test::mock::Mock<(), &'static str>>();
394 let mut singleton = Singleton::new(outer);
395
396 outer_handle.allow(2);
398 std::future::poll_fn(|cx| singleton.poll_ready(cx))
399 .await
400 .unwrap();
401
402 let fut1 = singleton.call(());
404 let ((), send_inner) = outer_handle.next_request().await.unwrap();
405 let (inner, mut inner_handle) = tower_test::mock::pair::<(), &'static str>();
406 send_inner.send_response(inner);
407 let mut cached = fut1.await.unwrap();
408
409 inner_handle.allow(1);
411
412 inner_handle.send_error(std::io::Error::new(
414 std::io::ErrorKind::Other,
415 "cached poll_ready failed",
416 ));
417
418 let err = std::future::poll_fn(|cx| cached.poll_ready(cx))
420 .await
421 .err()
422 .expect("expected poll_ready error");
423 assert_eq!(err.to_string(), "cached poll_ready failed");
424
425 outer_handle.allow(1);
427 std::future::poll_fn(|cx| singleton.poll_ready(cx))
428 .await
429 .unwrap();
430 let fut2 = singleton.call(());
431 let ((), send_inner2) = outer_handle.next_request().await.unwrap();
432 let (inner2, mut inner_handle2) = tower_test::mock::pair::<(), &'static str>();
433 send_inner2.send_response(inner2);
434 let mut cached2 = fut2.await.unwrap();
435
436 inner_handle2.allow(1);
438 std::future::poll_fn(|cx| cached2.poll_ready(cx))
439 .await
440 .expect("expected poll_ready");
441 let cfut2 = cached2.call(());
442 let ((), send_cached2) = inner_handle2.next_request().await.unwrap();
443 send_cached2.send_response("svc2");
444 cfut2.await.unwrap();
445 }
446
447 #[tokio::test]
448 async fn cancel_waiter_does_not_affect_others() {
449 let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
450 let mut singleton = Singleton::new(mock_svc);
451
452 std::future::poll_fn(|cx| singleton.poll_ready(cx))
453 .await
454 .unwrap();
455 let fut1 = singleton.call(());
456 let fut2 = singleton.call(());
457 drop(fut2); let ((), send_response) = handle.next_request().await.unwrap();
460 send_response.send_response("svc");
461
462 fut1.await.unwrap();
463 }
464
465 #[tokio::test]
467 async fn cancel_driver_cancels_all() {
468 let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>();
469 let mut singleton = Singleton::new(mock_svc);
470
471 std::future::poll_fn(|cx| singleton.poll_ready(cx))
472 .await
473 .unwrap();
474 let mut fut1 = singleton.call(());
475 let fut2 = singleton.call(());
476
477 std::future::poll_fn(move |cx| {
479 let _ = Pin::new(&mut fut1).poll(cx);
480 Poll::Ready(())
481 })
482 .await;
483
484 let ((), send_response) = handle.next_request().await.unwrap();
485 send_response.send_response("svc");
486
487 assert_eq!(
488 fut2.await.unwrap_err().0.to_string(),
489 "singleton connection canceled"
490 );
491 }
492}