use crate::header::extensions::Extension;
use crate::header::{
Origin, WebSocketAccept, WebSocketExtensions, WebSocketKey, WebSocketProtocol, WebSocketVersion,
};
use crate::stream::Stream;
use std::error::Error;
use std::fmt::{self, Display, Formatter};
use std::io;
use hyper::header::{Connection, ConnectionOption, Headers, Protocol, ProtocolName, Upgrade};
use hyper::http::h1::Incoming;
use hyper::method::Method;
use hyper::status::StatusCode;
use hyper::uri::RequestUri;
use unicase::UniCase;
#[cfg(any(feature = "sync", feature = "async"))]
use hyper::version::HttpVersion;
#[cfg(feature = "async")]
pub mod r#async;
#[cfg(feature = "sync")]
pub mod sync;
pub type Request = Incoming<(Method, RequestUri)>;
pub struct WsUpgrade<S, B>
where
S: Stream,
{
pub headers: Headers,
pub stream: S,
pub request: Request,
pub buffer: B,
}
impl<S, B> WsUpgrade<S, B>
where
S: Stream,
{
pub fn use_protocol<P>(mut self, protocol: P) -> Self
where
P: Into<String>,
{
upsert_header!(self.headers; WebSocketProtocol; {
Some(protos) => protos.0.push(protocol.into()),
None => WebSocketProtocol(vec![protocol.into()])
});
self
}
pub fn use_extension(mut self, extension: Extension) -> Self {
upsert_header!(self.headers; WebSocketExtensions; {
Some(protos) => protos.0.push(extension),
None => WebSocketExtensions(vec![extension])
});
self
}
pub fn use_extensions<I>(mut self, extensions: I) -> Self
where
I: IntoIterator<Item = Extension>,
{
let mut extensions: Vec<Extension> = extensions.into_iter().collect();
upsert_header!(self.headers; WebSocketExtensions; {
Some(protos) => protos.0.append(&mut extensions),
None => WebSocketExtensions(extensions)
});
self
}
pub fn drop(self) {
::std::mem::drop(self);
}
pub fn protocols(&self) -> &[String] {
self.request
.headers
.get::<WebSocketProtocol>()
.map(|p| p.0.as_slice())
.unwrap_or(&[])
}
pub fn extensions(&self) -> &[Extension] {
self.request
.headers
.get::<WebSocketExtensions>()
.map(|e| e.0.as_slice())
.unwrap_or(&[])
}
pub fn key(&self) -> Option<&[u8; 16]> {
self.request.headers.get::<WebSocketKey>().map(|k| &(k.0).0)
}
pub fn version(&self) -> Option<&WebSocketVersion> {
self.request.headers.get::<WebSocketVersion>()
}
pub fn uri(&self) -> String {
format!("{}", self.request.subject.1)
}
pub fn origin(&self) -> Option<&str> {
self.request.headers.get::<Origin>().map(|o| &o.0 as &str)
}
#[cfg(feature = "sync")]
fn send(&mut self, status: StatusCode) -> io::Result<()> {
let data = format!(
"{} {}\r\n{}\r\n",
self.request.version, status, self.headers
);
self.stream.write_all(data.as_bytes())?;
Ok(())
}
#[doc(hidden)]
pub fn prepare_headers(&mut self, custom: Option<&Headers>) -> StatusCode {
if let Some(headers) = custom {
self.headers.extend(headers.iter());
}
let key = self.request.headers.get::<WebSocketKey>().unwrap();
self.headers.set(WebSocketAccept::new(key));
self.headers
.set(Connection(vec![ConnectionOption::ConnectionHeader(
UniCase("Upgrade".to_string()),
)]));
self.headers
.set(Upgrade(vec![Protocol::new(ProtocolName::WebSocket, None)]));
StatusCode::SwitchingProtocols
}
}
#[derive(Debug)]
pub enum HyperIntoWsError {
MethodNotGet,
UnsupportedHttpVersion,
UnsupportedWebsocketVersion,
NoSecWsKeyHeader,
NoWsUpgradeHeader,
NoUpgradeHeader,
NoWsConnectionHeader,
NoConnectionHeader,
Io(io::Error),
Parsing(::hyper::error::Error),
}
impl Display for HyperIntoWsError {
fn fmt(&self, fmt: &mut Formatter) -> Result<(), fmt::Error> {
match *self {
HyperIntoWsError::MethodNotGet => fmt.write_str("Request method must be GET"),
HyperIntoWsError::UnsupportedHttpVersion => {
fmt.write_str("Unsupported request HTTP version")
}
HyperIntoWsError::UnsupportedWebsocketVersion => {
fmt.write_str("Unsupported WebSocket version")
}
HyperIntoWsError::NoSecWsKeyHeader => fmt.write_str("Missing Sec-WebSocket-Key header"),
HyperIntoWsError::NoWsUpgradeHeader => {
fmt.write_str("Invalid Upgrade WebSocket header")
}
HyperIntoWsError::NoUpgradeHeader => fmt.write_str("Missing Upgrade WebSocket header"),
HyperIntoWsError::NoWsConnectionHeader => {
fmt.write_str("Invalid Connection WebSocket header")
}
HyperIntoWsError::NoConnectionHeader => {
fmt.write_str("Missing Connection WebSocket header")
}
HyperIntoWsError::Io(ref e) => fmt.write_str(e.to_string().as_str()),
HyperIntoWsError::Parsing(ref e) => fmt.write_str(e.to_string().as_str()),
}
}
}
impl Error for HyperIntoWsError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match *self {
HyperIntoWsError::Io(ref e) => Some(e),
HyperIntoWsError::Parsing(ref e) => Some(e),
_ => None,
}
}
}
impl From<io::Error> for HyperIntoWsError {
fn from(err: io::Error) -> Self {
HyperIntoWsError::Io(err)
}
}
impl From<::hyper::error::Error> for HyperIntoWsError {
fn from(err: ::hyper::error::Error) -> Self {
HyperIntoWsError::Parsing(err)
}
}
#[cfg(feature = "async")]
impl From<crate::codec::http::HttpCodecError> for HyperIntoWsError {
fn from(src: crate::codec::http::HttpCodecError) -> Self {
match src {
crate::codec::http::HttpCodecError::Io(e) => HyperIntoWsError::Io(e),
crate::codec::http::HttpCodecError::Http(e) => HyperIntoWsError::Parsing(e),
}
}
}
#[cfg(any(feature = "sync", feature = "async"))]
pub fn validate(
method: &Method,
version: HttpVersion,
headers: &Headers,
) -> Result<(), HyperIntoWsError> {
if *method != Method::Get {
return Err(HyperIntoWsError::MethodNotGet);
}
if version == HttpVersion::Http09 || version == HttpVersion::Http10 {
return Err(HyperIntoWsError::UnsupportedHttpVersion);
}
if let Some(version) = headers.get::<WebSocketVersion>() {
if version != &WebSocketVersion::WebSocket13 {
return Err(HyperIntoWsError::UnsupportedWebsocketVersion);
}
}
if headers.get::<WebSocketKey>().is_none() {
return Err(HyperIntoWsError::NoSecWsKeyHeader);
}
match headers.get() {
Some(&Upgrade(ref upgrade)) => {
if upgrade.iter().all(|u| u.name != ProtocolName::WebSocket) {
return Err(HyperIntoWsError::NoWsUpgradeHeader);
}
}
None => return Err(HyperIntoWsError::NoUpgradeHeader),
};
fn check_connection_header(headers: &[ConnectionOption]) -> bool {
for header in headers {
if let ConnectionOption::ConnectionHeader(ref h) = *header {
if UniCase(h as &str) == UniCase("upgrade") {
return true;
}
}
}
false
}
match headers.get() {
Some(&Connection(ref connection)) => {
if !check_connection_header(connection) {
return Err(HyperIntoWsError::NoWsConnectionHeader);
}
}
None => return Err(HyperIntoWsError::NoConnectionHeader),
};
Ok(())
}