use std::{future::poll_fn, io, pin::Pin, str::FromStr};
use base64::{engine::general_purpose, Engine};
use futures_core::Stream;
use http::{header::HeaderName, HeaderMap, HeaderValue, Uri};
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt},
net::TcpStream,
};
use tokio_util::codec::FramedRead;
use crate::{
proto::{Config, Limits, Role},
resolver::{self, Resolver},
upgrade::{self, server_response},
Connector, Error, MaybeTlsStream, WebSocketStream,
};
pub(crate) fn make_key() -> [u8; 24] {
let mut key_base64 = [0; 24];
let key_bytes = crate::rand::get_key();
unsafe {
general_purpose::STANDARD
.encode_slice(key_bytes, &mut key_base64)
.unwrap_unchecked()
};
key_base64
}
fn default_port(uri: &Uri) -> Option<u16> {
if let Some(port) = uri.port_u16() {
return Some(port);
}
let scheme = uri.scheme_str();
match scheme {
Some("https" | "wss") => Some(443),
Some("http" | "ws") => Some(80),
_ => None,
}
}
fn build_request(uri: &Uri, key: &[u8], headers: &HeaderMap) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(b"GET ");
buf.extend_from_slice(uri.path().as_bytes());
if let Some(query) = uri.query() {
buf.extend_from_slice(b"?");
buf.extend_from_slice(query.as_bytes());
}
buf.extend_from_slice(b" HTTP/1.1\r\n");
if let Some(host) = uri.host() {
buf.extend_from_slice(b"Host: ");
buf.extend_from_slice(host.as_bytes());
if let Some(port) = default_port(uri) {
buf.extend_from_slice(b":");
buf.extend_from_slice(port.to_string().as_bytes());
}
buf.extend_from_slice(b"\r\n");
}
buf.extend_from_slice(b"Upgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: ");
buf.extend_from_slice(key);
buf.extend_from_slice(b"\r\nSec-WebSocket-Version: 13\r\n");
for (name, value) in headers {
buf.extend_from_slice(name.as_str().as_bytes());
buf.extend_from_slice(b": ");
buf.extend_from_slice(value.as_bytes());
buf.extend_from_slice(b"\r\n");
}
buf.extend_from_slice(b"\r\n");
buf
}
pub struct Builder<'a, R: Resolver = resolver::Gai> {
uri: Option<Uri>,
connector: Option<&'a Connector>,
resolver: R,
config: Config,
limits: Limits,
headers: HeaderMap,
}
impl<'a> Builder<'a> {
#[must_use]
pub fn new() -> Self {
Self {
uri: None,
connector: None,
resolver: resolver::Gai,
config: Config::default(),
limits: Limits::default(),
headers: HeaderMap::new(),
}
}
#[must_use]
pub fn from_uri(uri: Uri) -> Self {
Self {
uri: Some(uri),
connector: None,
resolver: resolver::Gai,
config: Config::default(),
limits: Limits::default(),
headers: HeaderMap::new(),
}
}
}
impl<'a, R: Resolver> Builder<'a, R> {
pub fn uri(mut self, uri: &str) -> Result<Self, http::uri::InvalidUri> {
self.uri = Some(Uri::from_str(uri)?);
Ok(self)
}
#[must_use]
pub fn connector(mut self, connector: &'a Connector) -> Self {
self.connector = Some(connector);
self
}
#[must_use]
pub fn resolver<NewR: Resolver>(self, resolver: NewR) -> Builder<'a, NewR> {
let Builder {
uri,
connector,
resolver: _,
config,
limits,
headers,
} = self;
Builder {
uri,
connector,
resolver,
config,
limits,
headers,
}
}
#[must_use]
pub fn config(mut self, config: Config) -> Self {
self.config = config;
self
}
#[must_use]
pub fn limits(mut self, limits: Limits) -> Self {
self.limits = limits;
self
}
#[must_use]
pub fn add_header(mut self, name: HeaderName, value: HeaderValue) -> Self {
self.headers.insert(name, value);
self
}
pub async fn connect(
&self,
) -> Result<
(
WebSocketStream<MaybeTlsStream<TcpStream>>,
upgrade::Response,
),
Error,
> {
let uri = self.uri.as_ref().ok_or(Error::NoUriConfigured)?;
let host = uri
.host()
.ok_or(Error::CannotResolveHost)?
.trim_start_matches('[')
.trim_end_matches(']');
let port = default_port(uri).unwrap_or(80);
let addr = self.resolver.resolve(host, port).await?;
let stream = TcpStream::connect(&addr).await?;
let stream = if uri.scheme_str() == Some("wss") {
if let Some(connector) = self.connector {
connector.wrap(host, stream).await?
} else {
let connector = Connector::new()?;
connector.wrap(host, stream).await?
}
} else if uri.scheme_str() == Some("ws") {
Connector::Plain.wrap(host, stream).await?
} else {
return Err(Error::UnsupportedScheme);
};
self.connect_on(stream).await
}
pub async fn connect_on<S: AsyncRead + AsyncWrite + Unpin>(
&self,
mut stream: S,
) -> Result<(WebSocketStream<S>, upgrade::Response), Error> {
let uri = self.uri.as_ref().ok_or(Error::NoUriConfigured)?;
let key_base64 = make_key();
let upgrade_codec = server_response::Codec::new(&key_base64);
let request = build_request(uri, &key_base64, &self.headers);
stream.write_all(&request).await?;
let mut framed = FramedRead::new(stream, upgrade_codec);
let res = poll_fn(|cx| Pin::new(&mut framed).poll_next(cx))
.await
.ok_or(Error::Io(io::ErrorKind::UnexpectedEof.into()))??;
Ok((
WebSocketStream::from_framed(framed, Role::Client, self.config, self.limits),
res,
))
}
pub fn take_over<S: AsyncRead + AsyncWrite + Unpin>(&self, stream: S) -> WebSocketStream<S> {
WebSocketStream::from_raw_stream(stream, Role::Client, self.config, self.limits)
}
}
impl<'a> Default for Builder<'a> {
fn default() -> Self {
Self::new()
}
}