1use std::{cell, fmt, future::Future, marker, pin::Pin, rc::Rc, task::Context, task::Poll};
2
3use crate::{IntoService, Service, ServiceCtx, ctx::WaitersRef};
4
5#[derive(Debug)]
6pub struct Pipeline<S> {
10 index: u32,
11 state: Rc<PipelineState<S>>,
12}
13
14struct PipelineState<S> {
15 svc: S,
16 waiters: WaitersRef,
17}
18
19impl<S> PipelineState<S> {
20 pub(crate) fn waiters_ref(&self) -> &WaitersRef {
21 &self.waiters
22 }
23}
24
25impl<S> Pipeline<S> {
26 #[inline]
27 pub fn new(svc: S) -> Self {
29 let (index, waiters) = WaitersRef::new();
30 Pipeline {
31 index,
32 state: Rc::new(PipelineState { svc, waiters }),
33 }
34 }
35
36 #[inline]
37 pub fn get_ref(&self) -> &S {
39 &self.state.svc
40 }
41
42 #[inline]
43 pub async fn ready<R>(&self) -> Result<(), S::Error>
45 where
46 S: Service<R>,
47 {
48 ServiceCtx::<'_, S>::new(self.index, self.state.waiters_ref())
49 .ready(&self.state.svc)
50 .await
51 }
52
53 #[inline]
54 pub async fn call<R>(&self, req: R) -> Result<S::Response, S::Error>
57 where
58 S: Service<R>,
59 {
60 ServiceCtx::<'_, S>::new(self.index, self.state.waiters_ref())
61 .call(&self.state.svc, req)
62 .await
63 }
64
65 #[inline]
66 pub fn call_static<R>(&self, req: R) -> PipelineCall<S, R>
69 where
70 S: Service<R> + 'static,
71 R: 'static,
72 {
73 let pl = self.clone();
74
75 PipelineCall {
76 fut: Box::pin(async move {
77 ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
78 .call(&pl.state.svc, req)
79 .await
80 }),
81 }
82 }
83
84 #[inline]
85 pub fn call_nowait<R>(&self, req: R) -> PipelineCall<S, R>
89 where
90 S: Service<R> + 'static,
91 R: 'static,
92 {
93 let pl = self.clone();
94
95 PipelineCall {
96 fut: Box::pin(async move {
97 ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
98 .call_nowait(&pl.state.svc, req)
99 .await
100 }),
101 }
102 }
103
104 #[inline]
105 pub fn is_shutdown(&self) -> bool {
107 self.state.waiters.is_shutdown()
108 }
109
110 #[inline]
111 pub async fn shutdown<R>(&self)
113 where
114 S: Service<R>,
115 {
116 self.state.svc.shutdown().await;
117 }
118
119 #[inline]
120 pub fn poll<R>(&self, cx: &mut Context<'_>) -> Result<(), S::Error>
121 where
122 S: Service<R>,
123 {
124 self.state.svc.poll(cx)
125 }
126
127 #[inline]
128 pub fn bind<R>(self) -> PipelineBinding<S, R>
130 where
131 S: Service<R> + 'static,
132 R: 'static,
133 {
134 PipelineBinding::new(self)
135 }
136}
137
138impl<S> From<S> for Pipeline<S> {
139 #[inline]
140 fn from(svc: S) -> Self {
141 Pipeline::new(svc)
142 }
143}
144
145impl<S> Clone for Pipeline<S> {
146 fn clone(&self) -> Self {
147 Pipeline {
148 index: self.state.waiters.insert(),
149 state: self.state.clone(),
150 }
151 }
152}
153
154impl<S> Drop for Pipeline<S> {
155 #[inline]
156 fn drop(&mut self) {
157 self.state.waiters.remove(self.index);
158 }
159}
160
161impl<S: fmt::Debug> fmt::Debug for PipelineState<S> {
162 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163 f.debug_struct("PipelineState")
164 .field("svc", &self.svc)
165 .field("waiters", &self.waiters.get().len())
166 .finish()
167 }
168}
169
170#[derive(Debug)]
171pub struct PipelineSvc<S> {
173 inner: Pipeline<S>,
174}
175
176impl<S> PipelineSvc<S> {
177 #[inline]
178 pub fn new(inner: Pipeline<S>) -> Self {
180 Self { inner }
181 }
182}
183
184impl<S, Req> Service<Req> for PipelineSvc<S>
185where
186 S: Service<Req>,
187{
188 type Response = S::Response;
189 type Error = S::Error;
190
191 #[inline]
192 async fn call(
193 &self,
194 req: Req,
195 _: ServiceCtx<'_, Self>,
196 ) -> Result<Self::Response, Self::Error> {
197 self.inner.call(req).await
198 }
199
200 #[inline]
201 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
202 self.inner.ready().await
203 }
204
205 #[inline]
206 async fn shutdown(&self) {
207 self.inner.shutdown().await;
208 }
209
210 #[inline]
211 fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
212 self.inner.poll(cx)
213 }
214}
215
216impl<S> From<S> for PipelineSvc<S> {
217 #[inline]
218 fn from(svc: S) -> Self {
219 PipelineSvc {
220 inner: Pipeline::new(svc),
221 }
222 }
223}
224
225impl<S> Clone for PipelineSvc<S> {
226 fn clone(&self) -> Self {
227 PipelineSvc {
228 inner: self.inner.clone(),
229 }
230 }
231}
232
233impl<S, R> IntoService<PipelineSvc<S>, R> for Pipeline<S>
234where
235 S: Service<R>,
236{
237 #[inline]
238 fn into_service(self) -> PipelineSvc<S> {
239 PipelineSvc::new(self)
240 }
241}
242
243pub struct PipelineBinding<S, R>
245where
246 S: Service<R>,
247{
248 pl: Pipeline<S>,
249 st: cell::UnsafeCell<State<S::Error>>,
250}
251
252enum State<E> {
253 New,
254 Readiness(Pin<Box<dyn Future<Output = Result<(), E>> + 'static>>),
255 Shutdown(Pin<Box<dyn Future<Output = ()> + 'static>>),
256}
257
258impl<S, R> PipelineBinding<S, R>
259where
260 S: Service<R> + 'static,
261 R: 'static,
262{
263 fn new(pl: Pipeline<S>) -> Self {
264 PipelineBinding {
265 pl,
266 st: cell::UnsafeCell::new(State::New),
267 }
268 }
269
270 #[inline]
271 pub fn get_ref(&self) -> &S {
273 &self.pl.state.svc
274 }
275
276 #[inline]
277 pub fn pipeline(&self) -> Pipeline<S> {
279 self.pl.clone()
280 }
281
282 #[inline]
283 pub fn poll(&self, cx: &mut Context<'_>) -> Result<(), S::Error> {
284 self.pl.poll(cx)
285 }
286
287 #[inline]
288 pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), S::Error>> {
294 let st = unsafe { &mut *self.st.get() };
295
296 match st {
297 State::New => {
298 let pl: &'static Pipeline<S> = unsafe { std::mem::transmute(&self.pl) };
302 let fut = Box::pin(CheckReadiness {
303 fut: None,
304 f: ready,
305 _t: marker::PhantomData,
306 pl,
307 });
308 *st = State::Readiness(fut);
309 self.poll_ready(cx)
310 }
311 State::Readiness(fut) => Pin::new(fut).poll(cx),
312 State::Shutdown(_) => panic!("Pipeline is shutding down"),
313 }
314 }
315
316 #[inline]
317 pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<()> {
319 let st = unsafe { &mut *self.st.get() };
320
321 match st {
322 State::New | State::Readiness(_) => {
323 let pl: &'static Pipeline<S> = unsafe { std::mem::transmute(&self.pl) };
327 *st = State::Shutdown(Box::pin(async move { pl.shutdown().await }));
328 pl.state.waiters.shutdown();
329 self.poll_shutdown(cx)
330 }
331 State::Shutdown(fut) => Pin::new(fut).poll(cx),
332 }
333 }
334
335 #[inline]
336 pub fn call(&self, req: R) -> PipelineCall<S, R> {
339 let pl = self.pl.clone();
340
341 PipelineCall {
342 fut: Box::pin(async move {
343 ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
344 .call(&pl.state.svc, req)
345 .await
346 }),
347 }
348 }
349
350 #[inline]
351 pub fn call_nowait(&self, req: R) -> PipelineCall<S, R> {
355 let pl = self.pl.clone();
356
357 PipelineCall {
358 fut: Box::pin(async move {
359 ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
360 .call_nowait(&pl.state.svc, req)
361 .await
362 }),
363 }
364 }
365
366 #[inline]
367 pub fn is_shutdown(&self) -> bool {
369 self.pl.state.waiters.is_shutdown()
370 }
371
372 #[inline]
373 pub async fn shutdown(&self) {
375 self.pl.state.svc.shutdown().await;
376 }
377}
378
379impl<S, R> Drop for PipelineBinding<S, R>
380where
381 S: Service<R>,
382{
383 fn drop(&mut self) {
384 self.st = cell::UnsafeCell::new(State::New);
385 }
386}
387
388impl<S, R> Clone for PipelineBinding<S, R>
389where
390 S: Service<R>,
391{
392 #[inline]
393 fn clone(&self) -> Self {
394 Self {
395 pl: self.pl.clone(),
396 st: cell::UnsafeCell::new(State::New),
397 }
398 }
399}
400
401impl<S, R> fmt::Debug for PipelineBinding<S, R>
402where
403 S: Service<R> + fmt::Debug,
404{
405 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
406 f.debug_struct("PipelineBinding")
407 .field("pipeline", &self.pl)
408 .finish()
409 }
410}
411
412#[must_use = "futures do nothing unless polled"]
413pub struct PipelineCall<S, R>
415where
416 S: Service<R>,
417 R: 'static,
418{
419 fut: Call<S::Response, S::Error>,
420}
421
422type Call<R, E> = Pin<Box<dyn Future<Output = Result<R, E>> + 'static>>;
423
424impl<S, R> Future for PipelineCall<S, R>
425where
426 S: Service<R>,
427{
428 type Output = Result<S::Response, S::Error>;
429
430 #[inline]
431 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
432 Pin::new(&mut self.as_mut().fut).poll(cx)
433 }
434}
435
436impl<S, R> fmt::Debug for PipelineCall<S, R>
437where
438 S: Service<R>,
439{
440 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
441 f.debug_struct("PipelineCall").finish()
442 }
443}
444
445fn ready<S, R>(pl: &'static Pipeline<S>) -> impl Future<Output = Result<(), S::Error>>
446where
447 S: Service<R>,
448 R: 'static,
449{
450 pl.state
451 .svc
452 .ready(ServiceCtx::<'_, S>::new(pl.index, pl.state.waiters_ref()))
453}
454
455struct CheckReadiness<S: Service<R> + 'static, R, F, Fut> {
456 f: F,
457 fut: Option<Fut>,
458 pl: &'static Pipeline<S>,
459 _t: marker::PhantomData<R>,
460}
461
462impl<S: Service<R>, R, F, Fut> Unpin for CheckReadiness<S, R, F, Fut> {}
463
464impl<S: Service<R>, R, F, Fut> Drop for CheckReadiness<S, R, F, Fut> {
465 fn drop(&mut self) {
466 if self.fut.is_some() {
468 self.pl.state.waiters.notify();
469 }
470 }
471}
472
473impl<S, R, F, Fut> Future for CheckReadiness<S, R, F, Fut>
474where
475 S: Service<R>,
476 F: Fn(&'static Pipeline<S>) -> Fut,
477 Fut: Future<Output = Result<(), S::Error>>,
478{
479 type Output = Result<(), S::Error>;
480
481 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
482 let mut this = self.as_mut();
483
484 this.pl.poll(cx)?;
485
486 this.pl.state.waiters.run(this.pl.index, cx, |cx| {
487 if this.fut.is_none() {
488 this.fut = Some((this.f)(this.pl));
489 }
490 let fut = this.fut.as_mut().unwrap();
491 let result = unsafe { Pin::new_unchecked(fut) }.poll(cx);
492 if result.is_ready() {
493 let _ = this.fut.take();
494 }
495 result
496 })
497 }
498}
499
500#[cfg(test)]
501mod tests {
502 use std::{cell::Cell, future::poll_fn, rc::Rc};
503
504 use super::*;
505
506 #[derive(Debug, Default, Clone)]
507 struct Srv(Rc<Cell<usize>>);
508
509 impl Service<()> for Srv {
510 type Response = ();
511 type Error = ();
512
513 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
514 Ok(())
515 }
516
517 async fn call(&self, _m: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
518 Ok(())
519 }
520
521 async fn shutdown(&self) {
522 self.0.set(self.0.get() + 1);
523 }
524 }
525
526 #[ntex::test]
527 async fn pipeline_service() {
528 let cnt_sht = Rc::new(Cell::new(0));
529 let srv = Pipeline::new(
530 Pipeline::new(Srv(cnt_sht.clone()).map(|()| "ok"))
531 .into_service()
532 .clone(),
533 );
534 let res = srv.call(()).await;
535 assert!(res.is_ok());
536 assert_eq!(res.unwrap(), "ok");
537
538 let res = srv.ready().await;
539 assert_eq!(res, Ok(()));
540
541 srv.shutdown().await;
542 assert_eq!(cnt_sht.get(), 1);
543 let _ = format!("{srv:?}");
544
545 let cnt_sht = Rc::new(Cell::new(0));
546 let svc = Srv(cnt_sht.clone()).map(|()| "ok");
547 let srv = Pipeline::new(PipelineSvc::from(&svc));
548 let res = srv.call(()).await;
549 assert!(res.is_ok());
550 assert_eq!(res.unwrap(), "ok");
551
552 let res = srv.ready().await;
553 assert_eq!(res, Ok(()));
554
555 let res = poll_fn(|cx| Poll::Ready(srv.poll(cx))).await;
556 assert_eq!(res, Ok(()));
557
558 srv.shutdown().await;
559 assert_eq!(cnt_sht.get(), 1);
560 let _ = format!("{srv:?}");
561 }
562}