extern crate alloc;
use http::{
header::{
HeaderMap, HeaderName, HeaderValue, ALLOW, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY,
SEC_WEBSOCKET_VERSION, UPGRADE,
},
request::Request,
response::{Builder, Response},
uri::Uri,
Method, StatusCode, Version,
};
mod codec;
mod error;
mod frame;
mod mask;
mod proto;
pub use self::{
codec::{Codec, Item, Message},
error::{HandshakeError, ProtocolError},
proto::{hash_key, CloseCode, CloseReason, OpCode},
};
#[allow(clippy::declare_interior_mutable_const)]
mod const_header {
use super::{HeaderName, HeaderValue};
pub(super) const PROTOCOL: HeaderName = HeaderName::from_static("protocol");
pub(super) const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
pub(super) const UPGRADE_VALUE: HeaderValue = HeaderValue::from_static("upgrade");
pub(super) const SEC_WEBSOCKET_VERSION_VALUE: HeaderValue = HeaderValue::from_static("13");
}
use const_header::*;
impl From<HandshakeError> for Builder {
fn from(e: HandshakeError) -> Self {
match e {
HandshakeError::GetMethodRequired => Response::builder()
.status(StatusCode::METHOD_NOT_ALLOWED)
.header(ALLOW, "GET"),
_ => Response::builder().status(StatusCode::BAD_REQUEST),
}
}
}
pub fn client_request_from_uri<U, E>(uri: U, version: Version) -> Result<Request<()>, E>
where
Uri: TryFrom<U, Error = E>,
{
let uri = uri.try_into()?;
let mut req = Request::new(());
*req.uri_mut() = uri;
*req.version_mut() = version;
match version {
Version::HTTP_11 => {
req.headers_mut().insert(UPGRADE, WEBSOCKET);
req.headers_mut().insert(CONNECTION, UPGRADE_VALUE);
let input = rand::random::<[u8; 16]>();
let mut output = [0u8; 24];
#[allow(clippy::needless_borrow)] let n =
base64::engine::Engine::encode_slice(&base64::engine::general_purpose::STANDARD, input, &mut output)
.unwrap();
assert_eq!(n, output.len());
req.headers_mut()
.insert(SEC_WEBSOCKET_KEY, HeaderValue::from_bytes(&output).unwrap());
}
Version::HTTP_2 => {
*req.method_mut() = Method::CONNECT;
req.headers_mut().insert(PROTOCOL, WEBSOCKET);
}
_ => {}
}
req.headers_mut()
.insert(SEC_WEBSOCKET_VERSION, SEC_WEBSOCKET_VERSION_VALUE);
Ok(req)
}
pub fn handshake(method: &Method, headers: &HeaderMap) -> Result<Builder, HandshakeError> {
let key = verify_handshake(method, headers)?;
let builder = handshake_response(key);
Ok(builder)
}
pub fn handshake_h2(method: &Method, headers: &HeaderMap) -> Result<Builder, HandshakeError> {
if method != Method::CONNECT {
return Err(HandshakeError::ConnectMethodRequired);
}
ws_version_check(headers)?;
Ok(Response::builder().status(StatusCode::OK))
}
fn verify_handshake<'a>(method: &'a Method, headers: &'a HeaderMap) -> Result<&'a [u8], HandshakeError> {
if method != Method::GET {
return Err(HandshakeError::GetMethodRequired);
}
let has_upgrade_hd = headers
.get(UPGRADE)
.and_then(|hdr| hdr.to_str().ok())
.filter(|s| s.to_ascii_lowercase().contains("websocket"))
.is_some();
if !has_upgrade_hd {
return Err(HandshakeError::NoWebsocketUpgrade);
}
let has_connection_hd = headers
.get(CONNECTION)
.and_then(|hdr| hdr.to_str().ok())
.filter(|s| s.to_ascii_lowercase().contains("upgrade"))
.is_some();
if !has_connection_hd {
return Err(HandshakeError::NoConnectionUpgrade);
}
ws_version_check(headers)?;
let value = headers.get(SEC_WEBSOCKET_KEY).ok_or(HandshakeError::BadWebsocketKey)?;
Ok(value.as_bytes())
}
fn handshake_response(key: &[u8]) -> Builder {
let key = hash_key(key);
Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(UPGRADE, WEBSOCKET)
.header(CONNECTION, UPGRADE_VALUE)
.header(
SEC_WEBSOCKET_ACCEPT,
HeaderValue::from_bytes(&key).unwrap(),
)
}
fn ws_version_check(headers: &HeaderMap) -> Result<(), HandshakeError> {
let value = headers
.get(SEC_WEBSOCKET_VERSION)
.ok_or(HandshakeError::NoVersionHeader)?;
if value != "13" && value != "8" && value != "7" {
Err(HandshakeError::UnsupportedVersion)
} else {
Ok(())
}
}
#[cfg(feature = "stream")]
pub mod stream;
#[cfg(feature = "stream")]
pub use self::stream::{RequestStream, ResponseSender, ResponseStream, ResponseWeakSender, WsError};
#[cfg(feature = "stream")]
pub type WsOutput<B, E> = (RequestStream<B, E>, Response<ResponseStream>, ResponseSender);
#[cfg(feature = "stream")]
pub fn ws<ReqB, B, T, E>(req: &Request<ReqB>, body: B) -> Result<WsOutput<B, E>, HandshakeError>
where
B: futures_core::Stream<Item = Result<T, E>>,
T: AsRef<[u8]>,
{
let builder = match req.version() {
Version::HTTP_2 => handshake_h2(req.method(), req.headers())?,
_ => handshake(req.method(), req.headers())?,
};
let decode = RequestStream::new(body);
let (res, tx) = decode.response_stream();
let res = builder
.body(res)
.expect("handshake function failed to generate correct Response Builder");
Ok((decode, res, tx))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_handshake() {
let req = Request::builder().method(Method::POST).body(()).unwrap();
assert_eq!(
HandshakeError::GetMethodRequired,
verify_handshake(req.method(), req.headers()).unwrap_err(),
);
let req = Request::builder().body(()).unwrap();
assert_eq!(
HandshakeError::NoWebsocketUpgrade,
verify_handshake(req.method(), req.headers()).unwrap_err(),
);
let req = Request::builder()
.header(UPGRADE, HeaderValue::from_static("test"))
.body(())
.unwrap();
assert_eq!(
HandshakeError::NoWebsocketUpgrade,
verify_handshake(req.method(), req.headers()).unwrap_err(),
);
let req = Request::builder().header(UPGRADE, WEBSOCKET).body(()).unwrap();
assert_eq!(
HandshakeError::NoConnectionUpgrade,
verify_handshake(req.method(), req.headers()).unwrap_err(),
);
let req = Request::builder()
.header(UPGRADE, WEBSOCKET)
.header(CONNECTION, UPGRADE_VALUE)
.body(())
.unwrap();
assert_eq!(
HandshakeError::NoVersionHeader,
verify_handshake(req.method(), req.headers()).unwrap_err(),
);
let req = Request::builder()
.header(UPGRADE, WEBSOCKET)
.header(CONNECTION, UPGRADE_VALUE)
.header(SEC_WEBSOCKET_VERSION, HeaderValue::from_static("5"))
.body(())
.unwrap();
assert_eq!(
HandshakeError::UnsupportedVersion,
verify_handshake(req.method(), req.headers()).unwrap_err(),
);
let builder = || {
Request::builder()
.header(UPGRADE, WEBSOCKET)
.header(CONNECTION, UPGRADE_VALUE)
.header(SEC_WEBSOCKET_VERSION, SEC_WEBSOCKET_VERSION_VALUE)
};
let req = builder().body(()).unwrap();
assert_eq!(
HandshakeError::BadWebsocketKey,
verify_handshake(req.method(), req.headers()).unwrap_err(),
);
let req = builder()
.header(SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION_VALUE)
.body(())
.unwrap();
let key = verify_handshake(req.method(), req.headers()).unwrap();
assert_eq!(
StatusCode::SWITCHING_PROTOCOLS,
handshake_response(key).body(()).unwrap().status()
);
}
#[test]
fn test_ws_error_http_response() {
let res = Builder::from(HandshakeError::GetMethodRequired).body(()).unwrap();
assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
let res = Builder::from(HandshakeError::NoWebsocketUpgrade).body(()).unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let res = Builder::from(HandshakeError::NoConnectionUpgrade).body(()).unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let res = Builder::from(HandshakeError::NoVersionHeader).body(()).unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let res = Builder::from(HandshakeError::UnsupportedVersion).body(()).unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
let res = Builder::from(HandshakeError::BadWebsocketKey).body(()).unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST);
}
}