1use std::{cell, fmt, future::Future, marker, pin::Pin, rc::Rc, task::Context, task::Poll};
2
3use crate::{IntoService, Service, ServiceCtx, ctx::WaitersRef};
4
5#[derive(Debug)]
6pub struct Pipeline<S> {
10 index: u32,
11 state: Rc<PipelineState<S>>,
12}
13
14struct PipelineState<S> {
15 svc: S,
16 waiters: WaitersRef,
17}
18
19impl<S> PipelineState<S> {
20 pub(crate) fn waiters_ref(&self) -> &WaitersRef {
21 &self.waiters
22 }
23}
24
25impl<S> Pipeline<S> {
26 #[inline]
27 pub fn new(svc: S) -> Self {
29 let (index, waiters) = WaitersRef::new();
30 Pipeline {
31 index,
32 state: Rc::new(PipelineState { svc, waiters }),
33 }
34 }
35
36 #[inline]
37 pub fn get_ref(&self) -> &S {
39 &self.state.svc
40 }
41
42 #[inline]
43 pub async fn ready<R>(&self) -> Result<(), S::Error>
45 where
46 S: Service<R>,
47 {
48 ServiceCtx::<'_, S>::new(self.index, self.state.waiters_ref())
49 .ready(&self.state.svc)
50 .await
51 }
52
53 #[inline]
54 pub async fn call<R>(&self, req: R) -> Result<S::Response, S::Error>
57 where
58 S: Service<R>,
59 {
60 ServiceCtx::<'_, S>::new(self.index, self.state.waiters_ref())
61 .call(&self.state.svc, req)
62 .await
63 }
64
65 #[inline]
66 pub fn call_static<R>(&self, req: R) -> PipelineCall<S, R>
69 where
70 S: Service<R> + 'static,
71 R: 'static,
72 {
73 let pl = self.clone();
74
75 PipelineCall {
76 fut: Box::pin(async move {
77 ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
78 .call(&pl.state.svc, req)
79 .await
80 }),
81 }
82 }
83
84 #[inline]
85 pub fn call_nowait<R>(&self, req: R) -> PipelineCall<S, R>
89 where
90 S: Service<R> + 'static,
91 R: 'static,
92 {
93 let pl = self.clone();
94
95 PipelineCall {
96 fut: Box::pin(async move {
97 ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
98 .call_nowait(&pl.state.svc, req)
99 .await
100 }),
101 }
102 }
103
104 #[inline]
105 pub fn is_shutdown(&self) -> bool {
107 self.state.waiters.is_shutdown()
108 }
109
110 #[inline]
111 pub async fn shutdown<R>(&self)
113 where
114 S: Service<R>,
115 {
116 self.state.svc.shutdown().await
117 }
118
119 #[inline]
120 pub fn poll<R>(&self, cx: &mut Context<'_>) -> Result<(), S::Error>
121 where
122 S: Service<R>,
123 {
124 self.state.svc.poll(cx)
125 }
126
127 #[inline]
128 pub fn bind<R>(self) -> PipelineBinding<S, R>
130 where
131 S: Service<R> + 'static,
132 R: 'static,
133 {
134 PipelineBinding::new(self)
135 }
136}
137
138impl<S> From<S> for Pipeline<S> {
139 #[inline]
140 fn from(svc: S) -> Self {
141 Pipeline::new(svc)
142 }
143}
144
145impl<S> Clone for Pipeline<S> {
146 fn clone(&self) -> Self {
147 Pipeline {
148 index: self.state.waiters.insert(),
149 state: self.state.clone(),
150 }
151 }
152}
153
154impl<S> Drop for Pipeline<S> {
155 #[inline]
156 fn drop(&mut self) {
157 self.state.waiters.remove(self.index);
158 }
159}
160
161impl<S: fmt::Debug> fmt::Debug for PipelineState<S> {
162 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163 f.debug_struct("PipelineState")
164 .field("svc", &self.svc)
165 .field("waiters", &self.waiters.get().len())
166 .finish()
167 }
168}
169
170#[derive(Debug)]
171pub struct PipelineSvc<S> {
173 inner: Pipeline<S>,
174}
175
176impl<S> PipelineSvc<S> {
177 #[inline]
178 pub fn new(inner: Pipeline<S>) -> Self {
180 Self { inner }
181 }
182}
183
184impl<S, Req> Service<Req> for PipelineSvc<S>
185where
186 S: Service<Req>,
187{
188 type Response = S::Response;
189 type Error = S::Error;
190
191 #[inline]
192 async fn call(
193 &self,
194 req: Req,
195 _: ServiceCtx<'_, Self>,
196 ) -> Result<Self::Response, Self::Error> {
197 self.inner.call(req).await
198 }
199
200 #[inline]
201 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
202 self.inner.ready().await
203 }
204
205 #[inline]
206 async fn shutdown(&self) {
207 self.inner.shutdown().await
208 }
209
210 #[inline]
211 fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
212 self.inner.poll(cx)
213 }
214}
215
216impl<S> From<S> for PipelineSvc<S> {
217 #[inline]
218 fn from(svc: S) -> Self {
219 PipelineSvc {
220 inner: Pipeline::new(svc),
221 }
222 }
223}
224
225impl<S> Clone for PipelineSvc<S> {
226 fn clone(&self) -> Self {
227 PipelineSvc {
228 inner: self.inner.clone(),
229 }
230 }
231}
232
233impl<S, R> IntoService<PipelineSvc<S>, R> for Pipeline<S>
234where
235 S: Service<R>,
236{
237 #[inline]
238 fn into_service(self) -> PipelineSvc<S> {
239 PipelineSvc::new(self)
240 }
241}
242
243pub struct PipelineBinding<S, R>
245where
246 S: Service<R>,
247{
248 pl: Pipeline<S>,
249 st: cell::UnsafeCell<State<S::Error>>,
250}
251
252enum State<E> {
253 New,
254 Readiness(Pin<Box<dyn Future<Output = Result<(), E>> + 'static>>),
255 Shutdown(Pin<Box<dyn Future<Output = ()> + 'static>>),
256}
257
258impl<S, R> PipelineBinding<S, R>
259where
260 S: Service<R> + 'static,
261 R: 'static,
262{
263 fn new(pl: Pipeline<S>) -> Self {
264 PipelineBinding {
265 pl,
266 st: cell::UnsafeCell::new(State::New),
267 }
268 }
269
270 #[inline]
271 pub fn get_ref(&self) -> &S {
273 &self.pl.state.svc
274 }
275
276 #[inline]
277 pub fn pipeline(&self) -> Pipeline<S> {
279 self.pl.clone()
280 }
281
282 #[inline]
283 pub fn poll(&self, cx: &mut Context<'_>) -> Result<(), S::Error> {
284 self.pl.poll(cx)
285 }
286
287 #[inline]
288 pub fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), S::Error>> {
292 let st = unsafe { &mut *self.st.get() };
293
294 match st {
295 State::New => {
296 let pl: &'static Pipeline<S> = unsafe { std::mem::transmute(&self.pl) };
300 let fut = Box::pin(CheckReadiness {
301 fut: None,
302 f: ready,
303 _t: marker::PhantomData,
304 pl,
305 });
306 *st = State::Readiness(fut);
307 self.poll_ready(cx)
308 }
309 State::Readiness(fut) => Pin::new(fut).poll(cx),
310 State::Shutdown(_) => panic!("Pipeline is shutding down"),
311 }
312 }
313
314 #[inline]
315 pub fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<()> {
317 let st = unsafe { &mut *self.st.get() };
318
319 match st {
320 State::New | State::Readiness(_) => {
321 let pl: &'static Pipeline<S> = unsafe { std::mem::transmute(&self.pl) };
325 *st = State::Shutdown(Box::pin(async move { pl.shutdown().await }));
326 pl.state.waiters.shutdown();
327 self.poll_shutdown(cx)
328 }
329 State::Shutdown(fut) => Pin::new(fut).poll(cx),
330 }
331 }
332
333 #[inline]
334 pub fn call(&self, req: R) -> PipelineCall<S, R> {
337 let pl = self.pl.clone();
338
339 PipelineCall {
340 fut: Box::pin(async move {
341 ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
342 .call(&pl.state.svc, req)
343 .await
344 }),
345 }
346 }
347
348 #[inline]
349 pub fn call_nowait(&self, req: R) -> PipelineCall<S, R> {
353 let pl = self.pl.clone();
354
355 PipelineCall {
356 fut: Box::pin(async move {
357 ServiceCtx::<S>::new(pl.index, pl.state.waiters_ref())
358 .call_nowait(&pl.state.svc, req)
359 .await
360 }),
361 }
362 }
363
364 #[inline]
365 pub fn is_shutdown(&self) -> bool {
367 self.pl.state.waiters.is_shutdown()
368 }
369
370 #[inline]
371 pub async fn shutdown(&self) {
373 self.pl.state.svc.shutdown().await
374 }
375}
376
377impl<S, R> Drop for PipelineBinding<S, R>
378where
379 S: Service<R>,
380{
381 fn drop(&mut self) {
382 self.st = cell::UnsafeCell::new(State::New);
383 }
384}
385
386impl<S, R> Clone for PipelineBinding<S, R>
387where
388 S: Service<R>,
389{
390 #[inline]
391 fn clone(&self) -> Self {
392 Self {
393 pl: self.pl.clone(),
394 st: cell::UnsafeCell::new(State::New),
395 }
396 }
397}
398
399impl<S, R> fmt::Debug for PipelineBinding<S, R>
400where
401 S: Service<R> + fmt::Debug,
402{
403 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
404 f.debug_struct("PipelineBinding")
405 .field("pipeline", &self.pl)
406 .finish()
407 }
408}
409
410#[must_use = "futures do nothing unless polled"]
411pub struct PipelineCall<S, R>
413where
414 S: Service<R>,
415 R: 'static,
416{
417 fut: Call<S::Response, S::Error>,
418}
419
420type Call<R, E> = Pin<Box<dyn Future<Output = Result<R, E>> + 'static>>;
421
422impl<S, R> Future for PipelineCall<S, R>
423where
424 S: Service<R>,
425{
426 type Output = Result<S::Response, S::Error>;
427
428 #[inline]
429 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
430 Pin::new(&mut self.as_mut().fut).poll(cx)
431 }
432}
433
434impl<S, R> fmt::Debug for PipelineCall<S, R>
435where
436 S: Service<R>,
437{
438 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
439 f.debug_struct("PipelineCall").finish()
440 }
441}
442
443fn ready<S, R>(pl: &'static Pipeline<S>) -> impl Future<Output = Result<(), S::Error>>
444where
445 S: Service<R>,
446 R: 'static,
447{
448 pl.state
449 .svc
450 .ready(ServiceCtx::<'_, S>::new(pl.index, pl.state.waiters_ref()))
451}
452
453struct CheckReadiness<S: Service<R> + 'static, R, F, Fut> {
454 f: F,
455 fut: Option<Fut>,
456 pl: &'static Pipeline<S>,
457 _t: marker::PhantomData<R>,
458}
459
460impl<S: Service<R>, R, F, Fut> Unpin for CheckReadiness<S, R, F, Fut> {}
461
462impl<S: Service<R>, R, F, Fut> Drop for CheckReadiness<S, R, F, Fut> {
463 fn drop(&mut self) {
464 if self.fut.is_some() {
466 self.pl.state.waiters.notify();
467 }
468 }
469}
470
471impl<S, R, F, Fut> Future for CheckReadiness<S, R, F, Fut>
472where
473 S: Service<R>,
474 F: Fn(&'static Pipeline<S>) -> Fut,
475 Fut: Future<Output = Result<(), S::Error>>,
476{
477 type Output = Result<(), S::Error>;
478
479 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
480 let mut slf = self.as_mut();
481
482 slf.pl.poll(cx)?;
483
484 slf.pl.state.waiters.run(slf.pl.index, cx, |cx| {
485 if slf.fut.is_none() {
486 slf.fut = Some((slf.f)(slf.pl));
487 }
488 let fut = slf.fut.as_mut().unwrap();
489 let result = unsafe { Pin::new_unchecked(fut) }.poll(cx);
490 if result.is_ready() {
491 let _ = slf.fut.take();
492 }
493 result
494 })
495 }
496}
497
498#[cfg(test)]
499mod tests {
500 use std::{cell::Cell, rc::Rc};
501
502 use super::*;
503
504 #[derive(Debug, Default, Clone)]
505 struct Srv(Rc<Cell<usize>>);
506
507 impl Service<()> for Srv {
508 type Response = ();
509 type Error = ();
510
511 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
512 Ok(())
513 }
514
515 async fn call(&self, _: (), _: ServiceCtx<'_, Self>) -> Result<(), ()> {
516 Ok(())
517 }
518
519 async fn shutdown(&self) {
520 self.0.set(self.0.get() + 1);
521 }
522 }
523
524 #[ntex::test]
525 async fn pipeline_service() {
526 let cnt_sht = Rc::new(Cell::new(0));
527 let srv = Pipeline::new(
528 Pipeline::new(Srv(cnt_sht.clone()).map(|_| "ok"))
529 .into_service()
530 .clone(),
531 );
532 let res = srv.call(()).await;
533 assert!(res.is_ok());
534 assert_eq!(res.unwrap(), "ok");
535
536 let res = srv.ready().await;
537 assert_eq!(res, Ok(()));
538
539 srv.shutdown().await;
540 assert_eq!(cnt_sht.get(), 1);
541 let _ = format!("{srv:?}");
542
543 let cnt_sht = Rc::new(Cell::new(0));
544 let svc = Srv(cnt_sht.clone()).map(|_| "ok");
545 let srv = Pipeline::new(PipelineSvc::from(&svc));
546 let res = srv.call(()).await;
547 assert!(res.is_ok());
548 assert_eq!(res.unwrap(), "ok");
549
550 let res = srv.ready().await;
551 assert_eq!(res, Ok(()));
552
553 srv.shutdown().await;
554 assert_eq!(cnt_sht.get(), 1);
555 let _ = format!("{srv:?}");
556 }
557}