use std::borrow::Cow;
pub use url::{Url, ParseError};
use header::extensions::Extension;
use header::{WebSocketKey, WebSocketVersion, WebSocketProtocol, WebSocketExtensions, Origin};
use hyper::header::{Headers, Header, HeaderFormat};
use hyper::version::HttpVersion;
#[cfg(any(feature="sync", feature="async"))]
mod common_imports {
pub use std::net::TcpStream;
pub use std::net::ToSocketAddrs;
pub use url::Position;
pub use hyper::http::h1::Incoming;
pub use hyper::http::RawStatus;
pub use hyper::status::StatusCode;
pub use hyper::buffer::BufReader;
pub use hyper::method::Method;
pub use hyper::uri::RequestUri;
pub use hyper::http::h1::parse_response;
pub use hyper::header::{Host, Connection, ConnectionOption, Upgrade, Protocol, ProtocolName};
pub use unicase::UniCase;
pub use header::WebSocketAccept;
pub use result::{WSUrlErrorKind, WebSocketResult, WebSocketError};
pub use stream::{self, Stream};
}
#[cfg(any(feature="sync", feature="async"))]
use self::common_imports::*;
#[cfg(feature="sync")]
use super::sync::Client;
#[cfg(feature="sync-ssl")]
use stream::sync::NetworkStream;
#[cfg(any(feature="sync-ssl", feature="async-ssl"))]
use native_tls::TlsConnector;
#[cfg(feature="sync-ssl")]
use native_tls::TlsStream;
#[cfg(feature="async")]
mod async_imports {
pub use super::super::async;
pub use tokio_io::codec::Framed;
pub use tokio_core::net::TcpStreamNew;
pub use tokio_core::reactor::Handle;
pub use futures::{Future, Sink};
pub use futures::future;
pub use futures::Stream as FutureStream;
pub use codec::ws::{MessageCodec, Context};
#[cfg(feature="async-ssl")]
pub use tokio_tls::TlsConnectorExt;
}
#[cfg(feature="async")]
use self::async_imports::*;
#[derive(Clone, Debug)]
pub struct ClientBuilder<'u> {
url: Cow<'u, Url>,
version: HttpVersion,
headers: Headers,
version_set: bool,
key_set: bool,
}
impl<'u> ClientBuilder<'u> {
pub fn from_url(address: &'u Url) -> Self {
ClientBuilder::init(Cow::Borrowed(address))
}
pub fn new(address: &str) -> Result<Self, ParseError> {
let url = Url::parse(address)?;
Ok(ClientBuilder::init(Cow::Owned(url)))
}
fn init(url: Cow<'u, Url>) -> Self {
ClientBuilder {
url: url,
version: HttpVersion::Http11,
version_set: false,
key_set: false,
headers: Headers::new(),
}
}
pub fn add_protocol<P>(mut self, protocol: P) -> Self
where P: Into<String>
{
upsert_header!(self.headers; WebSocketProtocol; {
Some(protos) => protos.0.push(protocol.into()),
None => WebSocketProtocol(vec![protocol.into()])
});
self
}
pub fn add_protocols<I, S>(mut self, protocols: I) -> Self
where I: IntoIterator<Item = S>,
S: Into<String>
{
let mut protocols: Vec<String> =
protocols.into_iter()
.map(Into::into)
.collect();
upsert_header!(self.headers; WebSocketProtocol; {
Some(protos) => protos.0.append(&mut protocols),
None => WebSocketProtocol(protocols)
});
self
}
pub fn clear_protocols(mut self) -> Self {
self.headers.remove::<WebSocketProtocol>();
self
}
pub fn add_extension(mut self, extension: Extension) -> Self {
upsert_header!(self.headers; WebSocketExtensions; {
Some(protos) => protos.0.push(extension),
None => WebSocketExtensions(vec![extension])
});
self
}
pub fn add_extensions<I>(mut self, extensions: I) -> Self
where I: IntoIterator<Item = Extension>
{
let mut extensions: Vec<Extension> =
extensions.into_iter().collect();
upsert_header!(self.headers; WebSocketExtensions; {
Some(protos) => protos.0.append(&mut extensions),
None => WebSocketExtensions(extensions)
});
self
}
pub fn clear_extensions(mut self) -> Self {
self.headers.remove::<WebSocketExtensions>();
self
}
pub fn key(mut self, key: [u8; 16]) -> Self {
self.headers.set(WebSocketKey(key));
self.key_set = true;
self
}
pub fn clear_key(mut self) -> Self {
self.headers.remove::<WebSocketKey>();
self.key_set = false;
self
}
pub fn version(mut self, version: WebSocketVersion) -> Self {
self.headers.set(version);
self.version_set = true;
self
}
pub fn clear_version(mut self) -> Self {
self.headers.remove::<WebSocketVersion>();
self.version_set = false;
self
}
pub fn origin(mut self, origin: String) -> Self {
self.headers.set(Origin(origin));
self
}
pub fn clear_origin(mut self) -> Self {
self.headers.remove::<Origin>();
self
}
pub fn custom_headers(mut self, custom_headers: &Headers) -> Self {
self.headers.extend(custom_headers.iter());
self
}
pub fn clear_header<H>(mut self) -> Self
where H: Header + HeaderFormat
{
self.headers.remove::<H>();
self
}
pub fn get_header<H>(&self) -> Option<&H>
where H: Header + HeaderFormat
{
self.headers.get::<H>()
}
#[cfg(feature="sync-ssl")]
pub fn connect(
&mut self,
ssl_config: Option<TlsConnector>,
) -> WebSocketResult<Client<Box<NetworkStream + Send>>> {
let tcp_stream = self.establish_tcp(None)?;
let boxed_stream: Box<NetworkStream + Send> = if
self.url.scheme() == "wss" {
Box::new(self.wrap_ssl(tcp_stream, ssl_config)?)
} else {
Box::new(tcp_stream)
};
self.connect_on(boxed_stream)
}
#[cfg(feature="sync")]
pub fn connect_insecure(&mut self) -> WebSocketResult<Client<TcpStream>> {
let tcp_stream = self.establish_tcp(Some(false))?;
self.connect_on(tcp_stream)
}
#[cfg(feature="sync-ssl")]
pub fn connect_secure(
&mut self,
ssl_config: Option<TlsConnector>,
) -> WebSocketResult<Client<TlsStream<TcpStream>>> {
let tcp_stream = self.establish_tcp(Some(true))?;
let ssl_stream = self.wrap_ssl(tcp_stream, ssl_config)?;
self.connect_on(ssl_stream)
}
#[cfg(feature="sync")]
pub fn connect_on<S>(&mut self, mut stream: S) -> WebSocketResult<Client<S>>
where S: Stream
{
let resource = self.build_request();
write!(stream, "GET {} {}\r\n", resource, self.version)?;
write!(stream, "{}\r\n", self.headers)?;
let mut reader = BufReader::new(stream);
let response = parse_response(&mut reader)?;
self.validate(&response)?;
Ok(Client::unchecked(reader, response.headers, true, false))
}
#[cfg(feature="async-ssl")]
pub fn async_connect(
self,
ssl_config: Option<TlsConnector>,
handle: &Handle,
) -> async::ClientNew<Box<stream::async::Stream + Send>> {
let tcp_stream = match self.async_tcpstream(None, handle) {
Ok(t) => t,
Err(e) => return Box::new(future::err(e)),
};
let builder = ClientBuilder {
url: Cow::Owned(self.url.into_owned()),
version: self.version,
headers: self.headers,
version_set: self.version_set,
key_set: self.key_set,
};
if builder.url.scheme() == "wss" {
let (host, connector) = {
match builder.extract_host_ssl_conn(ssl_config) {
Ok((h, conn)) => (h.to_string(), conn),
Err(e) => return Box::new(future::err(e)),
}
};
let future = tcp_stream.map_err(|e| e.into())
.and_then(move |s| {
connector.connect_async(&host, s)
.map_err(|e| e.into())
})
.and_then(move |stream| {
let stream: Box<stream::async::Stream + Send> = Box::new(stream);
builder.async_connect_on(stream)
});
Box::new(future)
} else {
let future = tcp_stream.map_err(|e| e.into())
.and_then(move |stream| {
let stream: Box<stream::async::Stream + Send> = Box::new(stream);
builder.async_connect_on(stream)
});
Box::new(future)
}
}
#[cfg(feature="async-ssl")]
pub fn async_connect_secure(
self,
ssl_config: Option<TlsConnector>,
handle: &Handle,
) -> async::ClientNew<async::TlsStream<async::TcpStream>> {
let tcp_stream = match self.async_tcpstream(Some(true), handle) {
Ok(t) => t,
Err(e) => return Box::new(future::err(e)),
};
let (host, connector) = {
match self.extract_host_ssl_conn(ssl_config) {
Ok((h, conn)) => (h.to_string(), conn),
Err(e) => return Box::new(future::err(e)),
}
};
let builder = ClientBuilder {
url: Cow::Owned(self.url.into_owned()),
version: self.version,
headers: self.headers,
version_set: self.version_set,
key_set: self.key_set,
};
let future =
tcp_stream.map_err(|e| e.into())
.and_then(move |s| {
connector.connect_async(&host, s)
.map_err(|e| e.into())
})
.and_then(move |stream| builder.async_connect_on(stream));
Box::new(future)
}
#[cfg(feature="async")]
pub fn async_connect_insecure(self, handle: &Handle) -> async::ClientNew<async::TcpStream> {
let tcp_stream = match self.async_tcpstream(Some(false), handle) {
Ok(t) => t,
Err(e) => return Box::new(future::err(e)),
};
let builder = ClientBuilder {
url: Cow::Owned(self.url.into_owned()),
version: self.version,
headers: self.headers,
version_set: self.version_set,
key_set: self.key_set,
};
let future =
tcp_stream.map_err(|e| e.into())
.and_then(move |stream| builder.async_connect_on(stream));
Box::new(future)
}
#[cfg(feature="async")]
pub fn async_connect_on<S>(self, stream: S) -> async::ClientNew<S>
where S: stream::async::Stream + Send + 'static
{
let mut builder = ClientBuilder {
url: Cow::Owned(self.url.into_owned()),
version: self.version,
headers: self.headers,
version_set: self.version_set,
key_set: self.key_set,
};
let resource = builder.build_request();
let framed = stream.framed(::codec::http::HttpClientCodec);
let request = Incoming {
version: builder.version,
headers: builder.headers.clone(),
subject: (Method::Get, RequestUri::AbsolutePath(resource)),
};
let future = framed
.send(request).map_err(::std::convert::Into::into)
.and_then(|stream| stream.into_future().map_err(|e| e.0.into()))
.and_then(move |(message, stream)| {
message
.ok_or(WebSocketError::ProtocolError(
"Connection closed before handshake could complete."))
.and_then(|message| builder.validate(&message).map(|()| (message, stream)))
})
.map(|(message, stream)| {
let codec = MessageCodec::default(Context::Client);
let client = Framed::from_parts(stream.into_parts(), codec);
(client, message.headers)
});
Box::new(future)
}
#[cfg(feature="async")]
fn async_tcpstream(
&self,
secure: Option<bool>,
handle: &Handle,
) -> WebSocketResult<TcpStreamNew> {
let address = match self.extract_host_port(secure).and_then(|p| Ok(p.to_socket_addrs()?)) {
Ok(mut s) => {
match s.next() {
Some(a) => a,
None => {
return Err(WebSocketError::WebSocketUrlError(WSUrlErrorKind::NoHostName));
}
}
}
Err(e) => return Err(e.into()),
};
Ok(async::TcpStream::connect(&address, handle))
}
#[cfg(any(feature="sync", feature="async"))]
fn build_request(&mut self) -> String {
if let Some(host) = self.url.host_str() {
self.headers
.set(Host {
hostname: host.to_string(),
port: self.url.port(),
});
}
self.headers
.set(Connection(vec![
ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string()))
]));
self.headers
.set(Upgrade(vec![
Protocol {
name: ProtocolName::WebSocket,
version: None,
},
]));
if !self.version_set {
self.headers.set(WebSocketVersion::WebSocket13);
}
if !self.key_set {
self.headers.set(WebSocketKey::new());
}
let resource = self.url[Position::BeforePath..Position::AfterQuery].to_owned();
resource
}
#[cfg(any(feature="sync", feature="async"))]
fn validate(&self, response: &Incoming<RawStatus>) -> WebSocketResult<()> {
let status = StatusCode::from_u16(response.subject.0);
if status != StatusCode::SwitchingProtocols {
return Err(WebSocketError::ResponseError("Status code must be Switching Protocols"));
}
let key =
self.headers
.get::<WebSocketKey>()
.ok_or(WebSocketError::RequestError("Request Sec-WebSocket-Key was invalid"))?;
if response.headers.get() != Some(&(WebSocketAccept::new(key))) {
return Err(WebSocketError::ResponseError("Sec-WebSocket-Accept is invalid"));
}
if response.headers.get() !=
Some(&(Upgrade(vec![
Protocol {
name: ProtocolName::WebSocket,
version: None,
},
]))) {
return Err(WebSocketError::ResponseError("Upgrade field must be WebSocket"));
}
if self.headers.get() !=
Some(&(Connection(vec![
ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())),
]))) {
return Err(WebSocketError::ResponseError("Connection field must be 'Upgrade'"));
}
Ok(())
}
#[cfg(any(feature="sync", feature="async"))]
fn extract_host_port(&self, secure: Option<bool>) -> WebSocketResult<(&str, u16)> {
let port = match (self.url.port(), secure) {
(Some(port), _) => port,
(None, None) if self.url.scheme() == "wss" => 443,
(None, None) => 80,
(None, Some(true)) => 443,
(None, Some(false)) => 80,
};
let host = match self.url.host_str() {
Some(h) => h,
None => return Err(WebSocketError::WebSocketUrlError(WSUrlErrorKind::NoHostName)),
};
Ok((host, port))
}
#[cfg(feature="sync")]
fn establish_tcp(&mut self, secure: Option<bool>) -> WebSocketResult<TcpStream> {
Ok(TcpStream::connect(self.extract_host_port(secure)?)?)
}
#[cfg(any(feature="sync-ssl", feature="async-ssl"))]
fn extract_host_ssl_conn(
&self,
connector: Option<TlsConnector>,
) -> WebSocketResult<(&str, TlsConnector)> {
let host = match self.url.host_str() {
Some(h) => h,
None => return Err(WebSocketError::WebSocketUrlError(WSUrlErrorKind::NoHostName)),
};
let connector = match connector {
Some(c) => c,
None => TlsConnector::builder()?.build()?,
};
Ok((host, connector))
}
#[cfg(feature="sync-ssl")]
fn wrap_ssl(
&self,
tcp_stream: TcpStream,
connector: Option<TlsConnector>,
) -> WebSocketResult<TlsStream<TcpStream>> {
let (host, connector) = self.extract_host_ssl_conn(connector)?;
let ssl_stream = connector.connect(host, tcp_stream)?;
Ok(ssl_stream)
}
}
mod tests {
#[test]
fn build_client_with_protocols() {
use super::*;
let builder = ClientBuilder::new("ws://127.0.0.1:8080/hello/world")
.unwrap()
.add_protocol("protobeard");
let protos = &builder.headers.get::<WebSocketProtocol>().unwrap().0;
assert!(protos.contains(&"protobeard".to_string()));
assert!(protos.len() == 1);
let builder = ClientBuilder::new("ws://example.org/hello")
.unwrap()
.add_protocol("rust-websocket")
.clear_protocols()
.add_protocols(vec!["electric", "boogaloo"]);
let protos = &builder.headers.get::<WebSocketProtocol>().unwrap().0;
assert!(protos.contains(&"boogaloo".to_string()));
assert!(protos.contains(&"electric".to_string()));
assert!(!protos.contains(&"rust-websocket".to_string()));
}
}