use header::extensions::Extension;
use header::{Origin, WebSocketExtensions, WebSocketKey, WebSocketProtocol, WebSocketVersion};
use hyper::header::{Header, HeaderFormat, Headers};
use hyper::version::HttpVersion;
use std::borrow::Cow;
use std::convert::Into;
pub use url::{ParseError, Url};
#[cfg(any(feature = "sync", feature = "async"))]
mod common_imports {
pub use header::WebSocketAccept;
pub use hyper::buffer::BufReader;
pub use hyper::header::{Connection, ConnectionOption, Host, Protocol, ProtocolName, Upgrade};
pub use hyper::http::h1::parse_response;
pub use hyper::http::h1::Incoming;
pub use hyper::http::RawStatus;
pub use hyper::method::Method;
pub use hyper::status::StatusCode;
pub use hyper::uri::RequestUri;
pub use result::{WSUrlErrorKind, WebSocketError, WebSocketResult};
pub use std::net::TcpStream;
pub use std::net::ToSocketAddrs;
pub use stream::{self, Stream};
pub use unicase::UniCase;
pub use url::Position;
}
#[cfg(any(feature = "sync", feature = "async"))]
use self::common_imports::*;
#[cfg(feature = "sync")]
use super::sync::Client;
#[cfg(feature = "sync-ssl")]
use stream::sync::NetworkStream;
#[cfg(any(feature = "sync-ssl", feature = "async-ssl"))]
use native_tls::TlsConnector;
#[cfg(feature = "sync-ssl")]
use native_tls::TlsStream;
#[cfg(feature = "async")]
mod async_imports {
pub use super::super::async;
pub use codec::ws::{Context, MessageCodec};
pub use futures::future;
pub use futures::Stream as FutureStream;
pub use futures::{Future, IntoFuture, Sink};
pub use tokio::codec::FramedParts;
pub use tokio::codec::{Decoder, Framed};
pub use tokio::net::TcpStream as TcpStreamNew;
pub use tokio::reactor::Handle;
#[cfg(feature = "async-ssl")]
pub use tokio_tls::TlsConnector as TlsConnectorExt;
pub use ws::util::update_framed_codec;
}
#[cfg(feature = "async")]
use self::async_imports::*;
#[derive(Clone, Debug)]
pub struct ClientBuilder<'u> {
url: Cow<'u, Url>,
version: HttpVersion,
headers: Headers,
version_set: bool,
key_set: bool,
}
impl<'u> ClientBuilder<'u> {
pub fn from_url(address: &'u Url) -> Self {
ClientBuilder::init(Cow::Borrowed(address))
}
#[cfg_attr(feature = "cargo-clippy", warn(new_ret_no_self))]
pub fn new(address: &str) -> Result<Self, ParseError> {
let url = Url::parse(address)?;
Ok(ClientBuilder::init(Cow::Owned(url)))
}
fn init(url: Cow<'u, Url>) -> Self {
ClientBuilder {
url,
version: HttpVersion::Http11,
version_set: false,
key_set: false,
headers: Headers::new(),
}
}
pub fn add_protocol<P>(mut self, protocol: P) -> Self
where
P: Into<String>,
{
upsert_header!(self.headers; WebSocketProtocol; {
Some(protos) => protos.0.push(protocol.into()),
None => WebSocketProtocol(vec![protocol.into()])
});
self
}
pub fn add_protocols<I, S>(mut self, protocols: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
let mut protocols: Vec<String> = protocols.into_iter().map(Into::into).collect();
upsert_header!(self.headers; WebSocketProtocol; {
Some(protos) => protos.0.append(&mut protocols),
None => WebSocketProtocol(protocols)
});
self
}
pub fn clear_protocols(mut self) -> Self {
self.headers.remove::<WebSocketProtocol>();
self
}
pub fn add_extension(mut self, extension: Extension) -> Self {
upsert_header!(self.headers; WebSocketExtensions; {
Some(protos) => protos.0.push(extension),
None => WebSocketExtensions(vec![extension])
});
self
}
pub fn add_extensions<I>(mut self, extensions: I) -> Self
where
I: IntoIterator<Item = Extension>,
{
let mut extensions: Vec<Extension> = extensions.into_iter().collect();
upsert_header!(self.headers; WebSocketExtensions; {
Some(protos) => protos.0.append(&mut extensions),
None => WebSocketExtensions(extensions)
});
self
}
pub fn clear_extensions(mut self) -> Self {
self.headers.remove::<WebSocketExtensions>();
self
}
pub fn key(mut self, key: [u8; 16]) -> Self {
self.headers.set(WebSocketKey(key));
self.key_set = true;
self
}
pub fn clear_key(mut self) -> Self {
self.headers.remove::<WebSocketKey>();
self.key_set = false;
self
}
pub fn version(mut self, version: WebSocketVersion) -> Self {
self.headers.set(version);
self.version_set = true;
self
}
pub fn clear_version(mut self) -> Self {
self.headers.remove::<WebSocketVersion>();
self.version_set = false;
self
}
pub fn origin(mut self, origin: String) -> Self {
self.headers.set(Origin(origin));
self
}
pub fn clear_origin(mut self) -> Self {
self.headers.remove::<Origin>();
self
}
pub fn custom_headers(mut self, custom_headers: &Headers) -> Self {
self.headers.extend(custom_headers.iter());
self
}
pub fn clear_header<H>(mut self) -> Self
where
H: Header + HeaderFormat,
{
self.headers.remove::<H>();
self
}
pub fn get_header<H>(&self) -> Option<&H>
where
H: Header + HeaderFormat,
{
self.headers.get::<H>()
}
#[cfg(feature = "sync-ssl")]
pub fn connect(
&mut self,
ssl_config: Option<TlsConnector>,
) -> WebSocketResult<Client<Box<NetworkStream + Send>>> {
let tcp_stream = self.establish_tcp(None)?;
let boxed_stream: Box<NetworkStream + Send> = if self.is_secure_url() {
Box::new(self.wrap_ssl(tcp_stream, ssl_config)?)
} else {
Box::new(tcp_stream)
};
self.connect_on(boxed_stream)
}
#[cfg(feature = "sync")]
pub fn connect_insecure(&mut self) -> WebSocketResult<Client<TcpStream>> {
let tcp_stream = self.establish_tcp(Some(false))?;
self.connect_on(tcp_stream)
}
#[cfg(feature = "sync-ssl")]
pub fn connect_secure(
&mut self,
ssl_config: Option<TlsConnector>,
) -> WebSocketResult<Client<TlsStream<TcpStream>>> {
let tcp_stream = self.establish_tcp(Some(true))?;
let ssl_stream = self.wrap_ssl(tcp_stream, ssl_config)?;
self.connect_on(ssl_stream)
}
#[cfg(feature = "sync")]
pub fn connect_on<S>(&mut self, mut stream: S) -> WebSocketResult<Client<S>>
where
S: Stream,
{
let resource = self.build_request();
let data = format!("GET {} {}\r\n{}\r\n", resource, self.version, self.headers);
stream.write_all(data.as_bytes())?;
let mut reader = BufReader::new(stream);
let response = parse_response(&mut reader)?;
self.validate(&response)?;
Ok(Client::unchecked(reader, response.headers, true, false))
}
#[cfg(feature = "async-ssl")]
pub fn async_connect(
self,
ssl_config: Option<TlsConnector>,
) -> async::ClientNew<Box<stream::async::Stream + Send>> {
let tcp_stream = self.async_tcpstream(None);
let builder = ClientBuilder {
url: Cow::Owned(self.url.into_owned()),
version: self.version,
headers: self.headers,
version_set: self.version_set,
key_set: self.key_set,
};
if builder.is_secure_url() {
let (host, connector) = {
match builder.extract_host_ssl_conn(ssl_config) {
Ok((h, conn)) => (h.to_string(), TlsConnectorExt::from(conn)),
Err(e) => return Box::new(future::err(e)),
}
};
let future = tcp_stream
.and_then(move |s| connector.connect(&host, s).map_err(Into::into))
.and_then(move |stream| {
let stream: Box<stream::async::Stream + Send> = Box::new(stream);
builder.async_connect_on(stream)
});
Box::new(future)
} else {
let future = tcp_stream.and_then(move |stream| {
let stream: Box<stream::async::Stream + Send> = Box::new(stream);
builder.async_connect_on(stream)
});
Box::new(future)
}
}
#[cfg(feature = "async-ssl")]
pub fn async_connect_secure(
self,
ssl_config: Option<TlsConnector>,
) -> async::ClientNew<async::TlsStream<async::TcpStream>> {
let tcp_stream = self.async_tcpstream(Some(true));
let (host, connector) = {
match self.extract_host_ssl_conn(ssl_config) {
Ok((h, conn)) => (h.to_string(), TlsConnectorExt::from(conn)),
Err(e) => return Box::new(future::err(e)),
}
};
let builder = ClientBuilder {
url: Cow::Owned(self.url.into_owned()),
version: self.version,
headers: self.headers,
version_set: self.version_set,
key_set: self.key_set,
};
let future = tcp_stream
.and_then(move |s| connector.connect(&host, s).map_err(Into::into))
.and_then(move |stream| builder.async_connect_on(stream));
Box::new(future)
}
#[cfg(feature = "async")]
pub fn async_connect_insecure(self) -> async::ClientNew<async::TcpStream> {
let tcp_stream = self.async_tcpstream(Some(false));
let builder = ClientBuilder {
url: Cow::Owned(self.url.into_owned()),
version: self.version,
headers: self.headers,
version_set: self.version_set,
key_set: self.key_set,
};
let future = tcp_stream.and_then(move |stream| builder.async_connect_on(stream));
Box::new(future)
}
#[cfg(feature = "async")]
pub fn async_connect_on<S>(self, stream: S) -> async::ClientNew<S>
where
S: stream::async::Stream + Send + 'static,
{
let mut builder = ClientBuilder {
url: Cow::Owned(self.url.into_owned()),
version: self.version,
headers: self.headers,
version_set: self.version_set,
key_set: self.key_set,
};
let resource = builder.build_request();
let framed = ::codec::http::HttpClientCodec.framed(stream);
let request = Incoming {
version: builder.version,
headers: builder.headers.clone(),
subject: (Method::Get, RequestUri::AbsolutePath(resource)),
};
let future = framed
.send(request)
.map_err(::std::convert::Into::into)
.and_then(|stream| stream.into_future().map_err(|e| e.0.into()))
.and_then(move |(message, stream)| {
message
.ok_or(WebSocketError::ProtocolError(
"Connection closed before handshake could complete.",
))
.and_then(|message| builder.validate(&message).map(|()| (message, stream)))
})
.map(|(message, stream)| {
let codec = MessageCodec::default(Context::Client);
let client = update_framed_codec(stream, codec);
(client, message.headers)
});
Box::new(future)
}
#[cfg(feature = "async")]
fn async_tcpstream(
&self,
secure: Option<bool>,
) -> Box<future::Future<Item = TcpStreamNew, Error = WebSocketError> + Send> {
let address = match self
.extract_host_port(secure)
.and_then(|p| Ok(p.to_socket_addrs()?))
{
Ok(mut s) => match s.next() {
Some(a) => a,
None => {
return Box::new(
Err(WebSocketError::WebSocketUrlError(
WSUrlErrorKind::NoHostName,
))
.into_future(),
);
}
},
Err(e) => return Box::new(Err(e).into_future()),
};
Box::new(TcpStreamNew::connect(&address).map_err(Into::into))
}
#[cfg(any(feature = "sync", feature = "async"))]
fn build_request(&mut self) -> String {
if let Some(host) = self.url.host_str() {
self.headers.set(Host {
hostname: host.to_string(),
port: self.url.port(),
});
}
self.headers
.set(Connection(vec![ConnectionOption::ConnectionHeader(
UniCase("Upgrade".to_string()),
)]));
self.headers.set(Upgrade(vec![Protocol {
name: ProtocolName::WebSocket,
version: None,
}]));
if !self.version_set {
self.headers.set(WebSocketVersion::WebSocket13);
}
if !self.key_set {
self.headers.set(WebSocketKey::new());
}
self.url[Position::BeforePath..Position::AfterQuery].to_owned()
}
#[cfg(any(feature = "sync", feature = "async"))]
fn validate(&self, response: &Incoming<RawStatus>) -> WebSocketResult<()> {
let status = StatusCode::from_u16(response.subject.0);
if status != StatusCode::SwitchingProtocols {
return Err(WebSocketError::ResponseError(
"Status code must be Switching Protocols",
));
}
let key = self
.headers
.get::<WebSocketKey>()
.ok_or(WebSocketError::RequestError(
"Request Sec-WebSocket-Key was invalid",
))?;
if response.headers.get() != Some(&(WebSocketAccept::new(key))) {
return Err(WebSocketError::ResponseError(
"Sec-WebSocket-Accept is invalid",
));
}
if response.headers.get()
!= Some(
&(Upgrade(vec![Protocol {
name: ProtocolName::WebSocket,
version: None,
}])),
) {
return Err(WebSocketError::ResponseError(
"Upgrade field must be WebSocket",
));
}
if self.headers.get()
!= Some(
&(Connection(vec![ConnectionOption::ConnectionHeader(UniCase(
"Upgrade".to_string(),
))])),
) {
return Err(WebSocketError::ResponseError(
"Connection field must be 'Upgrade'",
));
}
Ok(())
}
#[cfg(any(feature = "sync-ssl", feature = "async-ssl"))]
fn is_secure_url(&self) -> bool {
let scheme = self.url.scheme();
scheme == "wss" || scheme == "https"
}
#[cfg(any(feature = "sync", feature = "async"))]
fn extract_host_port(&self, secure: Option<bool>) -> WebSocketResult<::url::HostAndPort<&str>> {
if self.url.host().is_none() {
return Err(WebSocketError::WebSocketUrlError(
WSUrlErrorKind::NoHostName,
));
}
Ok(self.url.with_default_port(|url| {
const SECURE_PORT: u16 = 443;
const INSECURE_PORT: u16 = 80;
const SECURE_WS_SCHEME: &str = "wss";
Ok(match secure {
None if url.scheme() == SECURE_WS_SCHEME => SECURE_PORT,
None => INSECURE_PORT,
Some(true) => SECURE_PORT,
Some(false) => INSECURE_PORT,
})
})?)
}
#[cfg(feature = "sync")]
fn establish_tcp(&mut self, secure: Option<bool>) -> WebSocketResult<TcpStream> {
Ok(TcpStream::connect(self.extract_host_port(secure)?)?)
}
#[cfg(any(feature = "sync-ssl", feature = "async-ssl"))]
fn extract_host_ssl_conn(
&self,
connector: Option<TlsConnector>,
) -> WebSocketResult<(&str, TlsConnector)> {
let host = match self.url.host_str() {
Some(h) => h,
None => {
return Err(WebSocketError::WebSocketUrlError(
WSUrlErrorKind::NoHostName,
));
}
};
let connector = match connector {
Some(c) => c,
None => TlsConnector::builder().build()?,
};
Ok((host, connector))
}
#[cfg(feature = "sync-ssl")]
fn wrap_ssl(
&self,
tcp_stream: TcpStream,
connector: Option<TlsConnector>,
) -> WebSocketResult<TlsStream<TcpStream>> {
let (host, connector) = self.extract_host_ssl_conn(connector)?;
let ssl_stream = connector.connect(host, tcp_stream)?;
Ok(ssl_stream)
}
}
mod tests {
#[test]
fn build_client_with_protocols() {
use super::*;
let builder = ClientBuilder::new("ws://127.0.0.1:8080/hello/world")
.unwrap()
.add_protocol("protobeard");
let protos = &builder.headers.get::<WebSocketProtocol>().unwrap().0;
assert!(protos.contains(&"protobeard".to_string()));
assert!(protos.len() == 1);
let builder = ClientBuilder::new("ws://example.org/hello")
.unwrap()
.add_protocol("rust-websocket")
.clear_protocols()
.add_protocols(vec!["electric", "boogaloo"]);
let protos = &builder.headers.get::<WebSocketProtocol>().unwrap().0;
assert!(protos.contains(&"boogaloo".to_string()));
assert!(protos.contains(&"electric".to_string()));
assert!(!protos.contains(&"rust-websocket".to_string()));
}
}