use crate::header::extensions::Extension;
use crate::header::{
Origin, WebSocketExtensions, WebSocketKey, WebSocketProtocol, WebSocketVersion,
};
use hyper::header::{Authorization, Basic, Header, HeaderFormat, Headers};
use hyper::version::HttpVersion;
use std::borrow::Cow;
use std::convert::Into;
pub use url::{ParseError, Url};
const DEFAULT_MAX_DATAFRAME_SIZE : usize = 1024*1024*100;
const DEFAULT_MAX_MESSAGE_SIZE : usize = 1024*1024*200;
#[cfg(any(feature = "sync", feature = "async"))]
mod common_imports {
pub use crate::header::WebSocketAccept;
pub use crate::result::{WSUrlErrorKind, WebSocketError, WebSocketOtherError, WebSocketResult};
pub use crate::stream::{self, Stream};
pub use hyper::buffer::BufReader;
pub use hyper::header::{Connection, ConnectionOption, Host, Protocol, ProtocolName, Upgrade};
pub use hyper::http::h1::parse_response;
pub use hyper::http::h1::Incoming;
pub use hyper::http::RawStatus;
pub use hyper::method::Method;
pub use hyper::status::StatusCode;
pub use hyper::uri::RequestUri;
pub use std::net::TcpStream;
pub use std::net::ToSocketAddrs;
pub use unicase::UniCase;
pub use url::Position;
}
#[cfg(any(feature = "sync", feature = "async"))]
use self::common_imports::*;
#[cfg(feature = "sync")]
use super::sync::Client;
#[cfg(feature = "sync-ssl")]
use crate::stream::sync::NetworkStream;
#[cfg(any(feature = "sync-ssl", feature = "async-ssl"))]
use native_tls::TlsConnector;
#[cfg(feature = "sync-ssl")]
use native_tls::TlsStream;
#[cfg(feature = "async")]
mod async_imports {
pub use super::super::r#async;
pub use crate::codec::ws::{Context, MessageCodec};
pub use crate::ws::util::update_framed_codec;
pub use futures::future;
pub use futures::Stream as FutureStream;
pub use futures::{Future, IntoFuture, Sink};
pub use tokio_codec::FramedParts;
pub use tokio_codec::{Decoder, Framed};
pub use tokio_reactor::Handle;
pub use tokio_tcp::TcpStream as TcpStreamNew;
#[cfg(feature = "async-ssl")]
pub use tokio_tls::TlsConnector as TlsConnectorExt;
}
#[cfg(feature = "async")]
use self::async_imports::*;
use crate::result::towse;
#[derive(Clone, Debug)]
pub struct ClientBuilder<'u> {
url: Cow<'u, Url>,
version: HttpVersion,
headers: Headers,
version_set: bool,
key_set: bool,
max_dataframe_size: usize,
max_message_size: usize,
}
impl<'u> ClientBuilder<'u> {
pub fn from_url(address: &'u Url) -> Self {
ClientBuilder::init(Cow::Borrowed(address))
}
#[warn(clippy::new_ret_no_self)]
pub fn new(address: &str) -> Result<Self, ParseError> {
let url = Url::parse(address)?;
Ok(ClientBuilder::init(Cow::Owned(url)))
}
fn init(url: Cow<'u, Url>) -> Self {
ClientBuilder {
url,
version: HttpVersion::Http11,
version_set: false,
key_set: false,
headers: Headers::new(),
max_dataframe_size: DEFAULT_MAX_DATAFRAME_SIZE,
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
}
}
pub fn add_protocol<P>(mut self, protocol: P) -> Self
where
P: Into<String>,
{
upsert_header!(self.headers; WebSocketProtocol; {
Some(protos) => protos.0.push(protocol.into()),
None => WebSocketProtocol(vec![protocol.into()])
});
self
}
pub fn add_protocols<I, S>(mut self, protocols: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
let mut protocols: Vec<String> = protocols.into_iter().map(Into::into).collect();
upsert_header!(self.headers; WebSocketProtocol; {
Some(protos) => protos.0.append(&mut protocols),
None => WebSocketProtocol(protocols)
});
self
}
pub fn clear_protocols(mut self) -> Self {
self.headers.remove::<WebSocketProtocol>();
self
}
pub fn add_extension(mut self, extension: Extension) -> Self {
upsert_header!(self.headers; WebSocketExtensions; {
Some(protos) => protos.0.push(extension),
None => WebSocketExtensions(vec![extension])
});
self
}
pub fn add_extensions<I>(mut self, extensions: I) -> Self
where
I: IntoIterator<Item = Extension>,
{
let mut extensions: Vec<Extension> = extensions.into_iter().collect();
upsert_header!(self.headers; WebSocketExtensions; {
Some(protos) => protos.0.append(&mut extensions),
None => WebSocketExtensions(extensions)
});
self
}
pub fn clear_extensions(mut self) -> Self {
self.headers.remove::<WebSocketExtensions>();
self
}
pub fn max_dataframe_size(mut self, value: usize) -> Self {
self.max_dataframe_size = value;
self
}
pub fn max_message_size(mut self, value: usize) -> Self {
self.max_message_size = value;
self
}
pub fn key(mut self, key: [u8; 16]) -> Self {
self.headers.set(WebSocketKey::from_array(key));
self.key_set = true;
self
}
pub fn clear_key(mut self) -> Self {
self.headers.remove::<WebSocketKey>();
self.key_set = false;
self
}
pub fn version(mut self, version: WebSocketVersion) -> Self {
self.headers.set(version);
self.version_set = true;
self
}
pub fn clear_version(mut self) -> Self {
self.headers.remove::<WebSocketVersion>();
self.version_set = false;
self
}
pub fn origin(mut self, origin: String) -> Self {
self.headers.set(Origin(origin));
self
}
pub fn clear_origin(mut self) -> Self {
self.headers.remove::<Origin>();
self
}
pub fn custom_headers(mut self, custom_headers: &Headers) -> Self {
self.headers.extend(custom_headers.iter());
self
}
pub fn clear_header<H>(mut self) -> Self
where
H: Header + HeaderFormat,
{
self.headers.remove::<H>();
self
}
pub fn get_header<H>(&self) -> Option<&H>
where
H: Header + HeaderFormat,
{
self.headers.get::<H>()
}
#[cfg(feature = "sync-ssl")]
pub fn connect(
&mut self,
ssl_config: Option<TlsConnector>,
) -> WebSocketResult<Client<Box<dyn NetworkStream + Send>>> {
let tcp_stream = self.establish_tcp(None)?;
let boxed_stream: Box<dyn NetworkStream + Send> = if self.is_secure_url() {
Box::new(self.wrap_ssl(tcp_stream, ssl_config)?)
} else {
Box::new(tcp_stream)
};
self.connect_on(boxed_stream)
}
#[cfg(feature = "sync")]
pub fn connect_insecure(&mut self) -> WebSocketResult<Client<TcpStream>> {
let tcp_stream = self.establish_tcp(Some(false))?;
self.connect_on(tcp_stream)
}
#[cfg(feature = "sync-ssl")]
pub fn connect_secure(
&mut self,
ssl_config: Option<TlsConnector>,
) -> WebSocketResult<Client<TlsStream<TcpStream>>> {
let tcp_stream = self.establish_tcp(Some(true))?;
let ssl_stream = self.wrap_ssl(tcp_stream, ssl_config)?;
self.connect_on(ssl_stream)
}
#[cfg(feature = "sync")]
pub fn connect_on<S>(&mut self, mut stream: S) -> WebSocketResult<Client<S>>
where
S: Stream,
{
let resource = self.build_request();
let data = format!("GET {} {}\r\n{}\r\n", resource, self.version, self.headers);
stream.write_all(data.as_bytes())?;
let mut reader = BufReader::new(stream);
let response = parse_response(&mut reader).map_err(towse)?;
self.validate(&response)?;
Ok(Client::unchecked_with_limits(reader, response.headers, true, false, self.max_dataframe_size, self.max_dataframe_size))
}
#[cfg(feature = "async-ssl")]
pub fn async_connect(
self,
ssl_config: Option<TlsConnector>,
) -> r#async::ClientNew<Box<dyn stream::r#async::Stream + Send>> {
let tcp_stream = self.async_tcpstream(None);
let builder = ClientBuilder {
url: Cow::Owned(self.url.into_owned()),
version: self.version,
headers: self.headers,
version_set: self.version_set,
key_set: self.key_set,
max_dataframe_size: self.max_dataframe_size,
max_message_size: self.max_message_size,
};
if builder.is_secure_url() {
let (host, connector) = {
match builder.extract_host_ssl_conn(ssl_config) {
Ok((h, conn)) => (h.to_string(), TlsConnectorExt::from(conn)),
Err(e) => return Box::new(future::err(e)),
}
};
let future = tcp_stream
.and_then(move |s| connector.connect(&host, s).map_err(towse))
.and_then(move |stream| {
let stream: Box<dyn stream::r#async::Stream + Send> = Box::new(stream);
builder.async_connect_on(stream)
});
Box::new(future)
} else {
let future = tcp_stream.and_then(move |stream| {
let stream: Box<dyn stream::r#async::Stream + Send> = Box::new(stream);
builder.async_connect_on(stream)
});
Box::new(future)
}
}
#[cfg(feature = "async-ssl")]
pub fn async_connect_with_cb(
self,
ssl_config: Option<TlsConnector>,
cb: impl FnOnce(url::SocketAddrs) -> Box<dyn future::Future<Item=TcpStreamNew, Error=WebSocketError> + Send>,
) -> r#async::ClientNew<Box<dyn stream::r#async::Stream + Send>> {
let tcp_stream = self.async_tcpstream_with_cb(None, cb);
let builder = ClientBuilder {
url: Cow::Owned(self.url.into_owned()),
version: self.version,
headers: self.headers,
version_set: self.version_set,
key_set: self.key_set,
max_dataframe_size: self.max_dataframe_size,
max_message_size: self.max_message_size,
};
if builder.is_secure_url() {
let (host, connector) = {
match builder.extract_host_ssl_conn(ssl_config) {
Ok((h, conn)) => (h.to_string(), TlsConnectorExt::from(conn)),
Err(e) => return Box::new(future::err(e)),
}
};
let future = tcp_stream
.and_then(move |s| connector.connect(&host, s).map_err(towse))
.and_then(move |stream| {
let stream: Box<dyn stream::r#async::Stream + Send> = Box::new(stream);
builder.async_connect_on(stream)
});
Box::new(future)
} else {
let future = tcp_stream.and_then(move |stream| {
let stream: Box<dyn stream::r#async::Stream + Send> = Box::new(stream);
builder.async_connect_on(stream)
});
Box::new(future)
}
}
#[cfg(feature = "async-ssl")]
pub fn async_connect_secure(
self,
ssl_config: Option<TlsConnector>,
) -> r#async::ClientNew<r#async::TlsStream<r#async::TcpStream>> {
let tcp_stream = self.async_tcpstream(Some(true));
let (host, connector) = {
match self.extract_host_ssl_conn(ssl_config) {
Ok((h, conn)) => (h.to_string(), TlsConnectorExt::from(conn)),
Err(e) => return Box::new(future::err(e)),
}
};
let builder = ClientBuilder {
url: Cow::Owned(self.url.into_owned()),
version: self.version,
headers: self.headers,
version_set: self.version_set,
key_set: self.key_set,
max_dataframe_size: self.max_dataframe_size,
max_message_size: self.max_message_size,
};
let future = tcp_stream
.and_then(move |s| connector.connect(&host, s).map_err(towse))
.and_then(move |stream| builder.async_connect_on(stream));
Box::new(future)
}
#[cfg(feature = "async")]
pub fn async_connect_insecure(self) -> r#async::ClientNew<r#async::TcpStream> {
let tcp_stream = self.async_tcpstream(Some(false));
let builder = ClientBuilder {
url: Cow::Owned(self.url.into_owned()),
version: self.version,
headers: self.headers,
version_set: self.version_set,
key_set: self.key_set,
max_dataframe_size: self.max_dataframe_size,
max_message_size: self.max_message_size,
};
let future = tcp_stream.and_then(move |stream| builder.async_connect_on(stream));
Box::new(future)
}
#[cfg(feature = "async")]
pub fn async_connect_on<S>(self, stream: S) -> r#async::ClientNew<S>
where
S: stream::r#async::Stream + Send + 'static,
{
let mut builder = ClientBuilder {
url: Cow::Owned(self.url.into_owned()),
version: self.version,
headers: self.headers,
version_set: self.version_set,
key_set: self.key_set,
max_dataframe_size: self.max_dataframe_size,
max_message_size: self.max_message_size,
};
let resource = builder.build_request();
let framed = crate::codec::http::HttpClientCodec.framed(stream);
let request = Incoming {
version: builder.version,
headers: builder.headers.clone(),
subject: (Method::Get, RequestUri::AbsolutePath(resource)),
};
let max_dataframe_size = self.max_dataframe_size;
let max_message_size = self.max_message_size;
let future = framed
.send(request)
.map_err(::std::convert::Into::into)
.and_then(|stream| stream.into_future().map_err(|e| towse(e.0)))
.and_then(move |(message, stream)| {
message
.ok_or(WebSocketError::ProtocolError(
"Connection closed before handshake could complete.",
))
.and_then(|message| builder.validate(&message).map(|()| (message, stream)))
})
.map(move |(message, stream)| {
let codec = MessageCodec::new_with_limits(Context::Client, max_dataframe_size, max_message_size);
let client = update_framed_codec(stream, codec);
(client, message.headers)
});
Box::new(future)
}
#[cfg(feature = "async")]
fn async_tcpstream(
&self,
secure: Option<bool>,
) -> Box<dyn future::Future<Item = TcpStreamNew, Error = WebSocketError> + Send> {
let address = match self
.extract_host_port(secure)
.and_then(|p| Ok(p.to_socket_addrs()?))
{
Ok(mut s) => match s.next() {
Some(a) => a,
None => {
return Box::new(
Err(WebSocketOtherError::WebSocketUrlError(
WSUrlErrorKind::NoHostName,
))
.map_err(towse)
.into_future(),
);
}
},
Err(e) => return Box::new(Err(e).into_future()),
};
Box::new(TcpStreamNew::connect(&address).map_err(Into::into))
}
#[cfg(feature = "async")]
fn async_tcpstream_with_cb(
&self,
secure: Option<bool>,
cb: impl FnOnce(url::SocketAddrs) -> Box<dyn future::Future<Item=TcpStreamNew, Error=WebSocketError> + Send>,
) -> Box<dyn future::Future<Item = TcpStreamNew, Error = WebSocketError> + Send> {
match self
.extract_host_port(secure)
.and_then(|p| Ok(p.to_socket_addrs()?))
{
Ok(s) => cb(s),
Err(e) => return Box::new(Err(e).into_future()),
}
}
#[cfg(any(feature = "sync", feature = "async"))]
fn build_request(&mut self) -> String {
if let Some(host) = self.url.host_str() {
self.headers.set(Host {
hostname: host.to_string(),
port: self.url.port(),
});
}
if !self.url.username().is_empty() {
self.headers.set(Authorization(Basic {
username: self.url.username().to_owned(),
password: match self.url.password() {
Some(password) => Some(password.to_owned()),
None => None,
},
}));
}
self.headers
.set(Connection(vec![ConnectionOption::ConnectionHeader(
UniCase("Upgrade".to_string()),
)]));
self.headers.set(Upgrade(vec![Protocol {
name: ProtocolName::WebSocket,
version: None,
}]));
if !self.version_set {
self.headers.set(WebSocketVersion::WebSocket13);
}
if !self.key_set {
self.headers.set(WebSocketKey::new());
}
self.url[Position::BeforePath..Position::AfterQuery].to_owned()
}
#[cfg(any(feature = "sync", feature = "async"))]
fn validate(&self, response: &Incoming<RawStatus>) -> WebSocketResult<()> {
let status = StatusCode::from_u16(response.subject.0);
if status != StatusCode::SwitchingProtocols {
if status.is_redirection() {
match response.headers.get::<hyper::header::Location>() {
Some(x) => return Err(WebSocketOtherError::RedirectError(status, x.to_string())).map_err(towse),
None => (),
}
}
return Err(WebSocketOtherError::StatusCodeError(status)).map_err(towse);
}
let key = self
.headers
.get::<WebSocketKey>()
.ok_or(WebSocketOtherError::RequestError(
"Request Sec-WebSocket-Key was invalid",
))?;
if response.headers.get() != Some(&(WebSocketAccept::new(key))) {
return Err(WebSocketOtherError::ResponseError(
"Sec-WebSocket-Accept is invalid",
))
.map_err(towse);
}
if response.headers.get()
!= Some(
&(Upgrade(vec![Protocol {
name: ProtocolName::WebSocket,
version: None,
}])),
) {
return Err(WebSocketOtherError::ResponseError(
"Upgrade field must be WebSocket",
))
.map_err(towse);
}
if self.headers.get()
!= Some(
&(Connection(vec![ConnectionOption::ConnectionHeader(UniCase(
"Upgrade".to_string(),
))])),
) {
return Err(WebSocketOtherError::ResponseError(
"Connection field must be 'Upgrade'",
))
.map_err(towse);
}
Ok(())
}
#[cfg(any(feature = "sync-ssl", feature = "async-ssl"))]
fn is_secure_url(&self) -> bool {
let scheme = self.url.scheme();
scheme == "wss" || scheme == "https"
}
#[cfg(any(feature = "sync", feature = "async"))]
fn extract_host_port(&self, secure: Option<bool>) -> WebSocketResult<::url::HostAndPort<&str>> {
if self.url.host().is_none() {
return Err(WebSocketOtherError::WebSocketUrlError(
WSUrlErrorKind::NoHostName,
))
.map_err(towse);
}
Ok(self.url.with_default_port(|url| {
const SECURE_PORT: u16 = 443;
const INSECURE_PORT: u16 = 80;
const SECURE_WS_SCHEME: &str = "wss";
Ok(match secure {
None if url.scheme() == SECURE_WS_SCHEME => SECURE_PORT,
None => INSECURE_PORT,
Some(true) => SECURE_PORT,
Some(false) => INSECURE_PORT,
})
})?)
}
#[cfg(feature = "sync")]
fn establish_tcp(&mut self, secure: Option<bool>) -> WebSocketResult<TcpStream> {
Ok(TcpStream::connect(self.extract_host_port(secure)?)?)
}
#[cfg(any(feature = "sync-ssl", feature = "async-ssl"))]
fn extract_host_ssl_conn(
&self,
connector: Option<TlsConnector>,
) -> WebSocketResult<(&str, TlsConnector)> {
let host = match self.url.host_str() {
Some(h) => h,
None => {
return Err(WebSocketOtherError::WebSocketUrlError(
WSUrlErrorKind::NoHostName,
))
.map_err(towse);
}
};
let connector = match connector {
Some(c) => c,
None => TlsConnector::builder().build().map_err(towse)?,
};
Ok((host, connector))
}
#[cfg(feature = "sync-ssl")]
fn wrap_ssl(
&self,
tcp_stream: TcpStream,
connector: Option<TlsConnector>,
) -> WebSocketResult<TlsStream<TcpStream>> {
let (host, connector) = self.extract_host_ssl_conn(connector)?;
let ssl_stream = connector.connect(host, tcp_stream).map_err(towse)?;
Ok(ssl_stream)
}
}
mod tests {
#[test]
fn build_client_with_protocols() {
use super::*;
let builder = ClientBuilder::new("ws://127.0.0.1:8080/hello/world")
.unwrap()
.add_protocol("protobeard");
let protos = &builder.headers.get::<WebSocketProtocol>().unwrap().0;
assert!(protos.contains(&"protobeard".to_string()));
assert!(protos.len() == 1);
let builder = ClientBuilder::new("ws://example.org/hello")
.unwrap()
.add_protocol("rust-websocket")
.clear_protocols()
.add_protocols(vec!["electric", "boogaloo"]);
let protos = &builder.headers.get::<WebSocketProtocol>().unwrap().0;
assert!(protos.contains(&"boogaloo".to_string()));
assert!(protos.contains(&"electric".to_string()));
assert!(!protos.contains(&"rust-websocket".to_string()));
}
#[test]
fn build_client_with_username_password() {
use super::*;
let mut builder = ClientBuilder::new("ws://john:pswd@127.0.0.1:8080/hello").unwrap();
let _request = builder.build_request();
let auth = builder.headers.get::<Authorization<Basic>>().unwrap();
assert!(auth.username == "john");
assert_eq!(auth.password, Some("pswd".to_owned()));
}
}