hydra_websockets/
websocket_handler.rs

1use std::future::Future;
2use std::net::SocketAddr;
3
4use futures_util::SinkExt;
5use futures_util::StreamExt;
6use futures_util::stream;
7
8use tokio::io::AsyncRead;
9use tokio::io::AsyncWrite;
10
11use tokio_tungstenite::WebSocketStream;
12use tokio_tungstenite::tungstenite::Error;
13use tokio_tungstenite::tungstenite::protocol::CloseFrame;
14
15use hydra::ExitReason;
16use hydra::Message;
17use hydra::Process;
18use hydra::Receivable;
19
20use crate::WebsocketCommand;
21use crate::WebsocketCommands;
22use crate::WebsocketMessage;
23use crate::WebsocketRequest;
24use crate::WebsocketResponse;
25
26/// A process that handles websocket messages.
27pub trait WebsocketHandler
28where
29    Self: Sized,
30{
31    /// The message type that this handler will use.
32    type Message: Receivable;
33
34    /// A callback used to accept or deny a request for a websocket upgrade.
35    ///
36    /// You can extract information from the request and put it in your handler state.
37    fn accept(
38        address: SocketAddr,
39        request: &WebsocketRequest,
40        response: WebsocketResponse,
41    ) -> Result<(WebsocketResponse, Self), ExitReason>;
42
43    /// An optional callback that happens before the first message is sent/received from the websocket.
44    ///
45    /// This is the first callback that happens in the process responsible for the websocket.
46    fn websocket_init(
47        &mut self,
48    ) -> impl Future<Output = Result<Option<WebsocketCommands>, ExitReason>> + Send {
49        async move { Ok(None) }
50    }
51
52    /// Invoked to handle messages received from the websocket.
53    fn websocket_handle(
54        &mut self,
55        message: WebsocketMessage,
56    ) -> impl Future<Output = Result<Option<WebsocketCommands>, ExitReason>> + Send;
57
58    /// Invoked to handle messages from processes and system messages.
59    fn websocket_info(
60        &mut self,
61        info: Message<Self::Message>,
62    ) -> impl Future<Output = Result<Option<WebsocketCommands>, ExitReason>> + Send {
63        async move {
64            let _ = info;
65
66            Ok(None)
67        }
68    }
69
70    /// Invoked when the handler is about to exit. It should do any cleanup required.
71    ///
72    /// `terminate` is useful for cleanup that requires access to the [WebsocketHandler]'s state. However, it is not
73    /// guaranteed that `terminate` is called when a [WebsocketHandler] exits. Therefore, important cleanup should be done
74    /// using process links and/or monitors. A monitoring process will receive the same `reason` that would be passed to `terminate`.
75    ///
76    /// `terminate` is called if:
77    /// - The websocket connection closes for whatever reason.
78    /// - A callback (except `accept`) returns stop with a given reason.
79    fn terminate(&mut self, reason: ExitReason) -> impl Future<Output = ()> + Send {
80        async move {
81            let _ = reason;
82        }
83    }
84}
85
86/// Internal routine to process commands from a [WebsocketHandler] callback.
87async fn websocket_process_commands<T, S>(
88    commands: WebsocketCommands,
89    handler: &mut T,
90    stream: &mut WebSocketStream<S>,
91) where
92    T: WebsocketHandler + Send + 'static,
93    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
94{
95    let mut close_command: Option<WebsocketCommand> = None;
96
97    let sends = commands.buffer.into_iter().filter_map(|command| {
98        if close_command.is_some() {
99            return None;
100        }
101
102        match command {
103            WebsocketCommand::Send(message) => Some(Ok(message)),
104            WebsocketCommand::Close(_, _) => {
105                close_command = Some(command);
106                None
107            }
108        }
109    });
110
111    let mut sends = stream::iter(sends);
112
113    if let Err(error) = stream.send_all(&mut sends).await {
114        handler.terminate(error_to_reason(&error)).await;
115
116        Process::exit(Process::current(), error_to_reason(&error))
117    }
118
119    if let Some(WebsocketCommand::Close(code, reason)) = close_command {
120        if let Err(error) = stream
121            .close(Some(CloseFrame {
122                code,
123                reason: reason.into(),
124            }))
125            .await
126        {
127            handler.terminate(error_to_reason(&error)).await;
128
129            Process::exit(Process::current(), error_to_reason(&error));
130        } else {
131            Process::exit(Process::current(), ExitReason::Normal);
132        }
133    }
134}
135
136/// Internal [WebsocketHandler] start routine.
137pub(crate) async fn start_websocket_handler<T, S>(mut handler: T, mut stream: WebSocketStream<S>)
138where
139    T: WebsocketHandler + Send + 'static,
140    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
141{
142    match handler.websocket_init().await {
143        Ok(commands) => {
144            if let Some(commands) = commands {
145                websocket_process_commands(commands, &mut handler, &mut stream).await;
146            }
147        }
148        Err(reason) => {
149            return Process::exit(Process::current(), reason);
150        }
151    }
152
153    loop {
154        tokio::select! {
155            message = Process::receive::<T::Message>() => {
156                match handler.websocket_info(message).await {
157                    Ok(commands) => {
158                        if let Some(commands) = commands {
159                            websocket_process_commands(commands, &mut handler, &mut stream).await;
160                        }
161                    }
162                    Err(reason) => {
163                        handler.terminate(reason.clone()).await;
164
165                        return Process::exit(Process::current(), reason);
166                    }
167                }
168            }
169            ws_message = stream.next() => {
170                let Some(ws_message) = ws_message else {
171                    panic!("Websocket closed without close frame!");
172                };
173
174                match ws_message {
175                    Ok(message) => {
176                        let mut should_close = false;
177
178                        match &message {
179                            WebsocketMessage::Ping(data) => {
180                                if let Err(error) = stream.send(WebsocketMessage::Pong(data.clone())).await {
181                                    handler.terminate(error_to_reason(&error)).await;
182
183                                    return Process::exit(Process::current(), error_to_reason(&error));
184                                }
185                            }
186                            WebsocketMessage::Close(_) => {
187                                should_close = true;
188                            }
189                            _ => {
190                                // No special handling.
191                            }
192                        }
193
194                        match handler.websocket_handle(message).await {
195                            Ok(commands) => {
196                                if let Some(commands) = commands {
197                                    websocket_process_commands(commands, &mut handler, &mut stream).await;
198                                }
199                            }
200                            Err(reason) => {
201                                handler.terminate(reason.clone()).await;
202
203                                return Process::exit(Process::current(), reason);
204                            }
205                        }
206
207                        if should_close {
208                            handler.terminate(ExitReason::from("connection_closed")).await;
209
210                            return Process::exit(Process::current(), ExitReason::from("connection_closed"));
211                        }
212                    }
213                    Err(error) => {
214                        handler.terminate(error_to_reason(&error)).await;
215
216                        return Process::exit(Process::current(), error_to_reason(&error));
217                    }
218                }
219            }
220        }
221    }
222}
223
224/// Converts a websocket error to an exit reason.
225fn error_to_reason(error: &Error) -> ExitReason {
226    match error {
227        Error::AlreadyClosed | Error::ConnectionClosed => ExitReason::from("connection_closed"),
228        Error::Io(_) => ExitReason::from("io_error"),
229        Error::Tls(_) => ExitReason::from("tls_error"),
230        Error::Utf8(_) => ExitReason::from("utf8_error"),
231        Error::AttackAttempt => ExitReason::from("attack_attempt"),
232        _ => ExitReason::from("unknown"),
233    }
234}