1use anyhow::Error;
2use futures::channel::mpsc;
3use futures::stream::Fuse;
4use futures::{select, Sink, SinkExt, Stream, StreamExt};
5use meio::prelude::{Action, ActionHandler, Actor, Address, StopReceiver};
6use meio_protocol::{ProtocolCodec, ProtocolData};
7use serde::ser::StdError;
8use std::fmt::Debug;
9use tungstenite::{error::Error as TungError, Message as TungMessage};
10
11#[derive(Debug)]
13pub struct WsIncoming<T: ProtocolData>(pub T);
14
15impl<T: ProtocolData> Action for WsIncoming<T> {}
16
17#[derive(Debug, Clone, PartialEq, Eq)]
19pub enum TermReason {
20 Interrupted,
22 Closed,
24}
25
26impl TermReason {
27 pub fn is_interrupted(&self) -> bool {
29 *self == Self::Interrupted
30 }
31}
32
33pub trait WsError: Debug + StdError + Sync + Send + 'static {}
35
36impl WsError for TungError {}
37
38pub trait WsMessage: Debug + Sized {
42 fn binary(data: Vec<u8>) -> Self;
44 fn is_ping(&self) -> bool;
46 fn is_pong(&self) -> bool;
48 fn is_text(&self) -> bool;
50 fn is_binary(&self) -> bool;
52 fn is_close(&self) -> bool;
54 fn into_bytes(self) -> Vec<u8>;
56}
57
58impl WsMessage for TungMessage {
59 fn binary(data: Vec<u8>) -> Self {
60 TungMessage::binary(data)
61 }
62 fn is_ping(&self) -> bool {
63 TungMessage::is_ping(self)
64 }
65 fn is_pong(&self) -> bool {
66 TungMessage::is_pong(self)
67 }
68 fn is_text(&self) -> bool {
69 TungMessage::is_text(self)
70 }
71 fn is_binary(&self) -> bool {
72 TungMessage::is_binary(self)
73 }
74 fn is_close(&self) -> bool {
75 TungMessage::is_close(self)
76 }
77 fn into_bytes(self) -> Vec<u8> {
78 TungMessage::into_data(self)
79 }
80}
81
82pub trait TalkerCompatible {
83 type WebSocket: Stream<Item = Result<Self::Message, Self::Error>>
84 + Sink<Self::Message, Error = Self::Error>
85 + Unpin;
86 type Message: WsMessage;
87 type Error: WsError;
88 type Actor: Actor + ActionHandler<WsIncoming<Self::Incoming>>;
89 type Codec: ProtocolCodec;
90 type Incoming: ProtocolData;
91 type Outgoing: ProtocolData;
92}
93
94pub struct Talker<T: TalkerCompatible> {
95 log_target: String,
96 address: Address<T::Actor>,
97 connection: Fuse<T::WebSocket>,
98 rx: mpsc::UnboundedReceiver<T::Outgoing>,
99 stop: StopReceiver,
100 rx_drained: bool,
101 connection_drained: bool,
102 interrupted: bool,
103}
104
105impl<T: TalkerCompatible> Talker<T> {
106 pub fn new(
107 base_log_target: &str,
108 address: Address<T::Actor>,
109 connection: T::WebSocket,
110 rx: mpsc::UnboundedReceiver<T::Outgoing>,
111 stop: StopReceiver,
112 ) -> Self {
113 let log_target = format!("{}::Talker", base_log_target);
114 Self {
115 log_target,
116 address,
117 connection: connection.fuse(),
118 rx,
119 stop,
120 rx_drained: false,
121 connection_drained: false,
122 interrupted: false,
123 }
124 }
125
126 fn is_done(&self) -> bool {
127 self.rx_drained && self.connection_drained
128 }
129
130 pub async fn routine(&mut self) -> Result<TermReason, Error> {
131 let mut done = self.stop.clone().into_future();
132 loop {
133 select! {
134 _ = done => {
135 self.interrupted = true;
136 self.rx.close();
138 }
139 request = self.connection.next() => {
140 let msg = request.transpose()?;
141 if let Some(msg) = msg {
142 if msg.is_text() || msg.is_binary() {
143 let decoded = T::Codec::decode(&msg.into_bytes())?;
144 log::trace!(target: &self.log_target, "MEIO-WS-RECV: {:?}", decoded);
145 let msg = WsIncoming(decoded);
146 self.address.act(msg)?;
147 } else if msg.is_ping() || msg.is_pong() {
148 } else if msg.is_close() {
150 log::trace!(target: &self.log_target, "Close message received. Draining the channel...");
151 self.rx.close();
154 } else {
155 log::warn!(target: &self.log_target, "Unhandled WebSocket message: {:?}", msg);
156 }
157 } else {
158 log::trace!(target: &self.log_target, "Connection phisically closed.");
160 self.connection_drained = true;
161 if self.is_done() {
162 break;
163 }
164 }
165 }
166 response = self.rx.next() => {
167 if let Some(msg) = response {
168 log::trace!(target: &self.log_target, "MEIO-WS-SEND: {:?}", msg);
169 let encoded = T::Codec::encode(&msg)?;
170 let message = T::Message::binary(encoded);
171 self.connection.send(message).await?;
172 } else {
173 log::trace!(target: &self.log_target, "Channel with outgoing data closed. Terminating a session with the client.");
174 log::trace!(target: &self.log_target, "Sending close notification to the client.");
175 self.connection.close().await?;
176 self.rx_drained = true;
177 if self.is_done() {
178 break;
179 }
180 }
181 }
182 }
183 }
184 if self.interrupted {
185 Ok(TermReason::Interrupted)
186 } else {
187 Ok(TermReason::Closed)
188 }
189 }
190}