1use std::{cell, fmt, future::Future, marker, pin::Pin, rc::Rc, task::Context, task::Poll};
2
3use crate::{ctx::WaitersRef, Service, ServiceCtx};
4
5#[derive(Debug)]
6pub struct Pipeline<S> {
10 index: u32,
11 state: Rc<PipelineState<S>>,
12}
13
14struct PipelineState<S> {
15 svc: S,
16 waiters: WaitersRef,
17}
18
19impl<S> PipelineState<S> {
20 pub(crate) fn waiters_ref(&self) -> &WaitersRef {
21 &self.waiters
22 }
23}
24
25impl<S> Pipeline<S> {
26 #[inline]
27 pub fn new(svc: S) -> Self {
29 let (index, waiters) = WaitersRef::new();
30 Pipeline {
31 index,
32 state: Rc::new(PipelineState { svc, waiters }),
33 }
34 }
35
36 #[inline]
37 pub fn get_ref(&self) -> &S {
39 &self.state.svc
40 }
41
42 #[inline]
43 pub async fn ready<R>(&self) -> Result<(), S::Error>
45 where
46 S: Service<R>,
47 {
48 ServiceCtx::<'_, S>::new(self.index, self.state.waiters_ref())
49 .ready(&self.state.svc)
50 .await
51 }
52
53 #[doc(hidden)]
54 #[deprecated]
55 pub async fn not_ready<R>(&self)
57 where
58 S: Service<R>,
59 {
60 std::future::pending().await
61 }
62
63 #[inline]
64 pub async fn call<R>(&self, req: R) -> Result<S::Response, S::Error>
67 where
68 S: Service<R>,
69 {
70 ServiceCtx::<'_, S>::new(self.index, self.state.waiters_ref())
71 .call(&self.state.svc, req)
72 .await
73 }
74
75 #[inline]
76 pub fn call_static<R>(&self, req: R) -> PipelineCall<S, R>
79 where
80 S: Service<R> + 'static,
81 R: 'static,
82 {
83 let pl = self.clone();
84
85 PipelineCall {
86 fut: Box::pin(async move {
87 ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
88 .call(&pl.state.svc, req)
89 .await
90 }),
91 }
92 }
93
94 #[inline]
95 pub fn call_nowait<R>(&self, req: R) -> PipelineCall<S, R>
99 where
100 S: Service<R> + 'static,
101 R: 'static,
102 {
103 let pl = self.clone();
104
105 PipelineCall {
106 fut: Box::pin(async move {
107 ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
108 .call_nowait(&pl.state.svc, req)
109 .await
110 }),
111 }
112 }
113
114 #[inline]
115 pub fn is_shutdown(&self) -> bool {
117 self.state.waiters.is_shutdown()
118 }
119
120 #[inline]
121 pub async fn shutdown<R>(&self)
123 where
124 S: Service<R>,
125 {
126 self.state.svc.shutdown().await
127 }
128
129 #[inline]
130 pub fn poll<R>(&self, cx: &mut Context<'_>) -> Result<(), S::Error>
131 where
132 S: Service<R>,
133 {
134 self.state.svc.poll(cx)
135 }
136
137 #[inline]
138 pub fn bind<R>(self) -> PipelineBinding<S, R>
140 where
141 S: Service<R> + 'static,
142 R: 'static,
143 {
144 PipelineBinding::new(self)
145 }
146}
147
148impl<S> From<S> for Pipeline<S> {
149 #[inline]
150 fn from(svc: S) -> Self {
151 Pipeline::new(svc)
152 }
153}
154
155impl<S> Clone for Pipeline<S> {
156 fn clone(&self) -> Self {
157 Pipeline {
158 index: self.state.waiters.insert(),
159 state: self.state.clone(),
160 }
161 }
162}
163
164impl<S> Drop for Pipeline<S> {
165 #[inline]
166 fn drop(&mut self) {
167 self.state.waiters.remove(self.index);
168 }
169}
170
171impl<S: fmt::Debug> fmt::Debug for PipelineState<S> {
172 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
173 f.debug_struct("PipelineState")
174 .field("svc", &self.svc)
175 .field("waiters", &self.waiters.get().len())
176 .finish()
177 }
178}
179
180pub struct PipelineBinding<S, R>
182where
183 S: Service<R>,
184{
185 pl: Pipeline<S>,
186 st: cell::UnsafeCell<State<S::Error>>,
187}
188
189enum State<E> {
190 New,
191 Readiness(Pin<Box<dyn Future<Output = Result<(), E>> + 'static>>),
192 Shutdown(Pin<Box<dyn Future<Output = ()> + 'static>>),
193}
194
195impl<S, R> PipelineBinding<S, R>
196where
197 S: Service<R> + 'static,
198 R: 'static,
199{
200 fn new(pl: Pipeline<S>) -> Self {
201 PipelineBinding {
202 pl,
203 st: cell::UnsafeCell::new(State::New),
204 }
205 }
206
207 #[inline]
208 pub fn get_ref(&self) -> &S {
210 &self.pl.state.svc
211 }
212
213 #[inline]
214 pub fn pipeline(&self) -> Pipeline<S> {
216 self.pl.clone()
217 }
218
219 #[inline]
220 pub fn poll(&self, cx: &mut Context<'_>) -> Result<(), S::Error> {
221 self.pl.poll(cx)
222 }
223
224 #[inline]
225 pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), S::Error>> {
229 let st = unsafe { &mut *self.st.get() };
230
231 match st {
232 State::New => {
233 let pl: &'static Pipeline<S> = unsafe { std::mem::transmute(&self.pl) };
237 let fut = Box::pin(CheckReadiness {
238 fut: None,
239 f: ready,
240 _t: marker::PhantomData,
241 pl,
242 });
243 *st = State::Readiness(fut);
244 self.poll_ready(cx)
245 }
246 State::Readiness(ref mut fut) => Pin::new(fut).poll(cx),
247 State::Shutdown(_) => panic!("Pipeline is shutding down"),
248 }
249 }
250
251 #[doc(hidden)]
252 #[deprecated]
253 #[inline]
254 pub fn poll_not_ready(&self, _: &mut Context<'_>) -> Poll<()> {
256 Poll::Pending
257 }
258
259 #[inline]
260 pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<()> {
262 let st = unsafe { &mut *self.st.get() };
263
264 match st {
265 State::New | State::Readiness(_) => {
266 let pl: &'static Pipeline<S> = unsafe { std::mem::transmute(&self.pl) };
270 *st = State::Shutdown(Box::pin(async move { pl.shutdown().await }));
271 pl.state.waiters.shutdown();
272 self.poll_shutdown(cx)
273 }
274 State::Shutdown(ref mut fut) => Pin::new(fut).poll(cx),
275 }
276 }
277
278 #[inline]
279 pub fn call(&self, req: R) -> PipelineCall<S, R> {
282 let pl = self.pl.clone();
283
284 PipelineCall {
285 fut: Box::pin(async move {
286 ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
287 .call(&pl.state.svc, req)
288 .await
289 }),
290 }
291 }
292
293 #[inline]
294 pub fn call_nowait(&self, req: R) -> PipelineCall<S, R> {
298 let pl = self.pl.clone();
299
300 PipelineCall {
301 fut: Box::pin(async move {
302 ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
303 .call_nowait(&pl.state.svc, req)
304 .await
305 }),
306 }
307 }
308
309 #[inline]
310 pub fn is_shutdown(&self) -> bool {
312 self.pl.state.waiters.is_shutdown()
313 }
314
315 #[inline]
316 pub async fn shutdown(&self) {
318 self.pl.state.svc.shutdown().await
319 }
320}
321
322impl<S, R> Drop for PipelineBinding<S, R>
323where
324 S: Service<R>,
325{
326 fn drop(&mut self) {
327 self.st = cell::UnsafeCell::new(State::New);
328 }
329}
330
331impl<S, R> Clone for PipelineBinding<S, R>
332where
333 S: Service<R>,
334{
335 #[inline]
336 fn clone(&self) -> Self {
337 Self {
338 pl: self.pl.clone(),
339 st: cell::UnsafeCell::new(State::New),
340 }
341 }
342}
343
344impl<S, R> fmt::Debug for PipelineBinding<S, R>
345where
346 S: Service<R> + fmt::Debug,
347{
348 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
349 f.debug_struct("PipelineBinding")
350 .field("pipeline", &self.pl)
351 .finish()
352 }
353}
354
355#[must_use = "futures do nothing unless polled"]
356pub struct PipelineCall<S, R>
358where
359 S: Service<R>,
360 R: 'static,
361{
362 fut: Call<S::Response, S::Error>,
363}
364
365type Call<R, E> = Pin<Box<dyn Future<Output = Result<R, E>> + 'static>>;
366
367impl<S, R> Future for PipelineCall<S, R>
368where
369 S: Service<R>,
370{
371 type Output = Result<S::Response, S::Error>;
372
373 #[inline]
374 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
375 Pin::new(&mut self.as_mut().fut).poll(cx)
376 }
377}
378
379impl<S, R> fmt::Debug for PipelineCall<S, R>
380where
381 S: Service<R>,
382{
383 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
384 f.debug_struct("PipelineCall").finish()
385 }
386}
387
388fn ready<S, R>(pl: &'static Pipeline<S>) -> impl Future<Output = Result<(), S::Error>>
389where
390 S: Service<R>,
391 R: 'static,
392{
393 pl.state
394 .svc
395 .ready(ServiceCtx::<'_, S>::new(pl.index, pl.state.waiters_ref()))
396}
397
398struct CheckReadiness<S: Service<R> + 'static, R, F, Fut> {
399 f: F,
400 fut: Option<Fut>,
401 pl: &'static Pipeline<S>,
402 _t: marker::PhantomData<R>,
403}
404
405impl<S: Service<R>, R, F, Fut> Unpin for CheckReadiness<S, R, F, Fut> {}
406
407impl<S: Service<R>, R, F, Fut> Drop for CheckReadiness<S, R, F, Fut> {
408 fn drop(&mut self) {
409 if self.fut.is_some() {
411 self.pl.state.waiters.notify();
412 }
413 }
414}
415
416impl<S, R, F, Fut> Future for CheckReadiness<S, R, F, Fut>
417where
418 S: Service<R>,
419 F: Fn(&'static Pipeline<S>) -> Fut,
420 Fut: Future<Output = Result<(), S::Error>>,
421{
422 type Output = Result<(), S::Error>;
423
424 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
425 let mut slf = self.as_mut();
426
427 slf.pl.poll(cx)?;
428
429 if slf.pl.state.waiters.can_check(slf.pl.index, cx) {
430 if slf.fut.is_none() {
431 slf.fut = Some((slf.f)(slf.pl));
432 }
433 let fut = slf.fut.as_mut().unwrap();
434 match unsafe { Pin::new_unchecked(fut) }.poll(cx) {
435 Poll::Pending => {
436 slf.pl.state.waiters.register(slf.pl.index, cx);
437 Poll::Pending
438 }
439 Poll::Ready(res) => {
440 let _ = slf.fut.take();
441 slf.pl.state.waiters.notify();
442 Poll::Ready(res)
443 }
444 }
445 } else {
446 Poll::Pending
447 }
448 }
449}