trillium_websockets/
lib.rs1#![forbid(unsafe_code)]
2#![deny(
3 clippy::dbg_macro,
4 missing_copy_implementations,
5 rustdoc::missing_crate_level_docs,
6 missing_debug_implementations,
7 missing_docs,
8 nonstandard_style,
9 unused_qualifications
10)]
11
12mod bidirectional_stream;
57mod websocket_connection;
58mod websocket_handler;
59
60use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
61use bidirectional_stream::{BidirectionalStream, Direction};
62use futures_lite::stream::StreamExt;
63use sha1::{Digest, Sha1};
64use std::{
65 net::IpAddr,
66 ops::{Deref, DerefMut},
67};
68use trillium::{
69 Conn, Handler,
70 KnownHeaderName::{
71 Connection, SecWebsocketAccept, SecWebsocketKey, SecWebsocketProtocol, SecWebsocketVersion,
72 Upgrade as UpgradeHeader,
73 },
74 Status, Upgrade,
75};
76
77pub use async_tungstenite::{
78 self,
79 tungstenite::{
80 self,
81 protocol::{Role, WebSocketConfig},
82 Message,
83 },
84};
85pub use trillium::async_trait;
86pub use websocket_connection::WebSocketConn;
87pub use websocket_handler::WebSocketHandler;
88
89const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
90
91#[derive(thiserror::Error, Debug)]
92#[non_exhaustive]
93pub enum Error {
96 #[error(transparent)]
97 WebSocket(#[from] tungstenite::Error),
99
100 #[cfg(feature = "json")]
101 #[error(transparent)]
102 Json(#[from] serde_json::Error),
104}
105
106pub type Result<T = Message> = std::result::Result<T, Error>;
108
109#[cfg(feature = "json")]
110mod json;
111
112#[cfg(feature = "json")]
113pub use json::{json_websocket, JsonHandler, JsonWebSocketHandler};
114
115#[derive(Debug)]
120pub struct WebSocket<H> {
121 handler: H,
122 protocols: Vec<String>,
123 config: Option<WebSocketConfig>,
124 required: bool,
125}
126
127impl<H> Deref for WebSocket<H> {
128 type Target = H;
129
130 fn deref(&self) -> &Self::Target {
131 &self.handler
132 }
133}
134
135impl<H> DerefMut for WebSocket<H> {
136 fn deref_mut(&mut self) -> &mut Self::Target {
137 &mut self.handler
138 }
139}
140
141pub fn websocket<H>(websocket_handler: H) -> WebSocket<H>
146where
147 H: WebSocketHandler,
148{
149 WebSocket::new(websocket_handler)
150}
151
152impl<H> WebSocket<H>
153where
154 H: WebSocketHandler,
155{
156 pub fn new(handler: H) -> Self {
159 Self {
160 handler,
161 protocols: Default::default(),
162 config: None,
163 required: false,
164 }
165 }
166
167 pub fn with_protocols(self, protocols: &[&str]) -> Self {
171 Self {
172 protocols: protocols.iter().map(ToString::to_string).collect(),
173 ..self
174 }
175 }
176
177 pub fn with_protocol_config(self, config: WebSocketConfig) -> Self {
179 Self {
180 config: Some(config),
181 ..self
182 }
183 }
184
185 pub fn required(mut self) -> Self {
188 self.required = true;
189 self
190 }
191}
192
193struct IsWebsocket;
194
195#[cfg(test)]
196mod tests;
197
198struct WebsocketPeerIp(Option<IpAddr>);
202
203#[async_trait]
204impl<H> Handler for WebSocket<H>
205where
206 H: WebSocketHandler,
207{
208 async fn run(&self, mut conn: Conn) -> Conn {
209 if !upgrade_requested(&conn) {
210 if self.required {
211 return conn.with_status(Status::UpgradeRequired).halt();
212 } else {
213 return conn;
214 }
215 }
216
217 let websocket_peer_ip = WebsocketPeerIp(conn.peer_ip());
218
219 let Some(sec_websocket_key) = conn.request_headers().get_str(SecWebsocketKey) else {
220 return conn.with_status(Status::BadRequest).halt();
221 };
222 let sec_websocket_accept = websocket_accept_hash(sec_websocket_key);
223
224 let protocol = websocket_protocol(&conn, &self.protocols);
225
226 let headers = conn.response_headers_mut();
227
228 headers.extend([
229 (UpgradeHeader, "websocket"),
230 (Connection, "Upgrade"),
231 (SecWebsocketVersion, "13"),
232 ]);
233
234 headers.insert(SecWebsocketAccept, sec_websocket_accept);
235
236 if let Some(protocol) = protocol {
237 headers.insert(SecWebsocketProtocol, protocol);
238 }
239
240 conn.halt()
241 .with_state(websocket_peer_ip)
242 .with_state(IsWebsocket)
243 .with_status(Status::SwitchingProtocols)
244 }
245
246 fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
247 upgrade.state().contains::<IsWebsocket>()
248 }
249
250 async fn upgrade(&self, mut upgrade: Upgrade) {
251 let peer_ip = upgrade.state.take::<WebsocketPeerIp>().and_then(|i| i.0);
252 let mut conn = WebSocketConn::new(upgrade, self.config, Role::Server).await;
253 conn.set_peer_ip(peer_ip);
254
255 let Some((mut conn, outbound)) = self.handler.connect(conn).await else {
256 return;
257 };
258
259 let inbound = conn.take_inbound_stream();
260
261 let mut stream = std::pin::pin!(BidirectionalStream { inbound, outbound });
262 while let Some(message) = stream.next().await {
263 match message {
264 Direction::Inbound(Ok(Message::Close(close_frame))) => {
265 self.handler.disconnect(&mut conn, close_frame).await;
266 break;
267 }
268
269 Direction::Inbound(Ok(message)) => {
270 self.handler.inbound(message, &mut conn).await;
271 }
272
273 Direction::Outbound(message) => {
274 if let Err(e) = self.handler.send(message, &mut conn).await {
275 log::warn!("outbound websocket error: {:?}", e);
276 break;
277 }
278 }
279
280 _ => {
281 self.handler.disconnect(&mut conn, None).await;
282 break;
283 }
284 }
285 }
286
287 if let Some(err) = conn.close().await.err() {
288 log::warn!("websocket close error: {:?}", err);
289 };
290 }
291}
292
293fn websocket_protocol(conn: &Conn, protocols: &[String]) -> Option<String> {
294 conn.request_headers()
295 .get_str(SecWebsocketProtocol)
296 .and_then(|value| {
297 value
298 .split(',')
299 .map(str::trim)
300 .find(|req_p| protocols.iter().any(|x| x == req_p))
301 .map(|s| s.to_owned())
302 })
303}
304
305fn connection_is_upgrade(conn: &Conn) -> bool {
306 conn.request_headers()
307 .get_str(Connection)
308 .map(|connection| {
309 connection
310 .split(',')
311 .any(|c| c.trim().eq_ignore_ascii_case("upgrade"))
312 })
313 .unwrap_or(false)
314}
315
316fn upgrade_to_websocket(conn: &Conn) -> bool {
317 conn.request_headers()
318 .eq_ignore_ascii_case(UpgradeHeader, "websocket")
319}
320
321fn upgrade_requested(conn: &Conn) -> bool {
322 connection_is_upgrade(conn) && upgrade_to_websocket(conn)
323}
324
325pub fn websocket_key() -> String {
327 BASE64.encode(fastrand::u128(..).to_ne_bytes())
328}
329
330pub fn websocket_accept_hash(websocket_key: &str) -> String {
332 let hash = Sha1::new()
333 .chain_update(websocket_key)
334 .chain_update(WEBSOCKET_GUID)
335 .finalize();
336 BASE64.encode(&hash[..])
337}