axum_tws/
lib.rs

1#![doc = include_str!("../README.md")]
2#![forbid(unsafe_code)]
3
4pub mod upgrade;
5pub mod websocket;
6
7use std::fmt::Display;
8
9use axum_core::body::Body;
10use axum_core::response::IntoResponse;
11use axum_core::response::Response;
12use http::StatusCode;
13
14pub use tokio_websockets::*;
15
16pub use crate::{upgrade::WebSocketUpgrade, websocket::WebSocket};
17
18#[derive(Debug)]
19pub enum WebSocketError {
20    ConnectionNotUpgradeable,
21    Internal(tokio_websockets::Error),
22    InvalidConnectionHeader,
23    /// For WebSocket over HTTP/2+
24    InvalidProtocolPseudoheader,
25    InvalidUpgradeHeader,
26    InvalidWebSocketVersionHeader,
27    /// Invalid method for WebSocket over HTTP/1.x
28    MethodNotGet,
29    /// Invalid method for WebSocket over HTTP/2+
30    MethodNotConnect,
31    UpgradeFailed(hyper::Error),
32}
33
34impl Display for WebSocketError {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        match self {
37            WebSocketError::ConnectionNotUpgradeable => {
38                write!(f, "connection is not upgradeable")
39            }
40            WebSocketError::Internal(e) => {
41                write!(f, "internal server error: {}", e)
42            }
43            WebSocketError::InvalidConnectionHeader => {
44                write!(f, "invalid `Connection` header")
45            }
46            WebSocketError::InvalidProtocolPseudoheader => {
47                write!(f, "invalid `:protocol` pseudoheader")
48            }
49            WebSocketError::InvalidUpgradeHeader => {
50                write!(f, "invalid `Upgrade` header")
51            }
52            WebSocketError::InvalidWebSocketVersionHeader => {
53                write!(f, "invalid `Sec-WebSocket-Version` header")
54            }
55            WebSocketError::MethodNotGet => {
56                write!(f, "http request method must be `GET`")
57            }
58            WebSocketError::MethodNotConnect => {
59                write!(f, "http2 request method must be `CONNECT`")
60            }
61            WebSocketError::UpgradeFailed(e) => {
62                write!(f, "upgrade failed: {}", e)
63            }
64        }
65    }
66}
67
68impl std::error::Error for WebSocketError {
69    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
70        match self {
71            WebSocketError::Internal(e) => Some(e),
72            WebSocketError::UpgradeFailed(e) => Some(e),
73            _ => None,
74        }
75    }
76}
77
78impl IntoResponse for WebSocketError {
79    fn into_response(self) -> Response<Body> {
80        let status = match self {
81            WebSocketError::ConnectionNotUpgradeable => StatusCode::UPGRADE_REQUIRED,
82
83            // Request headers are invalid or missing.
84            WebSocketError::InvalidConnectionHeader
85            | WebSocketError::InvalidUpgradeHeader
86            | WebSocketError::InvalidWebSocketVersionHeader => StatusCode::BAD_REQUEST,
87
88            // Invalid request method.
89            WebSocketError::MethodNotGet => StatusCode::METHOD_NOT_ALLOWED,
90
91            // All other errors will be treated as internal server errors.
92            _ => StatusCode::INTERNAL_SERVER_ERROR,
93        };
94
95        Response::builder()
96            .status(status)
97            .body(Body::empty())
98            .unwrap()
99    }
100}
101
102impl From<tokio_websockets::Error> for WebSocketError {
103    fn from(e: tokio_websockets::Error) -> Self {
104        WebSocketError::Internal(e)
105    }
106}
107
108impl From<hyper::Error> for WebSocketError {
109    fn from(e: hyper::Error) -> Self {
110        WebSocketError::UpgradeFailed(e)
111    }
112}