use std::fmt;
use std::io::Write;
use std::mem::transmute;
use std::str::from_utf8;
use std::net::SocketAddr;
use sha1;
use rand;
use url;
use httparse;
use result::{Result, Error, Kind};
static WS_GUID: &'static str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
static BASE64: &'static [u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
const MAX_HEADERS: usize = 124;
fn generate_key() -> String {
let key: [u8; 16] = unsafe {
transmute(rand::random::<(u64, u64)>())
};
encode_base64(&key)
}
pub fn hash_key(key: &[u8]) -> String {
let mut hasher = sha1::Sha1::new();
hasher.update(key);
hasher.update(WS_GUID.as_bytes());
encode_base64(&hasher.digest().bytes())
}
fn encode_base64(data: &[u8]) -> String {
let len = data.len();
let mod_len = len % 3;
let mut encoded = vec![b'='; (len + 2) / 3 * 4];
{
let mut in_iter = data[..len - mod_len].iter().map(|&c| c as u32);
let mut out_iter = encoded.iter_mut();
let enc = |val| BASE64[val as usize];
let mut write = |val| *out_iter.next().unwrap() = val;
while let (Some(one), Some(two), Some(three)) = (in_iter.next(), in_iter.next(), in_iter.next()) {
let g24 = one << 16 | two << 8 | three;
write(enc((g24 >> 18) & 63));
write(enc((g24 >> 12) & 63));
write(enc((g24 >> 6 ) & 63));
write(enc(g24 & 63));
}
match mod_len {
1 => {
let pad = (data[len-1] as u32) << 16;
write(enc((pad >> 18) & 63));
write(enc((pad >> 12) & 63));
}
2 => {
let pad = (data[len-2] as u32) << 16 | (data[len-1] as u32) << 8;
write(enc((pad >> 18) & 63));
write(enc((pad >> 12) & 63));
write(enc((pad >> 6) & 63));
}
_ => (),
}
}
String::from_utf8(encoded).unwrap()
}
#[derive(Debug)]
pub struct Handshake {
pub request: Request,
pub response: Response,
pub peer_addr: Option<SocketAddr>,
pub local_addr: Option<SocketAddr>,
}
impl Handshake {
#[allow(dead_code)]
pub fn remote_addr(&self) -> Result<Option<String>> {
Ok(try!(self.request.client_addr()).map(String::from).or_else(|| {
if let Some(addr) = self.peer_addr {
Some(addr.ip().to_string())
} else {
None
}
}))
}
}
#[derive(Debug)]
pub struct Request {
path: String,
method: String,
headers: Vec<(String, Vec<u8>)>,
}
impl Request {
pub fn header(&self, header: &str) -> Option<&Vec<u8>> {
self.headers
.iter()
.find(|&&(ref key, _)| key.to_lowercase() == header.to_lowercase())
.map(|&(_, ref val)| val)
}
pub fn header_mut(&mut self, header: &str) -> Option<&mut Vec<u8>> {
self.headers
.iter_mut()
.find(|&&mut (ref key, _)| key.to_lowercase() == header.to_lowercase())
.map(|&mut (_, ref mut val)| val)
}
#[allow(dead_code)]
#[inline]
pub fn headers(&self) -> &Vec<(String, Vec<u8>)> {
&self.headers
}
#[allow(dead_code)]
#[inline]
pub fn headers_mut(&mut self) -> &mut Vec<(String, Vec<u8>)> {
&mut self.headers
}
#[allow(dead_code)]
pub fn origin(&self) -> Result<Option<&str>> {
if let Some(origin) = self.header("origin") {
Ok(Some(try!(from_utf8(origin))))
} else {
Ok(None)
}
}
pub fn key(&self) -> Result<&Vec<u8>> {
self.header("sec-websocket-key")
.ok_or(Error::new(Kind::Protocol, "Unable to parse WebSocket key."))
}
pub fn hashed_key(&self) -> Result<String> {
Ok(hash_key(try!(self.key())))
}
#[allow(dead_code)]
pub fn version(&self) -> Result<&str> {
if let Some(version) = self.header("sec-websocket-version") {
from_utf8(version).map_err(Error::from)
} else {
Err(Error::new(Kind::Protocol, "The Sec-WebSocket-Version header is missing."))
}
}
#[allow(dead_code)]
#[inline]
pub fn resource(&self) -> &str {
&self.path
}
#[allow(dead_code)]
pub fn protocols(&self) -> Result<Vec<&str>> {
if let Some(protos) = self.header("sec-websocket-protocol") {
Ok(try!(from_utf8(protos)).split(',').map(|proto| proto.trim()).collect())
} else {
Ok(Vec::new())
}
}
#[allow(dead_code)]
pub fn add_protocol(&mut self, protocol: &str) {
if let Some(protos) = self.header_mut("sec-websocket-protocol") {
protos.push(b","[0]);
protos.extend(protocol.as_bytes());
return
}
self.headers_mut().push(("Sec-WebSocket-Protocol".into(), protocol.into()))
}
#[allow(dead_code)]
pub fn remove_protocol(&mut self, protocol: &str) {
if let Some(protos) = self.header_mut("sec-websocket-protocol") {
let mut new_protos = Vec::with_capacity(protos.len());
if let Ok(protos_str) = from_utf8(protos) {
new_protos = protos_str
.split(',')
.filter(|proto| proto.trim() == protocol)
.collect::<Vec<&str>>()
.join(",").into();
}
if new_protos.len() < protos.len() {
*protos = new_protos
}
}
}
#[allow(dead_code)]
pub fn extensions(&self) -> Result<Vec<&str>> {
if let Some(exts) = self.header("sec-websocket-extensions") {
Ok(try!(from_utf8(exts)).split(',').map(|ext| ext.trim()).collect())
} else {
Ok(Vec::new())
}
}
#[allow(dead_code)]
pub fn add_extension(&mut self, ext: &str) {
if let Some(exts) = self.header_mut("sec-websocket-extensions") {
exts.push(b","[0]);
exts.extend(ext.as_bytes());
return
}
self.headers_mut().push(("Sec-WebSocket-Extensions".into(), ext.into()))
}
#[allow(dead_code)]
pub fn remove_extension(&mut self, ext: &str) {
if let Some(exts) = self.header_mut("sec-websocket-extensions") {
let mut new_exts = Vec::with_capacity(exts.len());
if let Ok(exts_str) = from_utf8(exts) {
new_exts = exts_str
.split(',')
.filter(|e| e.trim().starts_with(ext))
.collect::<Vec<&str>>()
.join(",").into();
}
if new_exts.len() < exts.len() {
*exts = new_exts
}
}
}
#[allow(dead_code)]
pub fn client_addr(&self) -> Result<Option<&str>> {
if let Some(x_forward) = self.header("x-forwarded-for") {
return Ok(try!(from_utf8(x_forward)).split(',').next())
}
if let Some(forward) = self.header("forwarded") {
if let Some(_for) = try!(from_utf8(forward))
.split(';')
.find(|f| f.trim().starts_with("for"))
{
if let Some(_for_eq) = _for.trim().split(',').next() {
let mut it = _for_eq.split('=');
it.next();
return Ok(it.next())
}
}
}
Ok(None)
}
pub fn parse(buf: &[u8]) -> Result<Option<Request>> {
let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
let mut req = httparse::Request::new(&mut headers);
let parsed = try!(req.parse(buf));
if !parsed.is_partial() {
Ok(Some(Request {
path: req.path.unwrap().into(),
method: req.method.unwrap().into(),
headers: req.headers.iter().map(|h| (h.name.into(), h.value.into())).collect(),
}))
} else {
Ok(None)
}
}
pub fn from_url(url: &url::Url) -> Result<Request> {
let query = if let Some(q) = url.query() {
format!("?{}", q)
} else {
"".into()
};
let req = Request {
path: format!(
"{}{}",
url.path(),
query),
method: "GET".to_owned(),
headers: vec![
("Connection".into(), "Upgrade".into()),
(
"Host".into(),
format!(
"{}:{}",
try!(url.host_str().ok_or(
Error::new(Kind::Internal, "No host passed for WebSocket connection."))),
url.port_or_known_default().unwrap_or(80)).into(),
),
("Sec-WebSocket-Version".into(), "13".into()),
("Sec-WebSocket-Key".into(), generate_key().into()),
("Upgrade".into(), "websocket".into()),
],
};
debug!("Built request from URL:\n{}", req);
Ok(req)
}
pub fn format<W>(&self, w: &mut W) -> Result<()>
where W: Write,
{
try!(write!(w, "{} {} HTTP/1.1\r\n", self.method, self.path));
for &(ref key, ref val) in self.headers.iter() {
try!(write!(w, "{}: ", key));
try!(w.write(val));
try!(write!(w, "\r\n"));
}
try!(write!(w, "\r\n"));
Ok(())
}
}
impl fmt::Display for Request {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut s = Vec::with_capacity(2048);
try!(self.format(&mut s).map_err(|err| {
error!("{:?}", err);
fmt::Error
}));
write!(f, "{}", try!(from_utf8(&s).map_err(|err| {
error!("Unable to format request as utf8: {:?}", err);
fmt::Error
})))
}
}
#[derive(Debug)]
pub struct Response {
status: u16,
reason: String,
headers: Vec<(String, Vec<u8>)>,
}
impl Response {
fn header(&self, header: &str) -> Option<&Vec<u8>> {
self.headers
.iter()
.find(|&&(ref key, _)| key.to_lowercase() == header.to_lowercase())
.map(|&(_, ref val)| val)
}
pub fn header_mut(&mut self, header: &str) -> Option<&mut Vec<u8>> {
self.headers
.iter_mut()
.find(|&&mut (ref key, _)| key.to_lowercase() == header.to_lowercase())
.map(|&mut (_, ref mut val)| val)
}
#[allow(dead_code)]
#[inline]
pub fn headers(&self) -> &Vec<(String, Vec<u8>)> {
&self.headers
}
#[allow(dead_code)]
#[inline]
pub fn headers_mut(&mut self) -> &mut Vec<(String, Vec<u8>)> {
&mut self.headers
}
#[allow(dead_code)]
#[inline]
pub fn status(&self) -> u16 {
self.status
}
#[allow(dead_code)]
#[inline]
pub fn set_status(&mut self, status: u16) {
self.status = status
}
#[allow(dead_code)]
#[inline]
pub fn reason(&self) -> &str {
&self.reason
}
#[allow(dead_code)]
#[inline]
pub fn set_reason<R>(&mut self, reason: R)
where R: Into<String>
{
self.reason = reason.into()
}
pub fn key(&self) -> Result<&Vec<u8>> {
self.header("sec-websocket-accept").ok_or(Error::new(Kind::Protocol, "Unable to parse WebSocket key."))
}
#[allow(dead_code)]
pub fn protocol(&self) -> Result<Option<&str>> {
if let Some(proto) = self.header("sec-websocket-protocol") {
Ok(Some(try!(from_utf8(proto))))
} else {
Ok(None)
}
}
#[allow(dead_code)]
pub fn set_protocol(&mut self, protocol: &str) {
if let Some(proto) = self.header_mut("sec-websocket-protocol") {
*proto = protocol.into();
return
}
self.headers_mut().push(("Sec-WebSocket-Protocol".into(), protocol.into()))
}
#[allow(dead_code)]
pub fn extensions(&self) -> Result<Vec<&str>> {
if let Some(exts) = self.header("sec-websocket-extensions") {
Ok(try!(from_utf8(exts)).split(',').map(|proto| proto.trim()).collect())
} else {
Ok(Vec::new())
}
}
#[allow(dead_code)]
pub fn add_extension(&mut self, ext: &str) {
if let Some(exts) = self.header_mut("sec-websocket-extensions") {
exts.push(b","[0]);
exts.extend(ext.as_bytes());
return
}
self.headers_mut().push(("Sec-WebSocket-Extensions".into(), ext.into()))
}
#[allow(dead_code)]
pub fn remove_extension(&mut self, ext: &str) {
if let Some(exts) = self.header_mut("sec-websocket-extensions") {
let mut new_exts = Vec::with_capacity(exts.len());
if let Ok(exts_str) = from_utf8(exts) {
new_exts = exts_str
.split(',')
.filter(|e| e.trim().starts_with(ext))
.collect::<Vec<&str>>()
.join(",").into();
}
if new_exts.len() < exts.len() {
*exts = new_exts
}
}
}
pub fn parse(buf: &[u8]) -> Result<Option<Response>> {
let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
let mut res = httparse::Response::new(&mut headers);
let parsed = try!(res.parse(buf));
if !parsed.is_partial() {
Ok(Some(Response {
status: res.code.unwrap(),
reason: res.reason.unwrap().into(),
headers: res.headers.iter().map(|h| (h.name.into(), h.value.into())).collect(),
}))
} else {
Ok(None)
}
}
pub fn from_request(req: &Request) -> Result<Response> {
let res = Response {
status: 101,
reason: "Switching Protocols".into(),
headers: vec![
("Connection".into(), "Upgrade".into()),
("Sec-WebSocket-Accept".into(), try!(req.hashed_key()).into()),
("Upgrade".into(), "websocket".into()),
],
};
debug!("Built response from request:\n{}", res);
Ok(res)
}
pub fn format<W>(&self, w: &mut W) -> Result<()>
where W: Write
{
try!(write!(w, "HTTP/1.1 {} {}\r\n", self.status, self.reason));
for &(ref key, ref val) in self.headers.iter() {
try!(write!(w, "{}: ", key));
try!(w.write(val));
try!(write!(w, "\r\n"));
}
try!(write!(w, "\r\n"));
Ok(())
}
}
impl fmt::Display for Response {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut s = Vec::with_capacity(2048);
try!(self.format(&mut s).map_err(|err| {
error!("{:?}", err);
fmt::Error
}));
write!(f, "{}", try!(from_utf8(&s).map_err(|err| {
error!("Unable to format response as utf8: {:?}", err);
fmt::Error
})))
}
}
mod test {
#![allow(unused_imports, unused_variables, dead_code)]
use std::io::Write;
use std::net::SocketAddr;
use std::str::FromStr;
use super::*;
#[test]
fn remote_addr() {
let mut buf = Vec::with_capacity(2048);
write!(
&mut buf,
"GET / HTTP/1.1\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: q16eN37NCfVwUChPvBdk4g==\r\n\r\n").unwrap();
let req = Request::parse(&buf).unwrap().unwrap();
let res = Response::from_request(&req).unwrap();
let shake = Handshake {
request: req,
response: res,
peer_addr: Some(SocketAddr::from_str("127.0.0.1:8888").unwrap()),
local_addr: None,
};
assert_eq!(shake.remote_addr().unwrap().unwrap(), "127.0.0.1");
}
#[test]
fn remote_addr_x_forwarded_for() {
let mut buf = Vec::with_capacity(2048);
write!(
&mut buf,
"GET / HTTP/1.1\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
X-Forwarded-For: 192.168.1.1, 192.168.1.2, 192.168.1.3\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: q16eN37NCfVwUChPvBdk4g==\r\n\r\n").unwrap();
let req = Request::parse(&buf).unwrap().unwrap();
let res = Response::from_request(&req).unwrap();
let shake = Handshake {
request: req,
response: res,
peer_addr: None,
local_addr: None,
};
assert_eq!(shake.remote_addr().unwrap().unwrap(), "192.168.1.1");
}
#[test]
fn remote_addr_forwarded() {
let mut buf = Vec::with_capacity(2048);
write!(
&mut buf,
"GET / HTTP/1.1\r\n\
Connection: Upgrade\r\n\
Upgrade: websocket\r\n\
Forwarded: by=192.168.1.1; for=192.0.2.43, for=\"[2001:db8:cafe::17]\", for=unknown\r\n\
Sec-WebSocket-Version: 13\r\n\
Sec-WebSocket-Key: q16eN37NCfVwUChPvBdk4g==\r\n\r\n").unwrap();
let req = Request::parse(&buf).unwrap().unwrap();
let res = Response::from_request(&req).unwrap();
let shake = Handshake {
request: req,
response: res,
peer_addr: None,
local_addr: None,
};
assert_eq!(shake.remote_addr().unwrap().unwrap(), "192.0.2.43");
}
}