Skip to main content

actix_ioframe/
dispatcher.rs

1//! Framed dispatcher service and related utilities
2use std::pin::Pin;
3use std::rc::Rc;
4use std::task::{Context, Poll};
5
6use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed};
7use actix_service::{IntoService, Service};
8use actix_utils::{mpsc, oneshot};
9use futures::{FutureExt, Stream};
10use log::debug;
11
12use crate::error::ServiceError;
13use crate::item::Item;
14use crate::sink::Sink;
15
16type Request<S, U> = Item<S, U>;
17type Response<U> = <U as Encoder>::Item;
18
19pub(crate) enum Message<T> {
20    Item(T),
21    WaitClose(oneshot::Sender<()>),
22    Close,
23}
24
25/// FramedTransport - is a future that reads frames from Framed object
26/// and pass then to the service.
27#[pin_project::pin_project]
28pub(crate) struct Dispatcher<St, S, T, U>
29where
30    St: Clone,
31    S: Service<Request = Request<St, U>, Response = Option<Response<U>>>,
32    S::Error: 'static,
33    S::Future: 'static,
34    T: AsyncRead + AsyncWrite,
35    U: Encoder + Decoder,
36    <U as Encoder>::Item: 'static,
37    <U as Encoder>::Error: std::fmt::Debug,
38{
39    service: S,
40    sink: Sink<<U as Encoder>::Item>,
41    state: St,
42    dispatch_state: FramedState<S, U>,
43    framed: Framed<T, U>,
44    rx: mpsc::Receiver<Result<Message<<U as Encoder>::Item>, S::Error>>,
45    tx: mpsc::Sender<Result<Message<<U as Encoder>::Item>, S::Error>>,
46    disconnect: Option<Rc<dyn Fn(St, bool)>>,
47}
48
49impl<St, S, T, U> Dispatcher<St, S, T, U>
50where
51    St: Clone,
52    S: Service<Request = Request<St, U>, Response = Option<Response<U>>>,
53    S::Error: 'static,
54    S::Future: 'static,
55    T: AsyncRead + AsyncWrite,
56    U: Decoder + Encoder,
57    <U as Encoder>::Item: 'static,
58    <U as Encoder>::Error: std::fmt::Debug,
59{
60    pub(crate) fn new<F: IntoService<S>>(
61        framed: Framed<T, U>,
62        state: St,
63        service: F,
64        sink: Sink<<U as Encoder>::Item>,
65        rx: mpsc::Receiver<Result<Message<<U as Encoder>::Item>, S::Error>>,
66        disconnect: Option<Rc<dyn Fn(St, bool)>>,
67    ) -> Self {
68        let tx = rx.sender();
69
70        Dispatcher {
71            framed,
72            state,
73            sink,
74            disconnect,
75            rx,
76            tx,
77            service: service.into_service(),
78            dispatch_state: FramedState::Processing,
79        }
80    }
81}
82
83enum FramedState<S: Service, U: Encoder + Decoder> {
84    Processing,
85    Error(ServiceError<S::Error, U>),
86    FramedError(ServiceError<S::Error, U>),
87    FlushAndStop(Vec<oneshot::Sender<()>>),
88    Stopping,
89}
90
91impl<S: Service, U: Encoder + Decoder> FramedState<S, U> {
92    fn stop(&mut self, tx: Option<oneshot::Sender<()>>) {
93        match self {
94            FramedState::FlushAndStop(ref mut vec) => {
95                if let Some(tx) = tx {
96                    vec.push(tx)
97                }
98            }
99            FramedState::Processing => {
100                *self = FramedState::FlushAndStop(if let Some(tx) = tx {
101                    vec![tx]
102                } else {
103                    Vec::new()
104                })
105            }
106            FramedState::Error(_) | FramedState::FramedError(_) | FramedState::Stopping => {
107                if let Some(tx) = tx {
108                    let _ = tx.send(());
109                }
110            }
111        }
112    }
113
114    fn take_error(&mut self) -> ServiceError<S::Error, U> {
115        match std::mem::replace(self, FramedState::Processing) {
116            FramedState::Error(err) => err,
117            _ => panic!(),
118        }
119    }
120
121    fn take_framed_error(&mut self) -> ServiceError<S::Error, U> {
122        match std::mem::replace(self, FramedState::Processing) {
123            FramedState::FramedError(err) => err,
124            _ => panic!(),
125        }
126    }
127}
128
129impl<St, S, T, U> Dispatcher<St, S, T, U>
130where
131    St: Clone,
132    S: Service<Request = Request<St, U>, Response = Option<Response<U>>>,
133    S::Error: 'static,
134    S::Future: 'static,
135    T: AsyncRead + AsyncWrite,
136    U: Decoder + Encoder,
137    <U as Encoder>::Item: 'static,
138    <U as Encoder>::Error: std::fmt::Debug,
139{
140    fn poll_read(&mut self, cx: &mut Context<'_>) -> bool {
141        loop {
142            match self.service.poll_ready(cx) {
143                Poll::Ready(Ok(_)) => {
144                    let item = match self.framed.next_item(cx) {
145                        Poll::Ready(Some(Ok(el))) => el,
146                        Poll::Ready(Some(Err(err))) => {
147                            self.dispatch_state =
148                                FramedState::FramedError(ServiceError::Decoder(err));
149                            return true;
150                        }
151                        Poll::Pending => return false,
152                        Poll::Ready(None) => {
153                            log::trace!("Client disconnected");
154                            self.dispatch_state = FramedState::Stopping;
155                            return true;
156                        }
157                    };
158
159                    let tx = self.tx.clone();
160                    actix_rt::spawn(
161                        self.service
162                            .call(Item::new(self.state.clone(), self.sink.clone(), item))
163                            .map(move |item| {
164                                let item = match item {
165                                    Ok(Some(item)) => Ok(Message::Item(item)),
166                                    Ok(None) => return,
167                                    Err(err) => Err(err),
168                                };
169                                let _ = tx.send(item);
170                            }),
171                    );
172                }
173                Poll::Pending => return false,
174                Poll::Ready(Err(err)) => {
175                    self.dispatch_state = FramedState::Error(ServiceError::Service(err));
176                    return true;
177                }
178            }
179        }
180    }
181
182    /// write to framed object
183    fn poll_write(&mut self, cx: &mut Context<'_>) -> bool {
184        loop {
185            while !self.framed.is_write_buf_full() {
186                match Pin::new(&mut self.rx).poll_next(cx) {
187                    Poll::Ready(Some(Ok(Message::Item(msg)))) => {
188                        if let Err(err) = self.framed.write(msg) {
189                            self.dispatch_state =
190                                FramedState::FramedError(ServiceError::Encoder(err));
191                            return true;
192                        }
193                    }
194                    Poll::Ready(Some(Ok(Message::Close))) => {
195                        self.dispatch_state.stop(None);
196                        return true;
197                    }
198                    Poll::Ready(Some(Ok(Message::WaitClose(tx)))) => {
199                        self.dispatch_state.stop(Some(tx));
200                        return true;
201                    }
202                    Poll::Ready(Some(Err(err))) => {
203                        self.dispatch_state = FramedState::Error(ServiceError::Service(err));
204                        return true;
205                    }
206                    Poll::Ready(None) | Poll::Pending => break,
207                }
208            }
209
210            if !self.framed.is_write_buf_empty() {
211                match self.framed.flush(cx) {
212                    Poll::Pending => break,
213                    Poll::Ready(Ok(_)) => (),
214                    Poll::Ready(Err(err)) => {
215                        debug!("Error sending data: {:?}", err);
216                        self.dispatch_state =
217                            FramedState::FramedError(ServiceError::Encoder(err));
218                        return true;
219                    }
220                }
221            } else {
222                break;
223            }
224        }
225        false
226    }
227
228    pub(crate) fn poll(
229        &mut self,
230        cx: &mut Context<'_>,
231    ) -> Poll<Result<(), ServiceError<S::Error, U>>> {
232        match self.dispatch_state {
233            FramedState::Processing => {
234                if self.poll_read(cx) || self.poll_write(cx) {
235                    self.poll(cx)
236                } else {
237                    Poll::Pending
238                }
239            }
240            FramedState::Error(_) => {
241                // flush write buffer
242                if !self.framed.is_write_buf_empty() {
243                    if let Poll::Pending = self.framed.flush(cx) {
244                        return Poll::Pending;
245                    }
246                }
247                if let Some(ref disconnect) = self.disconnect {
248                    (&*disconnect)(self.state.clone(), true);
249                }
250                Poll::Ready(Err(self.dispatch_state.take_error()))
251            }
252            FramedState::FlushAndStop(ref mut vec) => {
253                if !self.framed.is_write_buf_empty() {
254                    match self.framed.flush(cx) {
255                        Poll::Ready(Err(err)) => {
256                            debug!("Error sending data: {:?}", err);
257                        }
258                        Poll::Pending => {
259                            return Poll::Pending;
260                        }
261                        Poll::Ready(_) => (),
262                    }
263                };
264                for tx in vec.drain(..) {
265                    let _ = tx.send(());
266                }
267                if let Some(ref disconnect) = self.disconnect {
268                    (&*disconnect)(self.state.clone(), false);
269                }
270                Poll::Ready(Ok(()))
271            }
272            FramedState::FramedError(_) => {
273                if let Some(ref disconnect) = self.disconnect {
274                    (&*disconnect)(self.state.clone(), true);
275                }
276                Poll::Ready(Err(self.dispatch_state.take_framed_error()))
277            }
278            FramedState::Stopping => {
279                if let Some(ref disconnect) = self.disconnect {
280                    (&*disconnect)(self.state.clone(), false);
281                }
282                Poll::Ready(Ok(()))
283            }
284        }
285    }
286}