use std::{
net::{SocketAddr, ToSocketAddrs},
str::FromStr,
};
use futures_util::StreamExt;
use http::{header::HeaderName, HeaderMap, HeaderValue, Uri};
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt},
net::TcpStream,
};
use tokio_util::codec::Decoder;
use crate::{proto::Role, upgrade, Connector, Error, MaybeTlsStream, WebsocketStream};
pub(crate) fn make_key(key: Option<[u8; 16]>, key_base64: &mut [u8; 24]) {
let key_bytes = key.unwrap_or_else(|| {
[
fastrand::u8(0..=255),
fastrand::u8(0..=255),
fastrand::u8(0..=255),
fastrand::u8(0..=255),
fastrand::u8(0..=255),
fastrand::u8(0..=255),
fastrand::u8(0..=255),
fastrand::u8(0..=255),
fastrand::u8(0..=255),
fastrand::u8(0..=255),
fastrand::u8(0..=255),
fastrand::u8(0..=255),
fastrand::u8(0..=255),
fastrand::u8(0..=255),
fastrand::u8(0..=255),
fastrand::u8(0..=255),
]
});
base64::encode_config_slice(key_bytes, base64::STANDARD, key_base64);
}
fn default_port(uri: &Uri) -> Option<u16> {
if let Some(port) = uri.port_u16() {
return Some(port);
}
let scheme = uri.scheme_str();
match scheme {
Some("https" | "wss") => Some(443),
Some("http" | "ws") => Some(80),
_ => None,
}
}
fn build_request(uri: &Uri, key: &[u8], headers: &HeaderMap) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(b"GET ");
buf.extend_from_slice(uri.path().as_bytes());
if let Some(query) = uri.query() {
buf.extend_from_slice(b"?");
buf.extend_from_slice(query.as_bytes());
}
buf.extend_from_slice(b" HTTP/1.1\r\n");
if let Some(host) = uri.host() {
buf.extend_from_slice(b"Host: ");
buf.extend_from_slice(host.as_bytes());
if let Some(port) = default_port(uri) {
buf.extend_from_slice(b":");
buf.extend_from_slice(port.to_string().as_bytes());
}
buf.extend_from_slice(b"\r\n");
}
buf.extend_from_slice(b"Upgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: ");
buf.extend_from_slice(key);
buf.extend_from_slice(b"\r\nSec-WebSocket-Version: 13\r\n");
for (name, value) in headers {
buf.extend_from_slice(name.as_str().as_bytes());
buf.extend_from_slice(b": ");
buf.extend_from_slice(value.as_bytes());
buf.extend_from_slice(b"\r\n");
}
buf.extend_from_slice(b"\r\n");
buf
}
async fn resolve(host: String, port: u16) -> Result<SocketAddr, Error> {
let task = tokio::task::spawn_blocking(move || {
(host, port)
.to_socket_addrs()?
.next()
.ok_or(Error::CannotResolveHost)
});
task.await.expect("Tokio threadpool failed")
}
pub struct Builder<'a> {
uri: Option<Uri>,
connector: Option<&'a Connector>,
headers: HeaderMap,
fail_fast_on_invalid_utf8: bool,
}
impl<'a> Builder<'a> {
#[must_use]
pub fn new() -> Self {
Self {
uri: None,
connector: None,
headers: HeaderMap::new(),
fail_fast_on_invalid_utf8: true,
}
}
pub fn uri(mut self, uri: &str) -> Result<Self, http::uri::InvalidUri> {
self.uri = Some(Uri::from_str(uri)?);
Ok(self)
}
#[must_use]
pub fn from_uri(uri: Uri) -> Self {
Self {
uri: Some(uri),
connector: None,
headers: HeaderMap::new(),
fail_fast_on_invalid_utf8: true,
}
}
#[must_use]
pub fn connector(mut self, connector: &'a Connector) -> Self {
self.connector = Some(connector);
self
}
#[must_use]
pub fn add_header(mut self, name: HeaderName, value: HeaderValue) -> Self {
self.headers.insert(name, value);
self
}
#[must_use]
pub fn fail_fast_on_invalid_utf8(mut self, value: bool) -> Self {
self.fail_fast_on_invalid_utf8 = value;
self
}
pub async fn connect(&self) -> Result<WebsocketStream<MaybeTlsStream<TcpStream>>, Error> {
let uri = self.uri.as_ref().ok_or(Error::NoUriConfigured)?;
let host = uri.host().ok_or(Error::CannotResolveHost)?;
let port = default_port(uri).unwrap_or(80);
let addr = resolve(host.to_string(), port).await?;
let stream = TcpStream::connect(&addr).await?;
let stream = if let Some(connector) = self.connector {
connector.wrap(host, stream).await?
} else if uri.scheme_str() == Some("wss") {
let connector = Connector::new()?;
connector.wrap(host, stream).await?
} else {
Connector::Plain.wrap(host, stream).await?
};
self.connect_on(stream).await
}
pub async fn connect_on<S: AsyncRead + AsyncWrite + Unpin>(
&self,
mut stream: S,
) -> Result<WebsocketStream<S>, Error> {
let uri = self.uri.as_ref().ok_or(Error::NoUriConfigured)?;
let mut key_base64 = [0; 24];
make_key(None, &mut key_base64);
let upgrade_codec = upgrade::ServerResponseCodec::new(&key_base64);
let request = build_request(uri, &key_base64, &self.headers);
stream.write_all(&request).await?;
let (opt, framed) = upgrade_codec.framed(stream).into_future().await;
opt.ok_or(Error::NoUpgradeResponse)??;
Ok(WebsocketStream::from_framed(
framed,
Role::Client,
self.fail_fast_on_invalid_utf8,
))
}
pub fn take_over<S: AsyncRead + AsyncWrite + Unpin>(&self, stream: S) -> WebsocketStream<S> {
WebsocketStream::from_raw_stream(stream, Role::Client, self.fail_fast_on_invalid_utf8)
}
}
impl<'a> Default for Builder<'a> {
fn default() -> Self {
Self::new()
}
}