1use std::{cell, fmt, future::Future, marker, pin::Pin, rc::Rc, task::Context, task::Poll};
2
3use crate::{ctx::WaitersRef, Service, ServiceCtx};
4
5#[derive(Debug)]
6pub struct Pipeline<S> {
10 index: u32,
11 state: Rc<PipelineState<S>>,
12}
13
14struct PipelineState<S> {
15 svc: S,
16 waiters: WaitersRef,
17}
18
19impl<S> PipelineState<S> {
20 pub(crate) fn waiters_ref(&self) -> &WaitersRef {
21 &self.waiters
22 }
23}
24
25impl<S> Pipeline<S> {
26 #[inline]
27 pub fn new(svc: S) -> Self {
29 let (index, waiters) = WaitersRef::new();
30 Pipeline {
31 index,
32 state: Rc::new(PipelineState { svc, waiters }),
33 }
34 }
35
36 #[inline]
37 pub fn get_ref(&self) -> &S {
39 &self.state.svc
40 }
41
42 #[inline]
43 pub async fn ready<R>(&self) -> Result<(), S::Error>
45 where
46 S: Service<R>,
47 {
48 ServiceCtx::<'_, S>::new(self.index, self.state.waiters_ref())
49 .ready(&self.state.svc)
50 .await
51 }
52
53 #[inline]
54 pub async fn call<R>(&self, req: R) -> Result<S::Response, S::Error>
57 where
58 S: Service<R>,
59 {
60 ServiceCtx::<'_, S>::new(self.index, self.state.waiters_ref())
61 .call(&self.state.svc, req)
62 .await
63 }
64
65 #[inline]
66 pub fn call_static<R>(&self, req: R) -> PipelineCall<S, R>
69 where
70 S: Service<R> + 'static,
71 R: 'static,
72 {
73 let pl = self.clone();
74
75 PipelineCall {
76 fut: Box::pin(async move {
77 ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
78 .call(&pl.state.svc, req)
79 .await
80 }),
81 }
82 }
83
84 #[inline]
85 pub fn call_nowait<R>(&self, req: R) -> PipelineCall<S, R>
89 where
90 S: Service<R> + 'static,
91 R: 'static,
92 {
93 let pl = self.clone();
94
95 PipelineCall {
96 fut: Box::pin(async move {
97 ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
98 .call_nowait(&pl.state.svc, req)
99 .await
100 }),
101 }
102 }
103
104 #[inline]
105 pub fn is_shutdown(&self) -> bool {
107 self.state.waiters.is_shutdown()
108 }
109
110 #[inline]
111 pub async fn shutdown<R>(&self)
113 where
114 S: Service<R>,
115 {
116 self.state.svc.shutdown().await
117 }
118
119 #[inline]
120 pub fn poll<R>(&self, cx: &mut Context<'_>) -> Result<(), S::Error>
121 where
122 S: Service<R>,
123 {
124 self.state.svc.poll(cx)
125 }
126
127 #[inline]
128 pub fn bind<R>(self) -> PipelineBinding<S, R>
130 where
131 S: Service<R> + 'static,
132 R: 'static,
133 {
134 PipelineBinding::new(self)
135 }
136}
137
138impl<S> From<S> for Pipeline<S> {
139 #[inline]
140 fn from(svc: S) -> Self {
141 Pipeline::new(svc)
142 }
143}
144
145impl<S> Clone for Pipeline<S> {
146 fn clone(&self) -> Self {
147 Pipeline {
148 index: self.state.waiters.insert(),
149 state: self.state.clone(),
150 }
151 }
152}
153
154impl<S> Drop for Pipeline<S> {
155 #[inline]
156 fn drop(&mut self) {
157 self.state.waiters.remove(self.index);
158 }
159}
160
161impl<S: fmt::Debug> fmt::Debug for PipelineState<S> {
162 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163 f.debug_struct("PipelineState")
164 .field("svc", &self.svc)
165 .field("waiters", &self.waiters.get().len())
166 .finish()
167 }
168}
169
170pub struct PipelineBinding<S, R>
172where
173 S: Service<R>,
174{
175 pl: Pipeline<S>,
176 st: cell::UnsafeCell<State<S::Error>>,
177}
178
179enum State<E> {
180 New,
181 Readiness(Pin<Box<dyn Future<Output = Result<(), E>> + 'static>>),
182 Shutdown(Pin<Box<dyn Future<Output = ()> + 'static>>),
183}
184
185impl<S, R> PipelineBinding<S, R>
186where
187 S: Service<R> + 'static,
188 R: 'static,
189{
190 fn new(pl: Pipeline<S>) -> Self {
191 PipelineBinding {
192 pl,
193 st: cell::UnsafeCell::new(State::New),
194 }
195 }
196
197 #[inline]
198 pub fn get_ref(&self) -> &S {
200 &self.pl.state.svc
201 }
202
203 #[inline]
204 pub fn pipeline(&self) -> Pipeline<S> {
206 self.pl.clone()
207 }
208
209 #[inline]
210 pub fn poll(&self, cx: &mut Context<'_>) -> Result<(), S::Error> {
211 self.pl.poll(cx)
212 }
213
214 #[inline]
215 pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), S::Error>> {
219 let st = unsafe { &mut *self.st.get() };
220
221 match st {
222 State::New => {
223 let pl: &'static Pipeline<S> = unsafe { std::mem::transmute(&self.pl) };
227 let fut = Box::pin(CheckReadiness {
228 fut: None,
229 f: ready,
230 _t: marker::PhantomData,
231 pl,
232 });
233 *st = State::Readiness(fut);
234 self.poll_ready(cx)
235 }
236 State::Readiness(ref mut fut) => Pin::new(fut).poll(cx),
237 State::Shutdown(_) => panic!("Pipeline is shutding down"),
238 }
239 }
240
241 #[inline]
242 pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<()> {
244 let st = unsafe { &mut *self.st.get() };
245
246 match st {
247 State::New | State::Readiness(_) => {
248 let pl: &'static Pipeline<S> = unsafe { std::mem::transmute(&self.pl) };
252 *st = State::Shutdown(Box::pin(async move { pl.shutdown().await }));
253 pl.state.waiters.shutdown();
254 self.poll_shutdown(cx)
255 }
256 State::Shutdown(ref mut fut) => Pin::new(fut).poll(cx),
257 }
258 }
259
260 #[inline]
261 pub fn call(&self, req: R) -> PipelineCall<S, R> {
264 let pl = self.pl.clone();
265
266 PipelineCall {
267 fut: Box::pin(async move {
268 ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
269 .call(&pl.state.svc, req)
270 .await
271 }),
272 }
273 }
274
275 #[inline]
276 pub fn call_nowait(&self, req: R) -> PipelineCall<S, R> {
280 let pl = self.pl.clone();
281
282 PipelineCall {
283 fut: Box::pin(async move {
284 ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
285 .call_nowait(&pl.state.svc, req)
286 .await
287 }),
288 }
289 }
290
291 #[inline]
292 pub fn is_shutdown(&self) -> bool {
294 self.pl.state.waiters.is_shutdown()
295 }
296
297 #[inline]
298 pub async fn shutdown(&self) {
300 self.pl.state.svc.shutdown().await
301 }
302}
303
304impl<S, R> Drop for PipelineBinding<S, R>
305where
306 S: Service<R>,
307{
308 fn drop(&mut self) {
309 self.st = cell::UnsafeCell::new(State::New);
310 }
311}
312
313impl<S, R> Clone for PipelineBinding<S, R>
314where
315 S: Service<R>,
316{
317 #[inline]
318 fn clone(&self) -> Self {
319 Self {
320 pl: self.pl.clone(),
321 st: cell::UnsafeCell::new(State::New),
322 }
323 }
324}
325
326impl<S, R> fmt::Debug for PipelineBinding<S, R>
327where
328 S: Service<R> + fmt::Debug,
329{
330 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
331 f.debug_struct("PipelineBinding")
332 .field("pipeline", &self.pl)
333 .finish()
334 }
335}
336
337#[must_use = "futures do nothing unless polled"]
338pub struct PipelineCall<S, R>
340where
341 S: Service<R>,
342 R: 'static,
343{
344 fut: Call<S::Response, S::Error>,
345}
346
347type Call<R, E> = Pin<Box<dyn Future<Output = Result<R, E>> + 'static>>;
348
349impl<S, R> Future for PipelineCall<S, R>
350where
351 S: Service<R>,
352{
353 type Output = Result<S::Response, S::Error>;
354
355 #[inline]
356 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
357 Pin::new(&mut self.as_mut().fut).poll(cx)
358 }
359}
360
361impl<S, R> fmt::Debug for PipelineCall<S, R>
362where
363 S: Service<R>,
364{
365 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
366 f.debug_struct("PipelineCall").finish()
367 }
368}
369
370fn ready<S, R>(pl: &'static Pipeline<S>) -> impl Future<Output = Result<(), S::Error>>
371where
372 S: Service<R>,
373 R: 'static,
374{
375 pl.state
376 .svc
377 .ready(ServiceCtx::<'_, S>::new(pl.index, pl.state.waiters_ref()))
378}
379
380struct CheckReadiness<S: Service<R> + 'static, R, F, Fut> {
381 f: F,
382 fut: Option<Fut>,
383 pl: &'static Pipeline<S>,
384 _t: marker::PhantomData<R>,
385}
386
387impl<S: Service<R>, R, F, Fut> Unpin for CheckReadiness<S, R, F, Fut> {}
388
389impl<S: Service<R>, R, F, Fut> Drop for CheckReadiness<S, R, F, Fut> {
390 fn drop(&mut self) {
391 if self.fut.is_some() {
393 self.pl.state.waiters.notify();
394 }
395 }
396}
397
398impl<S, R, F, Fut> Future for CheckReadiness<S, R, F, Fut>
399where
400 S: Service<R>,
401 F: Fn(&'static Pipeline<S>) -> Fut,
402 Fut: Future<Output = Result<(), S::Error>>,
403{
404 type Output = Result<(), S::Error>;
405
406 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
407 let mut slf = self.as_mut();
408
409 slf.pl.poll(cx)?;
410
411 if slf.pl.state.waiters.can_check(slf.pl.index, cx) {
412 if slf.fut.is_none() {
413 slf.fut = Some((slf.f)(slf.pl));
414 }
415 let fut = slf.fut.as_mut().unwrap();
416 match unsafe { Pin::new_unchecked(fut) }.poll(cx) {
417 Poll::Pending => {
418 slf.pl.state.waiters.register(slf.pl.index, cx);
419 Poll::Pending
420 }
421 Poll::Ready(res) => {
422 let _ = slf.fut.take();
423 slf.pl.state.waiters.notify();
424 Poll::Ready(res)
425 }
426 }
427 } else {
428 Poll::Pending
429 }
430 }
431}