use std::io::{Read, Write};
use server::Response;
use result::{WebSocketResult, WebSocketError};
use header::{WebSocketKey, WebSocketVersion, WebSocketProtocol, WebSocketExtensions, Origin};
pub use hyper::uri::RequestUri;
use hyper::buffer::BufReader;
use hyper::version::HttpVersion;
use hyper::header::Headers;
use hyper::header::{Connection, ConnectionOption};
use hyper::header::{Upgrade, ProtocolName};
use hyper::http::h1::parse_request;
use hyper::method::Method;
use unicase::UniCase;
pub struct Request<R: Read, W: Write> {
pub method: Method,
pub url: RequestUri,
pub version: HttpVersion,
pub headers: Headers,
reader: R,
writer: W,
}
unsafe impl<R, W> Send for Request<R, W> where R: Read + Send, W: Write + Send { }
impl<R: Read, W: Write> Request<R, W> {
pub fn key(&self) -> Option<&WebSocketKey> {
self.headers.get()
}
pub fn version(&self) -> Option<&WebSocketVersion> {
self.headers.get()
}
pub fn protocol(&self) -> Option<&WebSocketProtocol> {
self.headers.get()
}
pub fn extensions(&self) -> Option<&WebSocketExtensions> {
self.headers.get()
}
pub fn origin(&self) -> Option<&Origin> {
self.headers.get()
}
pub fn get_reader(&self) -> &R {
&self.reader
}
pub fn get_writer(&self) -> &W {
&self.writer
}
pub fn get_mut_reader(&mut self) -> &mut R {
&mut self.reader
}
pub fn get_mut_writer(&mut self) -> &mut W {
&mut self.writer
}
pub fn into_inner(self) -> (R, W) {
(self.reader, self.writer)
}
pub fn read(reader: R, writer: W) -> WebSocketResult<Request<R, W>> {
let mut reader = BufReader::new(reader);
let request = try!(parse_request(&mut reader));
Ok(Request {
method: request.subject.0,
url: request.subject.1,
version: request.version,
headers: request.headers,
reader: reader.into_inner(),
writer: writer,
})
}
pub fn validate(&self) -> WebSocketResult<()> {
if self.method != Method::Get {
return Err(WebSocketError::RequestError("Request method must be GET"));
}
if self.version == HttpVersion::Http09 || self.version == HttpVersion::Http10 {
return Err(WebSocketError::RequestError("Unsupported request HTTP version"));
}
if self.version() != Some(&(WebSocketVersion::WebSocket13)) {
return Err(WebSocketError::RequestError("Unsupported WebSocket version"));
}
if self.key().is_none() {
return Err(WebSocketError::RequestError("Missing Sec-WebSocket-Key header"));
}
match self.headers.get() {
Some(&Upgrade(ref upgrade)) => {
let mut correct_upgrade = false;
for u in upgrade {
if u.name == ProtocolName::WebSocket {
correct_upgrade = true;
}
}
if !correct_upgrade {
return Err(WebSocketError::RequestError("Invalid Upgrade WebSocket header"));
}
}
None => { return Err(WebSocketError::RequestError("Missing Upgrade WebSocket header")); }
}
match self.headers.get() {
Some(&Connection(ref connection)) => {
if !connection.contains(&(ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())))) {
return Err(WebSocketError::RequestError("Invalid Connection WebSocket header"));
}
}
None => { return Err(WebSocketError::RequestError("Missing Connection WebSocket header")); }
}
Ok(())
}
pub fn accept(self) -> Response<R, W> {
match self.validate() {
Ok(()) => { }
Err(_) => { return self.fail(); }
}
Response::new(self)
}
pub fn fail(self) -> Response<R, W> {
Response::bad_request(self)
}
}