meio_connect/
talker.rs

1use anyhow::Error;
2use futures::channel::mpsc;
3use futures::stream::Fuse;
4use futures::{select, Sink, SinkExt, Stream, StreamExt};
5use meio::prelude::{Action, ActionHandler, Actor, Address, StopReceiver};
6use meio_protocol::{ProtocolCodec, ProtocolData};
7use serde::ser::StdError;
8use std::fmt::Debug;
9use tungstenite::{error::Error as TungError, Message as TungMessage};
10
11/// Incoming message of the choosen `Protocol`.
12#[derive(Debug)]
13pub struct WsIncoming<T: ProtocolData>(pub T);
14
15impl<T: ProtocolData> Action for WsIncoming<T> {}
16
17/// The reason of connection termination.
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub enum TermReason {
20    /// The connection was interrupted.
21    Interrupted,
22    /// The connection was closed (properly).
23    Closed,
24}
25
26impl TermReason {
27    /// Was the handler interrupted?
28    pub fn is_interrupted(&self) -> bool {
29        *self == Self::Interrupted
30    }
31}
32
33/// The abstract error for WebSockets.
34pub trait WsError: Debug + StdError + Sync + Send + 'static {}
35
36impl WsError for TungError {}
37
38/// The abstract `WebSocket` message.
39///
40/// The trait provides generic metods to `Talker`s to work with messages.
41pub trait WsMessage: Debug + Sized {
42    /// Creates a message instance from a binary data.
43    fn binary(data: Vec<u8>) -> Self;
44    /// Is it the ping message?
45    fn is_ping(&self) -> bool;
46    /// Is it the ping message?
47    fn is_pong(&self) -> bool;
48    /// Is it the text message?
49    fn is_text(&self) -> bool;
50    /// Is it the binary message?
51    fn is_binary(&self) -> bool;
52    /// Is it the closing signal message?
53    fn is_close(&self) -> bool;
54    /// Convert the message into bytes.
55    fn into_bytes(self) -> Vec<u8>;
56}
57
58impl WsMessage for TungMessage {
59    fn binary(data: Vec<u8>) -> Self {
60        TungMessage::binary(data)
61    }
62    fn is_ping(&self) -> bool {
63        TungMessage::is_ping(self)
64    }
65    fn is_pong(&self) -> bool {
66        TungMessage::is_pong(self)
67    }
68    fn is_text(&self) -> bool {
69        TungMessage::is_text(self)
70    }
71    fn is_binary(&self) -> bool {
72        TungMessage::is_binary(self)
73    }
74    fn is_close(&self) -> bool {
75        TungMessage::is_close(self)
76    }
77    fn into_bytes(self) -> Vec<u8> {
78        TungMessage::into_data(self)
79    }
80}
81
82pub trait TalkerCompatible {
83    type WebSocket: Stream<Item = Result<Self::Message, Self::Error>>
84        + Sink<Self::Message, Error = Self::Error>
85        + Unpin;
86    type Message: WsMessage;
87    type Error: WsError;
88    type Actor: Actor + ActionHandler<WsIncoming<Self::Incoming>>;
89    type Codec: ProtocolCodec;
90    type Incoming: ProtocolData;
91    type Outgoing: ProtocolData;
92}
93
94pub struct Talker<T: TalkerCompatible> {
95    log_target: String,
96    address: Address<T::Actor>,
97    connection: Fuse<T::WebSocket>,
98    rx: mpsc::UnboundedReceiver<T::Outgoing>,
99    stop: StopReceiver,
100    rx_drained: bool,
101    connection_drained: bool,
102    interrupted: bool,
103}
104
105impl<T: TalkerCompatible> Talker<T> {
106    pub fn new(
107        base_log_target: &str,
108        address: Address<T::Actor>,
109        connection: T::WebSocket,
110        rx: mpsc::UnboundedReceiver<T::Outgoing>,
111        stop: StopReceiver,
112    ) -> Self {
113        let log_target = format!("{}::Talker", base_log_target);
114        Self {
115            log_target,
116            address,
117            connection: connection.fuse(),
118            rx,
119            stop,
120            rx_drained: false,
121            connection_drained: false,
122            interrupted: false,
123        }
124    }
125
126    fn is_done(&self) -> bool {
127        self.rx_drained && self.connection_drained
128    }
129
130    pub async fn routine(&mut self) -> Result<TermReason, Error> {
131        let mut done = self.stop.clone().into_future();
132        loop {
133            select! {
134                _ = done => {
135                    self.interrupted = true;
136                    // Just close the channel and wait when it will be drained
137                    self.rx.close();
138                }
139                request = self.connection.next() => {
140                    let msg = request.transpose()?;
141                    if let Some(msg) = msg {
142                        if msg.is_text() || msg.is_binary() {
143                            let decoded = T::Codec::decode(&msg.into_bytes())?;
144                            log::trace!(target: &self.log_target, "MEIO-WS-RECV: {:?}", decoded);
145                            let msg = WsIncoming(decoded);
146                            self.address.act(msg)?;
147                        } else if msg.is_ping() || msg.is_pong() {
148                            // Ignore Ping and Pong messages
149                        } else if msg.is_close() {
150                            log::trace!(target: &self.log_target, "Close message received. Draining the channel...");
151                            // Start draining that will close the connection.
152                            // No more messages expected. The receiver can be safely closed.
153                            self.rx.close();
154                        } else {
155                            log::warn!(target: &self.log_target, "Unhandled WebSocket message: {:?}", msg);
156                        }
157                    } else {
158                        // The connection was closed, further interaction doesn't make sense
159                        log::trace!(target: &self.log_target, "Connection phisically closed.");
160                        self.connection_drained = true;
161                        if self.is_done() {
162                            break;
163                        }
164                    }
165                }
166                response = self.rx.next() => {
167                    if let Some(msg) = response {
168                        log::trace!(target: &self.log_target, "MEIO-WS-SEND: {:?}", msg);
169                        let encoded = T::Codec::encode(&msg)?;
170                        let message = T::Message::binary(encoded);
171                        self.connection.send(message).await?;
172                    } else {
173                        log::trace!(target: &self.log_target, "Channel with outgoing data closed. Terminating a session with the client.");
174                        log::trace!(target: &self.log_target, "Sending close notification to the client.");
175                        self.connection.close().await?;
176                        self.rx_drained = true;
177                        if self.is_done() {
178                            break;
179                        }
180                    }
181                }
182            }
183        }
184        if self.interrupted {
185            Ok(TermReason::Interrupted)
186        } else {
187            Ok(TermReason::Closed)
188        }
189    }
190}