hydra_websockets/
websocket_handler.rs1use std::future::Future;
2use std::net::SocketAddr;
3
4use futures_util::SinkExt;
5use futures_util::StreamExt;
6use futures_util::stream;
7
8use tokio::io::AsyncRead;
9use tokio::io::AsyncWrite;
10
11use tokio_tungstenite::WebSocketStream;
12use tokio_tungstenite::tungstenite::Error;
13use tokio_tungstenite::tungstenite::protocol::CloseFrame;
14
15use hydra::ExitReason;
16use hydra::Message;
17use hydra::Process;
18use hydra::Receivable;
19
20use crate::WebsocketCommand;
21use crate::WebsocketCommands;
22use crate::WebsocketMessage;
23use crate::WebsocketRequest;
24use crate::WebsocketResponse;
25
26pub trait WebsocketHandler
28where
29 Self: Sized,
30{
31 type Message: Receivable;
33
34 fn accept(
38 address: SocketAddr,
39 request: &WebsocketRequest,
40 response: WebsocketResponse,
41 ) -> Result<(WebsocketResponse, Self), ExitReason>;
42
43 fn websocket_init(
47 &mut self,
48 ) -> impl Future<Output = Result<Option<WebsocketCommands>, ExitReason>> + Send {
49 async move { Ok(None) }
50 }
51
52 fn websocket_handle(
54 &mut self,
55 message: WebsocketMessage,
56 ) -> impl Future<Output = Result<Option<WebsocketCommands>, ExitReason>> + Send;
57
58 fn websocket_info(
60 &mut self,
61 info: Message<Self::Message>,
62 ) -> impl Future<Output = Result<Option<WebsocketCommands>, ExitReason>> + Send {
63 async move {
64 let _ = info;
65
66 Ok(None)
67 }
68 }
69
70 fn terminate(&mut self, reason: ExitReason) -> impl Future<Output = ()> + Send {
80 async move {
81 let _ = reason;
82 }
83 }
84}
85
86async fn websocket_process_commands<T, S>(
88 commands: WebsocketCommands,
89 handler: &mut T,
90 stream: &mut WebSocketStream<S>,
91) where
92 T: WebsocketHandler + Send + 'static,
93 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
94{
95 let mut close_command: Option<WebsocketCommand> = None;
96
97 let sends = commands.buffer.into_iter().filter_map(|command| {
98 if close_command.is_some() {
99 return None;
100 }
101
102 match command {
103 WebsocketCommand::Send(message) => Some(Ok(message)),
104 WebsocketCommand::Close(_, _) => {
105 close_command = Some(command);
106 None
107 }
108 }
109 });
110
111 let mut sends = stream::iter(sends);
112
113 if let Err(error) = stream.send_all(&mut sends).await {
114 handler.terminate(error_to_reason(&error)).await;
115
116 Process::exit(Process::current(), error_to_reason(&error))
117 }
118
119 if let Some(WebsocketCommand::Close(code, reason)) = close_command {
120 if let Err(error) = stream
121 .close(Some(CloseFrame {
122 code,
123 reason: reason.into(),
124 }))
125 .await
126 {
127 handler.terminate(error_to_reason(&error)).await;
128
129 Process::exit(Process::current(), error_to_reason(&error));
130 } else {
131 Process::exit(Process::current(), ExitReason::Normal);
132 }
133 }
134}
135
136pub(crate) async fn start_websocket_handler<T, S>(mut handler: T, mut stream: WebSocketStream<S>)
138where
139 T: WebsocketHandler + Send + 'static,
140 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
141{
142 match handler.websocket_init().await {
143 Ok(commands) => {
144 if let Some(commands) = commands {
145 websocket_process_commands(commands, &mut handler, &mut stream).await;
146 }
147 }
148 Err(reason) => {
149 return Process::exit(Process::current(), reason);
150 }
151 }
152
153 loop {
154 tokio::select! {
155 message = Process::receive::<T::Message>() => {
156 match handler.websocket_info(message).await {
157 Ok(commands) => {
158 if let Some(commands) = commands {
159 websocket_process_commands(commands, &mut handler, &mut stream).await;
160 }
161 }
162 Err(reason) => {
163 handler.terminate(reason.clone()).await;
164
165 return Process::exit(Process::current(), reason);
166 }
167 }
168 }
169 ws_message = stream.next() => {
170 let Some(ws_message) = ws_message else {
171 panic!("Websocket closed without close frame!");
172 };
173
174 match ws_message {
175 Ok(message) => {
176 let mut should_close = false;
177
178 match &message {
179 WebsocketMessage::Ping(data) => {
180 if let Err(error) = stream.send(WebsocketMessage::Pong(data.clone())).await {
181 handler.terminate(error_to_reason(&error)).await;
182
183 return Process::exit(Process::current(), error_to_reason(&error));
184 }
185 }
186 WebsocketMessage::Close(_) => {
187 should_close = true;
188 }
189 _ => {
190 }
192 }
193
194 match handler.websocket_handle(message).await {
195 Ok(commands) => {
196 if let Some(commands) = commands {
197 websocket_process_commands(commands, &mut handler, &mut stream).await;
198 }
199 }
200 Err(reason) => {
201 handler.terminate(reason.clone()).await;
202
203 return Process::exit(Process::current(), reason);
204 }
205 }
206
207 if should_close {
208 handler.terminate(ExitReason::from("connection_closed")).await;
209
210 return Process::exit(Process::current(), ExitReason::from("connection_closed"));
211 }
212 }
213 Err(error) => {
214 handler.terminate(error_to_reason(&error)).await;
215
216 return Process::exit(Process::current(), error_to_reason(&error));
217 }
218 }
219 }
220 }
221 }
222}
223
224fn error_to_reason(error: &Error) -> ExitReason {
226 match error {
227 Error::AlreadyClosed | Error::ConnectionClosed => ExitReason::from("connection_closed"),
228 Error::Io(_) => ExitReason::from("io_error"),
229 Error::Tls(_) => ExitReason::from("tls_error"),
230 Error::Utf8(_) => ExitReason::from("utf8_error"),
231 Error::AttackAttempt => ExitReason::from("attack_attempt"),
232 _ => ExitReason::from("unknown"),
233 }
234}