actix_ioframe/
service.rs

1use std::future::Future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::rc::Rc;
5use std::task::{Context, Poll};
6
7use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder};
8use actix_service::{IntoService, IntoServiceFactory, Service, ServiceFactory};
9use actix_utils::mpsc;
10use either::Either;
11use futures::future::{FutureExt, LocalBoxFuture};
12use pin_project::project;
13
14use crate::connect::{Connect, ConnectResult};
15use crate::dispatcher::{Dispatcher, Message};
16use crate::error::ServiceError;
17use crate::item::Item;
18use crate::sink::Sink;
19
20type RequestItem<S, U> = Item<S, U>;
21type ResponseItem<U> = Option<<U as Encoder>::Item>;
22type ServiceResult<U, E> = Result<Message<<U as Encoder>::Item>, E>;
23
24/// Service builder - structure that follows the builder pattern
25/// for building instances for framed services.
26pub struct Builder<St, Codec>(PhantomData<(St, Codec)>);
27
28impl<St: Clone, Codec> Default for Builder<St, Codec> {
29    fn default() -> Builder<St, Codec> {
30        Builder::new()
31    }
32}
33
34impl<St: Clone, Codec> Builder<St, Codec> {
35    pub fn new() -> Builder<St, Codec> {
36        Builder(PhantomData)
37    }
38
39    /// Construct framed handler service with specified connect service
40    pub fn service<Io, C, F>(self, connect: F) -> ServiceBuilder<St, C, Io, Codec>
41    where
42        F: IntoService<C>,
43        Io: AsyncRead + AsyncWrite,
44        C: Service<Request = Connect<Io, Codec>, Response = ConnectResult<Io, St, Codec>>,
45        Codec: Decoder + Encoder,
46    {
47        ServiceBuilder {
48            connect: connect.into_service(),
49            disconnect: None,
50            _t: PhantomData,
51        }
52    }
53
54    /// Construct framed handler new service with specified connect service
55    pub fn factory<Io, C, F>(self, connect: F) -> NewServiceBuilder<St, C, Io, Codec>
56    where
57        F: IntoServiceFactory<C>,
58        Io: AsyncRead + AsyncWrite,
59        C: ServiceFactory<
60            Config = (),
61            Request = Connect<Io, Codec>,
62            Response = ConnectResult<Io, St, Codec>,
63        >,
64        C::Error: 'static,
65        C::Future: 'static,
66        Codec: Decoder + Encoder,
67    {
68        NewServiceBuilder {
69            connect: connect.into_factory(),
70            disconnect: None,
71            _t: PhantomData,
72        }
73    }
74}
75
76pub struct ServiceBuilder<St, C, Io, Codec> {
77    connect: C,
78    disconnect: Option<Rc<dyn Fn(St, bool)>>,
79    _t: PhantomData<(St, Io, Codec)>,
80}
81
82impl<St, C, Io, Codec> ServiceBuilder<St, C, Io, Codec>
83where
84    St: Clone,
85    C: Service<Request = Connect<Io, Codec>, Response = ConnectResult<Io, St, Codec>>,
86    Io: AsyncRead + AsyncWrite,
87    Codec: Decoder + Encoder,
88    <Codec as Encoder>::Item: 'static,
89    <Codec as Encoder>::Error: std::fmt::Debug,
90{
91    /// Callback to execute on disconnect
92    ///
93    /// Second parameter indicates error occured during disconnect.
94    pub fn disconnect<F, Out>(mut self, disconnect: F) -> Self
95    where
96        F: Fn(St, bool) + 'static,
97    {
98        self.disconnect = Some(Rc::new(disconnect));
99        self
100    }
101
102    /// Provide stream items handler service and construct service factory.
103    pub fn finish<F, T>(self, service: F) -> FramedServiceImpl<St, C, T, Io, Codec>
104    where
105        F: IntoServiceFactory<T>,
106        T: ServiceFactory<
107            Config = St,
108            Request = RequestItem<St, Codec>,
109            Response = ResponseItem<Codec>,
110            Error = C::Error,
111            InitError = C::Error,
112        >,
113    {
114        FramedServiceImpl {
115            connect: self.connect,
116            handler: Rc::new(service.into_factory()),
117            disconnect: self.disconnect.clone(),
118            _t: PhantomData,
119        }
120    }
121}
122
123pub struct NewServiceBuilder<St, C, Io, Codec> {
124    connect: C,
125    disconnect: Option<Rc<dyn Fn(St, bool)>>,
126    _t: PhantomData<(St, Io, Codec)>,
127}
128
129impl<St, C, Io, Codec> NewServiceBuilder<St, C, Io, Codec>
130where
131    St: Clone,
132    Io: AsyncRead + AsyncWrite,
133    C: ServiceFactory<
134        Config = (),
135        Request = Connect<Io, Codec>,
136        Response = ConnectResult<Io, St, Codec>,
137    >,
138    C::Error: 'static,
139    C::Future: 'static,
140    Codec: Decoder + Encoder,
141    <Codec as Encoder>::Item: 'static,
142    <Codec as Encoder>::Error: std::fmt::Debug,
143{
144    /// Callback to execute on disconnect
145    ///
146    /// Second parameter indicates error occured during disconnect.
147    pub fn disconnect<F>(mut self, disconnect: F) -> Self
148    where
149        F: Fn(St, bool) + 'static,
150    {
151        self.disconnect = Some(Rc::new(disconnect));
152        self
153    }
154
155    pub fn finish<F, T, Cfg>(self, service: F) -> FramedService<St, C, T, Io, Codec, Cfg>
156    where
157        F: IntoServiceFactory<T>,
158        T: ServiceFactory<
159                Config = St,
160                Request = RequestItem<St, Codec>,
161                Response = ResponseItem<Codec>,
162                Error = C::Error,
163                InitError = C::Error,
164            > + 'static,
165    {
166        FramedService {
167            connect: self.connect,
168            handler: Rc::new(service.into_factory()),
169            disconnect: self.disconnect,
170            _t: PhantomData,
171        }
172    }
173}
174
175pub struct FramedService<St, C, T, Io, Codec, Cfg> {
176    connect: C,
177    handler: Rc<T>,
178    disconnect: Option<Rc<dyn Fn(St, bool)>>,
179    _t: PhantomData<(St, Io, Codec, Cfg)>,
180}
181
182impl<St, C, T, Io, Codec, Cfg> ServiceFactory for FramedService<St, C, T, Io, Codec, Cfg>
183where
184    St: Clone + 'static,
185    Io: AsyncRead + AsyncWrite,
186    C: ServiceFactory<
187        Config = (),
188        Request = Connect<Io, Codec>,
189        Response = ConnectResult<Io, St, Codec>,
190    >,
191    C::Error: 'static,
192    C::Future: 'static,
193    T: ServiceFactory<
194            Config = St,
195            Request = RequestItem<St, Codec>,
196            Response = ResponseItem<Codec>,
197            Error = C::Error,
198            InitError = C::Error,
199        > + 'static,
200    <T::Service as Service>::Future: 'static,
201    Codec: Decoder + Encoder,
202    <Codec as Encoder>::Item: 'static,
203    <Codec as Encoder>::Error: std::fmt::Debug,
204{
205    type Config = Cfg;
206    type Request = Io;
207    type Response = ();
208    type Error = ServiceError<C::Error, Codec>;
209    type InitError = C::InitError;
210    type Service = FramedServiceImpl<St, C::Service, T, Io, Codec>;
211    type Future = LocalBoxFuture<'static, Result<Self::Service, Self::InitError>>;
212
213    fn new_service(&self, _: Cfg) -> Self::Future {
214        let handler = self.handler.clone();
215        let disconnect = self.disconnect.clone();
216
217        // create connect service and then create service impl
218        self.connect
219            .new_service(())
220            .map(move |result| {
221                result.map(move |connect| FramedServiceImpl {
222                    connect,
223                    handler,
224                    disconnect,
225                    _t: PhantomData,
226                })
227            })
228            .boxed_local()
229    }
230}
231
232pub struct FramedServiceImpl<St, C, T, Io, Codec> {
233    connect: C,
234    handler: Rc<T>,
235    disconnect: Option<Rc<dyn Fn(St, bool)>>,
236    _t: PhantomData<(St, Io, Codec)>,
237}
238
239impl<St, C, T, Io, Codec> Service for FramedServiceImpl<St, C, T, Io, Codec>
240where
241    St: Clone,
242    Io: AsyncRead + AsyncWrite,
243    C: Service<Request = Connect<Io, Codec>, Response = ConnectResult<Io, St, Codec>>,
244    C::Error: 'static,
245    T: ServiceFactory<
246        Config = St,
247        Request = RequestItem<St, Codec>,
248        Response = ResponseItem<Codec>,
249        Error = C::Error,
250        InitError = C::Error,
251    >,
252    <T::Service as Service>::Future: 'static,
253    Codec: Decoder + Encoder,
254    <Codec as Encoder>::Item: 'static,
255    <Codec as Encoder>::Error: std::fmt::Debug,
256{
257    type Request = Io;
258    type Response = ();
259    type Error = ServiceError<C::Error, Codec>;
260    type Future = FramedServiceImplResponse<St, Io, Codec, C, T>;
261
262    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
263        self.connect.poll_ready(cx).map_err(|e| e.into())
264    }
265
266    fn call(&mut self, req: Io) -> Self::Future {
267        let (tx, rx) = mpsc::channel();
268        let sink = Sink::new(Rc::new(move |msg| {
269            let _ = tx.send(Ok(msg));
270        }));
271        FramedServiceImplResponse {
272            inner: FramedServiceImplResponseInner::Connect(
273                self.connect.call(Connect::new(req, sink.clone())),
274                self.handler.clone(),
275                self.disconnect.clone(),
276                Some(rx),
277            ),
278        }
279    }
280}
281
282#[pin_project::pin_project]
283pub struct FramedServiceImplResponse<St, Io, Codec, C, T>
284where
285    St: Clone,
286    C: Service<Request = Connect<Io, Codec>, Response = ConnectResult<Io, St, Codec>>,
287    C::Error: 'static,
288    T: ServiceFactory<
289        Config = St,
290        Request = RequestItem<St, Codec>,
291        Response = ResponseItem<Codec>,
292        Error = C::Error,
293        InitError = C::Error,
294    >,
295    <T::Service as Service>::Future: 'static,
296    Io: AsyncRead + AsyncWrite,
297    Codec: Encoder + Decoder,
298    <Codec as Encoder>::Item: 'static,
299    <Codec as Encoder>::Error: std::fmt::Debug,
300{
301    #[pin]
302    inner: FramedServiceImplResponseInner<St, Io, Codec, C, T>,
303}
304
305impl<St, Io, Codec, C, T> Future for FramedServiceImplResponse<St, Io, Codec, C, T>
306where
307    St: Clone,
308    C: Service<Request = Connect<Io, Codec>, Response = ConnectResult<Io, St, Codec>>,
309    C::Error: 'static,
310    T: ServiceFactory<
311        Config = St,
312        Request = RequestItem<St, Codec>,
313        Response = ResponseItem<Codec>,
314        Error = C::Error,
315        InitError = C::Error,
316    >,
317    <T::Service as Service>::Future: 'static,
318    Io: AsyncRead + AsyncWrite,
319    Codec: Encoder + Decoder,
320    <Codec as Encoder>::Item: 'static,
321    <Codec as Encoder>::Error: std::fmt::Debug,
322{
323    type Output = Result<(), ServiceError<C::Error, Codec>>;
324
325    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
326        let mut this = self.as_mut().project();
327
328        loop {
329            match this.inner.poll(cx) {
330                Either::Left(new) => {
331                    this = self.as_mut().project();
332                    this.inner.set(new)
333                }
334                Either::Right(poll) => return poll,
335            };
336        }
337    }
338}
339
340#[pin_project::pin_project]
341enum FramedServiceImplResponseInner<St, Io, Codec, C, T>
342where
343    St: Clone,
344    C: Service<Request = Connect<Io, Codec>, Response = ConnectResult<Io, St, Codec>>,
345    C::Error: 'static,
346    T: ServiceFactory<
347        Config = St,
348        Request = RequestItem<St, Codec>,
349        Response = ResponseItem<Codec>,
350        Error = C::Error,
351        InitError = C::Error,
352    >,
353    <T::Service as Service>::Future: 'static,
354    Io: AsyncRead + AsyncWrite,
355    Codec: Encoder + Decoder,
356    <Codec as Encoder>::Item: 'static,
357    <Codec as Encoder>::Error: std::fmt::Debug,
358{
359    Connect(
360        #[pin] C::Future,
361        Rc<T>,
362        Option<Rc<dyn Fn(St, bool)>>,
363        Option<mpsc::Receiver<ServiceResult<Codec, C::Error>>>,
364    ),
365    Handler(
366        #[pin] T::Future,
367        Option<ConnectResult<Io, St, Codec>>,
368        Option<Rc<dyn Fn(St, bool)>>,
369        Option<mpsc::Receiver<ServiceResult<Codec, C::Error>>>,
370    ),
371    Dispatcher(Dispatcher<St, T::Service, Io, Codec>),
372}
373
374impl<St, Io, Codec, C, T> FramedServiceImplResponseInner<St, Io, Codec, C, T>
375where
376    St: Clone,
377    C: Service<Request = Connect<Io, Codec>, Response = ConnectResult<Io, St, Codec>>,
378    C::Error: 'static,
379    T: ServiceFactory<
380        Config = St,
381        Request = RequestItem<St, Codec>,
382        Response = ResponseItem<Codec>,
383        Error = C::Error,
384        InitError = C::Error,
385    >,
386    <T::Service as Service>::Future: 'static,
387    Io: AsyncRead + AsyncWrite,
388    Codec: Encoder + Decoder,
389    <Codec as Encoder>::Item: 'static,
390    <Codec as Encoder>::Error: std::fmt::Debug,
391{
392    #[project]
393    fn poll(
394        self: Pin<&mut Self>,
395        cx: &mut Context<'_>,
396    ) -> Either<
397        FramedServiceImplResponseInner<St, Io, Codec, C, T>,
398        Poll<Result<(), ServiceError<C::Error, Codec>>>,
399    > {
400        #[project]
401        match self.project() {
402            FramedServiceImplResponseInner::Connect(fut, handler, disconnect, rx) => {
403                match fut.poll(cx) {
404                    Poll::Ready(Ok(res)) => {
405                        Either::Left(FramedServiceImplResponseInner::Handler(
406                            handler.new_service(res.state.clone()),
407                            Some(res),
408                            disconnect.take(),
409                            rx.take(),
410                        ))
411                    }
412                    Poll::Pending => Either::Right(Poll::Pending),
413                    Poll::Ready(Err(e)) => Either::Right(Poll::Ready(Err(e.into()))),
414                }
415            }
416            FramedServiceImplResponseInner::Handler(fut, res, disconnect, rx) => {
417                match fut.poll(cx) {
418                    Poll::Ready(Ok(handler)) => {
419                        let res = res.take().unwrap();
420                        Either::Left(FramedServiceImplResponseInner::Dispatcher(
421                            Dispatcher::new(
422                                res.framed,
423                                res.state,
424                                handler,
425                                res.sink,
426                                rx.take().unwrap(),
427                                disconnect.take(),
428                            ),
429                        ))
430                    }
431                    Poll::Pending => Either::Right(Poll::Pending),
432                    Poll::Ready(Err(e)) => Either::Right(Poll::Ready(Err(e.into()))),
433                }
434            }
435            FramedServiceImplResponseInner::Dispatcher(ref mut fut) => {
436                Either::Right(fut.poll(cx))
437            }
438        }
439    }
440}