1use futures::FutureExt;
2use std::marker::PhantomData;
3use std::task::{Context, Poll};
4use tokio::task::{JoinError, JoinHandle};
5use tower::balance::p2c::Balance;
6use tower::discover::ServiceList;
7use tower::{BoxError, Service, ServiceExt};
8
9use crate::util::BoxFuture;
10
11#[derive(Debug, thiserror::Error)]
12pub enum Error {
13 #[error("failed to create service pool")]
14 Create(
15 #[from]
16 #[source]
17 BoxError,
18 ),
19 #[error("failed to join created services pool")]
20 CreateJoin(#[source] JoinError),
21 #[error("failed to create service pool")]
22 Failed,
23}
24
25type CreateHandle<T, E> = JoinHandle<Result<T, E>>;
26type ServiceSet<S, Req> = Balance<ServiceList<Vec<S>>, Req>;
27
28enum CreateFuture<MS, Target, Req>
29where
30 MS: Service<Target>,
31 MS::Response: tower::Service<Req, Error = BoxError>,
32{
33 Pending {
34 handle: CreateHandle<ServiceSet<MS::Response, Req>, MS::Error>,
35 },
36 Ready {
37 services: ServiceSet<MS::Response, Req>,
38 },
39 Failed,
40}
41
42pub struct Pool<MS, Target, Req>
44where
45 MS: Service<Target>,
46 MS::Response: tower::Service<Req, Error = BoxError>,
47{
48 services: CreateFuture<MS, Target, Req>,
49 _p: PhantomData<Target>,
50}
51
52pub struct Layer<Target, Req> {
53 size: usize,
54 target: Target,
55 _p: PhantomData<Req>,
56}
57
58impl<MS, Target, Req> Pool<MS, Target, Req>
59where
60 Target: Clone + Send + 'static,
61 Req: Send + 'static,
62 MS: Service<Target> + Send + 'static,
63 MS::Response: tower::Service<Req, Error = BoxError> + Send,
64 MS::Error: Send,
65 MS::Future: Send,
66{
67 pub fn with_size(size: usize, mut make_service: MS, target: Target) -> Self {
69 tracing::debug!(message = "creating service pool", size);
70 let handle = tokio::spawn(async move {
71 let mut services = Vec::with_capacity(size);
72 for _ in 0..size {
73 let target = target.clone();
74 let service = make_service.ready().await?.call(target).await?;
75 services.push(service);
76 }
77 tracing::debug!(message = "service pool created", size);
78 Ok(Balance::new(ServiceList::new(services)))
79 });
80
81 Self {
82 services: CreateFuture::Pending { handle },
83 _p: PhantomData,
84 }
85 }
86}
87
88impl<MS, Target, Req> tower::Service<Req> for Pool<MS, Target, Req>
89where
90 MS: Service<Target>,
91 MS::Response: tower::Service<Req, Error = BoxError> + tower::load::Load,
92 MS::Error: std::error::Error + Send + Sync + 'static,
93 <MS::Response as tower::Service<Req>>::Future: Send + 'static,
94 <MS::Response as tower::load::Load>::Metric: std::fmt::Debug,
95{
96 type Response = <MS::Response as tower::Service<Req>>::Response;
97 type Error = BoxError;
98 type Future = BoxFuture<Result<Self::Response, Self::Error>>;
99
100 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
101 match self.services {
102 CreateFuture::Pending { ref mut handle } => match handle.poll_unpin(cx) {
103 Poll::Ready(Ok(Ok(services))) => {
104 self.services = CreateFuture::Ready { services };
105 self.poll_ready(cx)
106 }
107 Poll::Ready(Ok(Err(err))) => {
108 self.services = CreateFuture::Failed;
109 Poll::Ready(Err(err.into()))
110 }
111 Poll::Ready(Err(err)) => {
112 self.services = CreateFuture::Failed;
113 Poll::Ready(Err(err.into()))
114 }
115 Poll::Pending => Poll::Pending,
116 },
117 CreateFuture::Ready { ref mut services } => services.poll_ready(cx),
118 CreateFuture::Failed => Poll::Ready(Err(Error::Failed.into())),
119 }
120 }
121
122 fn call(&mut self, req: Req) -> Self::Future {
123 match self.services {
124 CreateFuture::Ready { ref mut services } => Box::pin(services.call(req)),
125 _ => unimplemented!("called before ready"),
126 }
127 }
128}
129
130impl<Target, Req> Layer<Target, Req> {
131 #[must_use]
132 pub fn with_size(size: usize, target: Target) -> Self {
133 Self {
134 size,
135 target,
136 _p: PhantomData,
137 }
138 }
139}
140
141impl<MS, Target, Req> tower::Layer<MS> for Layer<Target, Req>
142where
143 Req: Send + 'static,
144 MS: Service<Target> + Send + 'static,
145 MS::Response: tower::Service<Req, Error = BoxError> + Send,
146 MS::Error: Send,
147 MS::Future: Send,
148 Target: Clone + Send + 'static,
149{
150 type Service = Pool<MS, Target, Req>;
151
152 fn layer(&self, inner: MS) -> Self::Service {
153 let target = self.target.clone();
154 Pool::with_size(self.size, inner, target)
155 }
156}