1#![warn(
63 clippy::all,
64 clippy::dbg_macro,
65 clippy::todo,
66 clippy::empty_enum,
67 clippy::enum_glob_use,
68 clippy::mem_forget,
69 clippy::unused_self,
70 clippy::filter_map_next,
71 clippy::needless_continue,
72 clippy::needless_borrow,
73 clippy::match_wildcard_for_single_variants,
74 clippy::if_let_mutex,
75 clippy::mismatched_target_os,
76 clippy::await_holding_lock,
77 clippy::match_on_vec_items,
78 clippy::imprecise_flops,
79 clippy::suboptimal_flops,
80 clippy::lossy_float_literal,
81 clippy::rest_pat_in_fully_bound_structs,
82 clippy::fn_params_excessive_bools,
83 clippy::exit,
84 clippy::inefficient_to_string,
85 clippy::linkedlist,
86 clippy::macro_use_imports,
87 clippy::option_option,
88 clippy::verbose_file_reads,
89 clippy::unnested_or_patterns,
90 clippy::str_to_string,
91 rust_2018_idioms,
92 future_incompatible,
93 nonstandard_style,
94 missing_debug_implementations,
95 missing_docs
96)]
97#![deny(unreachable_pub, private_in_public)]
98#![allow(elided_lifetimes_in_paths, clippy::type_complexity)]
99#![forbid(unsafe_code)]
100#![cfg_attr(docsrs, feature(doc_auto_cfg, doc_cfg))]
101#![cfg_attr(test, allow(clippy::float_cmp))]
102
103use self::rejection::*;
104use async_trait::async_trait;
105use axum_core::{
106 extract::FromRequestParts,
107 response::{IntoResponse, Response},
108};
109use bytes::Bytes;
110use futures_util::{
111 sink::{Sink, SinkExt},
112 stream::{Stream, StreamExt},
113};
114use http::{
115 header::{self, HeaderMap, HeaderName, HeaderValue},
116 request::Parts,
117 Method, StatusCode,
118};
119use hyper::upgrade::{OnUpgrade, Upgraded};
120use sha1::{Digest, Sha1};
121use std::{
122 borrow::Cow,
123 future::Future,
124 pin::Pin,
125 task::{Context, Poll},
126};
127use tokio_tungstenite::{
128 tungstenite::protocol::{self, WebSocketConfig},
129 WebSocketStream,
130};
131
132#[doc(no_inline)]
133pub use tokio_tungstenite::tungstenite::error::{
134 CapacityError, Error, ProtocolError, TlsError, UrlError,
135};
136#[doc(no_inline)]
137pub use tokio_tungstenite::tungstenite::Message;
138
139#[derive(Debug)]
143pub struct WebSocketUpgrade<F = DefaultOnFailedUpdgrade> {
144 config: WebSocketConfig,
145 protocol: Option<HeaderValue>,
147 sec_websocket_key: HeaderValue,
148 on_upgrade: OnUpgrade,
149 on_failed_upgrade: F,
150 sec_websocket_protocol: Option<HeaderValue>,
151}
152
153impl<C> WebSocketUpgrade<C> {
154 pub fn write_buffer_size(mut self, size: usize) -> Self {
164 self.config.write_buffer_size = size;
165 self
166 }
167
168 pub fn max_write_buffer_size(mut self, max: usize) -> Self {
180 self.config.max_write_buffer_size = max;
181 self
182 }
183
184 pub fn max_message_size(mut self, max: usize) -> Self {
186 self.config.max_message_size = Some(max);
187 self
188 }
189
190 pub fn max_frame_size(mut self, max: usize) -> Self {
192 self.config.max_frame_size = Some(max);
193 self
194 }
195
196 pub fn accept_unmasked_frames(mut self, accept: bool) -> Self {
198 self.config.accept_unmasked_frames = accept;
199 self
200 }
201
202 pub fn protocols<I>(mut self, protocols: I) -> Self
212 where
213 I: IntoIterator,
214 I::Item: Into<Cow<'static, str>>,
215 {
216 if let Some(req_protocols) = self
217 .sec_websocket_protocol
218 .as_ref()
219 .and_then(|p| p.to_str().ok())
220 {
221 self.protocol = protocols
222 .into_iter()
223 .map(Into::into)
224 .find(|protocol| {
225 req_protocols
226 .split(',')
227 .any(|req_protocol| req_protocol.trim() == protocol)
228 })
229 .map(|protocol| match protocol {
230 Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(),
231 Cow::Borrowed(s) => HeaderValue::from_static(s),
232 });
233 }
234
235 self
236 }
237
238 pub fn on_upgrade<F, Fut>(self, callback: F) -> Response
245 where
246 F: FnOnce(WebSocket) -> Fut + Send + 'static,
247 Fut: Future<Output = ()> + Send + 'static,
248 C: OnFailedUpdgrade,
249 {
250 let on_upgrade = self.on_upgrade;
251 let config = self.config;
252 let on_failed_upgrade = self.on_failed_upgrade;
253
254 let protocol = self.protocol.clone();
255
256 tokio::spawn(async move {
257 let upgraded = match on_upgrade.await {
258 Ok(upgraded) => upgraded,
259 Err(err) => {
260 on_failed_upgrade.call(err);
261 return;
262 }
263 };
264
265 let socket =
266 WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
267 .await;
268 let socket = WebSocket {
269 inner: socket,
270 protocol,
271 };
272 callback(socket).await;
273 });
274
275 #[allow(clippy::declare_interior_mutable_const)]
276 const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
277 #[allow(clippy::declare_interior_mutable_const)]
278 const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
279
280 let mut headers = HeaderMap::new();
281 headers.insert(header::CONNECTION, UPGRADE);
282 headers.insert(header::UPGRADE, WEBSOCKET);
283 headers.insert(
284 header::SEC_WEBSOCKET_ACCEPT,
285 sign(self.sec_websocket_key.as_bytes()),
286 );
287
288 if let Some(protocol) = self.protocol {
289 headers.insert(header::SEC_WEBSOCKET_PROTOCOL, protocol);
290 }
291
292 (StatusCode::SWITCHING_PROTOCOLS, headers).into_response()
293 }
294
295 pub fn on_failed_upgrade<C2>(self, callback: C2) -> WebSocketUpgrade<C2>
318 where
319 C2: OnFailedUpdgrade,
320 {
321 WebSocketUpgrade {
322 config: self.config,
323 protocol: self.protocol,
324 sec_websocket_key: self.sec_websocket_key,
325 on_upgrade: self.on_upgrade,
326 on_failed_upgrade: callback,
327 sec_websocket_protocol: self.sec_websocket_protocol,
328 }
329 }
330}
331
332#[async_trait]
333impl<S> FromRequestParts<S> for WebSocketUpgrade
334where
335 S: Sync,
336{
337 type Rejection = WebSocketUpgradeRejection;
338
339 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
340 if parts.method != Method::GET {
341 return Err(MethodNotGet.into());
342 }
343
344 if !header_contains(parts, header::CONNECTION, "upgrade") {
345 return Err(InvalidConnectionHeader.into());
346 }
347
348 if !header_eq(parts, header::UPGRADE, "websocket") {
349 return Err(InvalidUpgradeHeader.into());
350 }
351
352 if !header_eq(parts, header::SEC_WEBSOCKET_VERSION, "13") {
353 return Err(InvalidWebSocketVersionHeader.into());
354 }
355
356 let sec_websocket_key = if let Some(key) = parts.headers.remove(header::SEC_WEBSOCKET_KEY) {
357 key
358 } else {
359 return Err(WebSocketKeyHeaderMissing.into());
360 };
361
362 let on_upgrade = parts.extensions.remove::<OnUpgrade>().unwrap();
363
364 let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
365
366 Ok(Self {
367 config: Default::default(),
368 protocol: None,
369 sec_websocket_key,
370 on_upgrade,
371 on_failed_upgrade: DefaultOnFailedUpdgrade,
372 sec_websocket_protocol,
373 })
374 }
375}
376
377fn header_eq(req: &Parts, key: HeaderName, value: &'static str) -> bool {
378 if let Some(header) = req.headers.get(&key) {
379 header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
380 } else {
381 false
382 }
383}
384
385fn header_contains(req: &Parts, key: HeaderName, value: &'static str) -> bool {
386 let header = if let Some(header) = req.headers.get(&key) {
387 header
388 } else {
389 return false;
390 };
391
392 if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
393 header.to_ascii_lowercase().contains(value)
394 } else {
395 false
396 }
397}
398
399#[derive(Debug)]
401pub struct WebSocket {
402 inner: WebSocketStream<Upgraded>,
403 protocol: Option<HeaderValue>,
404}
405
406impl WebSocket {
407 pub fn into_inner(self) -> WebSocketStream<Upgraded> {
409 self.inner
410 }
411
412 pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
416 self.next().await
417 }
418
419 pub async fn send(&mut self, msg: Message) -> Result<(), Error> {
421 self.inner.send(msg).await
422 }
423
424 pub async fn close(mut self) -> Result<(), Error> {
426 self.inner.close(None).await
427 }
428
429 pub fn protocol(&self) -> Option<&HeaderValue> {
431 self.protocol.as_ref()
432 }
433}
434
435impl Stream for WebSocket {
436 type Item = Result<Message, Error>;
437
438 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
439 self.inner.poll_next_unpin(cx)
440 }
441}
442
443impl Sink<Message> for WebSocket {
444 type Error = Error;
445
446 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
447 Pin::new(&mut self.inner).poll_ready(cx)
448 }
449
450 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
451 Pin::new(&mut self.inner).start_send(item)
452 }
453
454 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
455 Pin::new(&mut self.inner).poll_flush(cx)
456 }
457
458 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
459 Pin::new(&mut self.inner).poll_close(cx)
460 }
461}
462
463fn sign(key: &[u8]) -> HeaderValue {
464 use base64::engine::Engine as _;
465
466 let mut sha1 = Sha1::default();
467 sha1.update(key);
468 sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
469 let b64 = Bytes::from(base64::engine::general_purpose::STANDARD.encode(sha1.finalize()));
470 HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
471}
472
473pub trait OnFailedUpdgrade: Send + 'static {
477 fn call(self, error: hyper::Error);
479}
480
481impl<F> OnFailedUpdgrade for F
482where
483 F: FnOnce(hyper::Error) + Send + 'static,
484{
485 fn call(self, error: hyper::Error) {
486 self(error)
487 }
488}
489
490#[non_exhaustive]
494#[derive(Debug)]
495pub struct DefaultOnFailedUpdgrade;
496
497impl OnFailedUpdgrade for DefaultOnFailedUpdgrade {
498 #[inline]
499 fn call(self, _error: hyper::Error) {}
500}
501
502pub mod rejection {
503 use super::*;
506
507 macro_rules! define_rejection {
508 (
509 #[status = $status:ident]
510 #[body = $body:expr]
511 $(#[$m:meta])*
512 pub struct $name:ident;
513 ) => {
514 $(#[$m])*
515 #[derive(Debug)]
516 #[non_exhaustive]
517 pub struct $name;
518
519 impl IntoResponse for $name {
520 fn into_response(self) -> Response {
521 (http::StatusCode::$status, $body).into_response()
522 }
523 }
524
525 impl std::fmt::Display for $name {
526 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
527 write!(f, "{}", $body)
528 }
529 }
530
531 impl std::error::Error for $name {}
532 };
533 }
534
535 define_rejection! {
536 #[status = METHOD_NOT_ALLOWED]
537 #[body = "Request method must be `GET`"]
538 pub struct MethodNotGet;
540 }
541
542 define_rejection! {
543 #[status = BAD_REQUEST]
544 #[body = "Connection header did not include 'upgrade'"]
545 pub struct InvalidConnectionHeader;
547 }
548
549 define_rejection! {
550 #[status = BAD_REQUEST]
551 #[body = "`Upgrade` header did not include 'websocket'"]
552 pub struct InvalidUpgradeHeader;
554 }
555
556 define_rejection! {
557 #[status = BAD_REQUEST]
558 #[body = "`Sec-WebSocket-Version` header did not include '13'"]
559 pub struct InvalidWebSocketVersionHeader;
561 }
562
563 define_rejection! {
564 #[status = BAD_REQUEST]
565 #[body = "`Sec-WebSocket-Key` header missing"]
566 pub struct WebSocketKeyHeaderMissing;
568 }
569
570 macro_rules! composite_rejection {
571 (
572 $(#[$m:meta])*
573 pub enum $name:ident {
574 $($variant:ident),+
575 $(,)?
576 }
577 ) => {
578 $(#[$m])*
579 #[derive(Debug)]
580 #[non_exhaustive]
581 pub enum $name {
582 $(
583 #[allow(missing_docs)]
584 $variant($variant)
585 ),+
586 }
587
588 impl IntoResponse for $name {
589 fn into_response(self) -> Response {
590 match self {
591 $(
592 Self::$variant(inner) => inner.into_response(),
593 )+
594 }
595 }
596 }
597
598 $(
599 impl From<$variant> for $name {
600 fn from(inner: $variant) -> Self {
601 Self::$variant(inner)
602 }
603 }
604 )+
605
606 impl std::fmt::Display for $name {
607 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
608 match self {
609 $(
610 Self::$variant(inner) => write!(f, "{}", inner),
611 )+
612 }
613 }
614 }
615
616 impl std::error::Error for $name {
617 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
618 match self {
619 $(
620 Self::$variant(inner) => Some(inner),
621 )+
622 }
623 }
624 }
625 };
626 }
627
628 composite_rejection! {
629 pub enum WebSocketUpgradeRejection {
634 MethodNotGet,
635 InvalidConnectionHeader,
636 InvalidUpgradeHeader,
637 InvalidWebSocketVersionHeader,
638 WebSocketKeyHeaderMissing,
639 }
640 }
641}