use async_io::Timer;
use async_net::TcpStream;
use bytes::Bytes;
use futures_lite::future;
use futures_lite::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use sha1::{Digest, Sha1};
use std::future::{Future, IntoFuture};
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use crate::connection::tcp::connect_happy_eyeballs;
#[cfg(feature = "btls-backend")]
use crate::connection::tls::build_boring_tls_connector_for_protocols;
#[cfg(feature = "native-tls")]
use crate::connection::tls::build_native_tls_connector_for_protocols;
#[cfg(feature = "rustls")]
use crate::connection::tls::connect_async_tls_with_config;
#[cfg(feature = "btls-backend")]
use crate::connection::tls::connect_boring_tls_with_connector;
#[cfg(feature = "native-tls")]
use crate::connection::tls::connect_native_tls_with_connector;
use crate::cookie::CookieJar;
use crate::error::{Error, ErrorKind, Result};
use crate::header::HeaderMap;
use crate::request::TimeoutConfig;
use crate::tls::{TlsBackend, TlsConfig};
use crate::url::Url;
const WS_ACCEPT_MAGIC: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
static NEXT_RANDOM_FALLBACK: AtomicU64 = AtomicU64::new(1);
trait IoStream: AsyncRead + AsyncWrite + Unpin + Send {}
impl<T> IoStream for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
type BoxedStream = Box<dyn IoStream>;
pub struct WebSocketBuilder {
url: String,
base_url: Option<Url>,
headers: HeaderMap,
cookies: Vec<(String, String)>,
protocols: Vec<String>,
origin: Option<String>,
timeout_config: TimeoutConfig,
tls_config: TlsConfig,
cookie_jar: Option<CookieJar>,
local_addr: Option<SocketAddr>,
}
impl WebSocketBuilder {
pub(crate) fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
base_url: None,
headers: HeaderMap::new(),
cookies: Vec::new(),
protocols: Vec::new(),
origin: None,
timeout_config: TimeoutConfig::default(),
tls_config: TlsConfig::default(),
cookie_jar: None,
local_addr: None,
}
}
pub(crate) fn with_client_defaults(
mut self,
base_url: Option<Url>,
headers: HeaderMap,
cookies: Vec<(String, String)>,
timeout_config: TimeoutConfig,
tls_config: TlsConfig,
cookie_jar: Option<CookieJar>,
local_addr: Option<SocketAddr>,
) -> Self {
self.base_url = base_url;
self.headers = headers;
self.cookies = cookies;
self.timeout_config = timeout_config;
self.tls_config = tls_config;
self.cookie_jar = cookie_jar;
self.local_addr = local_addr;
self
}
pub fn local_addr(mut self, addr: SocketAddr) -> Self {
self.local_addr = Some(addr);
self
}
pub fn base_url(mut self, url: impl AsRef<str>) -> Result<Self> {
self.base_url = Some(Url::parse(url)?);
Ok(self)
}
pub fn header(mut self, name: impl AsRef<str>, value: impl AsRef<str>) -> Result<Self> {
self.headers.append(name, value)?;
Ok(self)
}
pub fn headers<I, K, V>(mut self, headers: I) -> Result<Self>
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<str>,
V: AsRef<str>,
{
for (name, value) in headers {
self.headers.append(name, value)?;
}
Ok(self)
}
pub fn bearer_auth(self, token: impl AsRef<str>) -> Result<Self> {
self.header("authorization", format!("Bearer {}", token.as_ref()))
}
pub fn basic_auth(self, username: impl AsRef<str>, password: impl AsRef<str>) -> Result<Self> {
let raw = format!("{}:{}", username.as_ref(), password.as_ref());
let encoded = encode_base64(raw.as_bytes());
self.header("authorization", format!("Basic {encoded}"))
}
pub fn cookie(mut self, name: impl AsRef<str>, value: impl AsRef<str>) -> Self {
self.cookies
.push((name.as_ref().to_owned(), value.as_ref().to_owned()));
self
}
pub fn cookies<I, K, V>(mut self, cookies: I) -> Self
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<str>,
V: AsRef<str>,
{
self.cookies.extend(
cookies
.into_iter()
.map(|(name, value)| (name.as_ref().to_owned(), value.as_ref().to_owned())),
);
self
}
pub fn protocol(mut self, name: impl AsRef<str>) -> Self {
self.protocols.push(name.as_ref().to_owned());
self
}
pub fn protocols<I, S>(mut self, values: I) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
self.protocols
.extend(values.into_iter().map(|value| value.as_ref().to_owned()));
self
}
pub fn origin(mut self, value: impl AsRef<str>) -> Self {
self.origin = Some(value.as_ref().to_owned());
self
}
pub fn timeout(mut self, duration: Duration) -> Self {
self.timeout_config.total = Some(duration);
self
}
pub fn connect_timeout(mut self, duration: Duration) -> Self {
self.timeout_config.connect = Some(duration);
self
}
pub fn read_timeout(mut self, duration: Duration) -> Self {
self.timeout_config.read = Some(duration);
self
}
pub fn write_timeout(mut self, duration: Duration) -> Self {
self.timeout_config.write = Some(duration);
self
}
pub fn tls_config(mut self, tls_config: TlsConfig) -> Self {
self.tls_config = tls_config;
self
}
pub fn danger_accept_invalid_certs(mut self, enabled: bool) -> Self {
self.tls_config = self.tls_config.clone().danger_accept_invalid_certs(enabled);
self
}
pub fn tls_backend(mut self, backend: TlsBackend) -> Self {
self.tls_config = self.tls_config.clone().backend(backend);
self
}
pub fn alpn_protocols<I, S>(mut self, protocols: I) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
self.tls_config = self.tls_config.clone().alpn_protocols(protocols);
self
}
pub fn disable_alpn(mut self) -> Self {
self.tls_config = self.tls_config.clone().disable_alpn();
self
}
async fn connect(self) -> Result<WebSocket> {
let total_timeout = self.timeout_config.total;
match total_timeout {
Some(duration) => {
with_timeout(
Some(duration),
self.connect_inner(),
ErrorKind::Timeout,
"websocket handshake timed out",
)
.await?
}
None => self.connect_inner().await,
}
}
async fn connect_inner(self) -> Result<WebSocket> {
let url = resolve_url(&self.base_url, &self.url)?;
match url.scheme() {
"ws" | "wss" => {}
scheme => {
return Err(Error::new(
ErrorKind::InvalidUrl,
format!("websocket url scheme must be ws or wss, got {scheme}"),
));
}
}
let mut websocket_key_bytes = [0_u8; 16];
fill_random_bytes(&mut websocket_key_bytes);
let websocket_key = encode_base64(&websocket_key_bytes);
let cookie_from_jar = self
.cookie_jar
.as_ref()
.and_then(|jar| jar.get_cookie_header(&url));
let request = build_handshake_request(
&url,
&self.headers,
&self.cookies,
cookie_from_jar.as_deref(),
&self.protocols,
self.origin.as_deref(),
&websocket_key,
)?;
let tcp = connect_tcp_socket(
url.host(),
url.effective_port(),
self.local_addr,
self.timeout_config,
)
.await?;
let mut stream = connect_socket(tcp, &url, &self.tls_config, self.timeout_config).await?;
with_timeout_io(
self.timeout_config.write,
stream.write_all(request.as_bytes()),
"write timed out",
)
.await?;
with_timeout_io(self.timeout_config.write, stream.flush(), "write timed out").await?;
let (response_headers, leftover) =
read_handshake_response(&mut *stream, self.timeout_config).await?;
let selected_protocol =
validate_handshake_response(&response_headers, &websocket_key, &self.protocols)?;
if let Some(jar) = &self.cookie_jar {
jar.store_set_cookies(&url, response_headers.get_all("set-cookie"))?;
}
Ok(WebSocket {
stream,
selected_protocol,
timeout_config: self.timeout_config,
read_buffer: leftover,
write_buffer: Vec::new(),
mask_seed: initial_mask_seed(),
fragmented_message: None,
sent_close: false,
terminated: false,
})
}
}
async fn connect_tcp_socket(
host: &str,
port: u16,
local_addr: Option<SocketAddr>,
timeout_config: TimeoutConfig,
) -> Result<TcpStream> {
let addrs: Vec<SocketAddr> = async_net::resolve((host, port))
.await
.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to resolve websocket host",
err,
)
})?
.into_iter()
.collect();
let (primary, fallback): (Vec<_>, Vec<_>) = addrs.into_iter().partition(|a| a.is_ipv6());
let (primary, fallback) = if primary.is_empty() {
(fallback, vec![])
} else {
(primary, fallback)
};
connect_happy_eyeballs(primary, fallback, local_addr, timeout_config.connect).await
}
impl IntoFuture for WebSocketBuilder {
type Output = Result<WebSocket>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'static>>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move { self.connect().await })
}
}
pub struct WebSocket {
stream: BoxedStream,
selected_protocol: Option<String>,
timeout_config: TimeoutConfig,
read_buffer: Vec<u8>,
write_buffer: Vec<u8>,
mask_seed: u64,
fragmented_message: Option<FragmentedMessage>,
sent_close: bool,
terminated: bool,
}
impl WebSocket {
pub fn selected_protocol(&self) -> Option<&str> {
self.selected_protocol.as_deref()
}
pub async fn send_text(&mut self, text: impl AsRef<str>) -> Result<()> {
self.send_frame(0x1, text.as_ref().as_bytes()).await
}
pub async fn send_binary(&mut self, bytes: impl AsRef<[u8]>) -> Result<()> {
self.send_frame(0x2, bytes.as_ref()).await
}
pub async fn ping(&mut self, data: impl AsRef<[u8]>) -> Result<()> {
self.send_control_frame(0x9, data.as_ref()).await
}
pub async fn pong(&mut self, data: impl AsRef<[u8]>) -> Result<()> {
self.send_control_frame(0xA, data.as_ref()).await
}
pub async fn close(&mut self) -> Result<()> {
if self.terminated || self.sent_close {
return Ok(());
}
self.send_close_frame(Some(1000), b"").await?;
self.sent_close = true;
Ok(())
}
pub async fn next(&mut self) -> Option<Result<WebSocketMessage>> {
if self.terminated {
return None;
}
loop {
let frame = match self.read_frame().await {
Ok(frame) => frame,
Err(err) => {
self.terminated = true;
return Some(Err(err));
}
};
match frame.opcode {
0x0 => match self.handle_continuation(frame) {
Ok(Some(message)) => return Some(Ok(message)),
Ok(None) => continue,
Err(err) => {
self.terminated = true;
return Some(Err(err));
}
},
0x1 | 0x2 => match self.handle_data_frame(frame) {
Ok(Some(message)) => return Some(Ok(message)),
Ok(None) => continue,
Err(err) => {
self.terminated = true;
return Some(Err(err));
}
},
0x8 => {
let close_frame = match parse_close_frame(&frame.payload) {
Ok(close_frame) => close_frame,
Err(err) => {
self.terminated = true;
return Some(Err(err));
}
};
if !self.sent_close {
let payload = frame.payload.clone();
if let Err(err) = self.send_close_reply(&payload).await {
self.terminated = true;
return Some(Err(err));
}
self.sent_close = true;
}
self.terminated = true;
return Some(Ok(WebSocketMessage::Close(close_frame)));
}
0x9 => {
if let Err(err) = self.send_control_frame(0xA, &frame.payload).await {
self.terminated = true;
return Some(Err(err));
}
return Some(Ok(WebSocketMessage::Ping(Bytes::from(frame.payload))));
}
0xA => return Some(Ok(WebSocketMessage::Pong(Bytes::from(frame.payload)))),
opcode => {
self.terminated = true;
return Some(Err(Error::new(
ErrorKind::Transport,
format!("unsupported websocket opcode: {opcode}"),
)));
}
}
}
}
fn handle_data_frame(&mut self, frame: Frame) -> Result<Option<WebSocketMessage>> {
if self.fragmented_message.is_some() {
return Err(Error::new(
ErrorKind::Transport,
"received websocket data frame while fragmented message is in progress",
));
}
if frame.fin {
Ok(Some(message_from_payload(frame.opcode, frame.payload)?))
} else {
self.fragmented_message = Some(FragmentedMessage {
opcode: frame.opcode,
payload: frame.payload,
});
Ok(None)
}
}
fn handle_continuation(&mut self, frame: Frame) -> Result<Option<WebSocketMessage>> {
let Some(fragmented) = self.fragmented_message.as_mut() else {
return Err(Error::new(
ErrorKind::Transport,
"received websocket continuation frame without an active fragmented message",
));
};
fragmented.payload.extend_from_slice(&frame.payload);
if !frame.fin {
return Ok(None);
}
let fragmented = self.fragmented_message.take().expect("fragmented message");
Ok(Some(message_from_payload(
fragmented.opcode,
fragmented.payload,
)?))
}
async fn read_frame(&mut self) -> Result<Frame> {
loop {
if let Some(frame) = try_parse_frame(&mut self.read_buffer)? {
return Ok(frame);
}
let mut scratch = [0_u8; 8192];
let read = with_timeout_io(
self.timeout_config.read,
self.stream.read(&mut scratch),
"read timed out",
)
.await?;
if read == 0 {
return Err(Error::new(
ErrorKind::Transport,
"websocket connection closed unexpectedly",
));
}
self.read_buffer.extend_from_slice(&scratch[..read]);
}
}
async fn send_frame(&mut self, opcode: u8, payload: &[u8]) -> Result<()> {
if self.terminated || self.sent_close {
return Err(Error::new(
ErrorKind::Transport,
"websocket connection is closing",
));
}
write_client_frame(
&mut *self.stream,
&mut self.write_buffer,
&mut self.mask_seed,
self.timeout_config.write,
opcode,
payload,
)
.await
}
async fn send_control_frame(&mut self, opcode: u8, payload: &[u8]) -> Result<()> {
if payload.len() > 125 {
return Err(Error::new(
ErrorKind::Transport,
"websocket control frame payload must be 125 bytes or less",
));
}
self.send_frame(opcode, payload).await
}
async fn send_close_frame(&mut self, code: Option<u16>, reason: &[u8]) -> Result<()> {
let mut payload = Vec::new();
if let Some(code) = code {
payload.extend_from_slice(&code.to_be_bytes());
payload.extend_from_slice(reason);
}
if payload.len() > 125 {
return Err(Error::new(
ErrorKind::Transport,
"websocket close frame payload must be 125 bytes or less",
));
}
write_client_frame(
&mut *self.stream,
&mut self.write_buffer,
&mut self.mask_seed,
self.timeout_config.write,
0x8,
&payload,
)
.await
}
async fn send_close_reply(&mut self, payload: &[u8]) -> Result<()> {
if self.terminated || self.sent_close {
return Ok(());
}
write_client_frame(
&mut *self.stream,
&mut self.write_buffer,
&mut self.mask_seed,
self.timeout_config.write,
0x8,
payload,
)
.await
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct CloseFrame {
pub code: u16,
pub reason: String,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum WebSocketMessage {
Text(String),
Binary(Bytes),
Ping(Bytes),
Pong(Bytes),
Close(Option<CloseFrame>),
}
struct FragmentedMessage {
opcode: u8,
payload: Vec<u8>,
}
struct Frame {
fin: bool,
opcode: u8,
payload: Vec<u8>,
}
fn resolve_url(base_url: &Option<Url>, raw_url: &str) -> Result<Url> {
match base_url {
Some(base_url) if !raw_url.contains("://") => base_url.join(raw_url),
Some(_) | None => Url::parse(raw_url),
}
}
fn build_handshake_request(
url: &Url,
headers: &HeaderMap,
cookies: &[(String, String)],
jar_cookie_header: Option<&str>,
protocols: &[String],
origin: Option<&str>,
websocket_key: &str,
) -> Result<String> {
let mut request = String::new();
request.push_str(&format!("GET {} HTTP/1.1\r\n", url.path_and_query()));
request.push_str(&format!("Host: {}\r\n", url.authority()));
request.push_str("Connection: Upgrade\r\n");
request.push_str("Upgrade: websocket\r\n");
request.push_str("Sec-WebSocket-Version: 13\r\n");
request.push_str(&format!("Sec-WebSocket-Key: {websocket_key}\r\n"));
let mut cookie_values = Vec::new();
for (name, value) in headers.iter() {
let lower = name.as_str().to_ascii_lowercase();
match lower.as_str() {
"host"
| "connection"
| "upgrade"
| "sec-websocket-key"
| "sec-websocket-version"
| "sec-websocket-accept"
| "sec-websocket-protocol"
| "sec-websocket-extensions" => continue,
"origin" if origin.is_some() => continue,
"cookie" => {
cookie_values.push(value.as_str().to_owned());
continue;
}
_ => {}
}
request.push_str(&format!("{}: {}\r\n", name.as_str(), value.as_str()));
}
if let Some(origin) = origin {
request.push_str(&format!("Origin: {origin}\r\n"));
}
if !protocols.is_empty() {
request.push_str(&format!(
"Sec-WebSocket-Protocol: {}\r\n",
protocols.join(", ")
));
}
if let Some(cookie) = jar_cookie_header {
cookie_values.push(cookie.to_owned());
}
if !cookies.is_empty() {
cookie_values.push(
cookies
.iter()
.map(|(name, value)| format!("{name}={value}"))
.collect::<Vec<_>>()
.join("; "),
);
}
if !cookie_values.is_empty() {
request.push_str(&format!("Cookie: {}\r\n", cookie_values.join("; ")));
}
request.push_str("\r\n");
Ok(request)
}
async fn read_handshake_response(
stream: &mut dyn IoStream,
timeout_config: TimeoutConfig,
) -> Result<(HeaderMap, Vec<u8>)> {
let mut buffer = Vec::new();
let mut header_end = None;
while header_end.is_none() {
let mut scratch = [0_u8; 1024];
let read = with_timeout_io(
timeout_config.read,
stream.read(&mut scratch),
"read timed out",
)
.await?;
if read == 0 {
return Err(Error::new(
ErrorKind::Transport,
"websocket handshake response headers are incomplete",
));
}
buffer.extend_from_slice(&scratch[..read]);
header_end = find_headers_end(&buffer);
}
let header_end = header_end.expect("header end");
let head = &buffer[..header_end];
let body = buffer[header_end + 4..].to_vec();
let text = std::str::from_utf8(head).map_err(|err| {
Error::with_source(
ErrorKind::Decode,
"websocket handshake response headers are not valid utf-8",
err,
)
})?;
let mut lines = text.split("\r\n");
let status_line = lines
.next()
.ok_or_else(|| Error::new(ErrorKind::Transport, "missing websocket status line"))?;
let status = parse_status_code(status_line)?;
if status != 101 {
return Err(Error::new(
ErrorKind::Transport,
format!("websocket handshake failed with status {status}"),
));
}
let mut headers = HeaderMap::new();
for line in lines {
if line.is_empty() {
continue;
}
let (name, value) = line.split_once(':').ok_or_else(|| {
Error::new(
ErrorKind::Transport,
format!("invalid websocket handshake header line: {line}"),
)
})?;
headers.append(name.trim(), value.trim())?;
}
Ok((headers, body))
}
fn validate_handshake_response(
headers: &HeaderMap,
websocket_key: &str,
protocols: &[String],
) -> Result<Option<String>> {
let has_upgrade = headers
.get("upgrade")
.is_some_and(|value| value.eq_ignore_ascii_case("websocket"));
if !has_upgrade {
return Err(Error::new(
ErrorKind::Transport,
"websocket handshake response is missing Upgrade: websocket",
));
}
let has_connection_upgrade = headers
.get_all("connection")
.into_iter()
.any(|value| contains_header_token(value, "upgrade"));
if !has_connection_upgrade {
return Err(Error::new(
ErrorKind::Transport,
"websocket handshake response is missing Connection: Upgrade",
));
}
if headers.get("sec-websocket-extensions").is_some() {
return Err(Error::new(
ErrorKind::Transport,
"websocket extensions are not supported yet",
));
}
let accept = headers.get("sec-websocket-accept").ok_or_else(|| {
Error::new(
ErrorKind::Transport,
"websocket handshake response is missing Sec-WebSocket-Accept",
)
})?;
let expected_accept = websocket_accept_value(websocket_key);
if accept.trim() != expected_accept {
return Err(Error::new(
ErrorKind::Transport,
"websocket handshake response returned an invalid Sec-WebSocket-Accept",
));
}
match headers.get("sec-websocket-protocol") {
Some(selected) => {
let selected = selected.trim();
if protocols.is_empty() {
return Err(Error::new(
ErrorKind::Transport,
"websocket server selected a subprotocol that was not requested",
));
}
if !protocols.iter().any(|protocol| protocol == selected) {
return Err(Error::new(
ErrorKind::Transport,
format!("websocket server selected an unexpected subprotocol: {selected}"),
));
}
Ok(Some(selected.to_owned()))
}
None => Ok(None),
}
}
async fn connect_socket(
stream: TcpStream,
url: &Url,
tls_config: &TlsConfig,
timeout_config: TimeoutConfig,
) -> Result<BoxedStream> {
match url.scheme() {
"ws" => Ok(Box::new(stream) as BoxedStream),
"wss" => connect_secure_socket(stream, url, tls_config, timeout_config).await,
_ => Err(Error::new(
ErrorKind::InvalidUrl,
format!("unsupported websocket scheme: {}", url.scheme()),
)),
}
}
#[cfg(feature = "rustls")]
async fn connect_async_tls(
stream: TcpStream,
url: &Url,
tls_config: &TlsConfig,
timeout_config: TimeoutConfig,
) -> Result<BoxedStream> {
let config = tls_config
.build_client_config(crate::request::ProtocolPolicy::Http1Only)
.map_err(|message| Error::new(ErrorKind::Transport, message))?;
let stream =
connect_async_tls_with_config(stream, url.host(), config, timeout_config.connect).await?;
Ok(Box::new(stream) as BoxedStream)
}
#[cfg(feature = "native-tls")]
async fn connect_native_tls(
stream: TcpStream,
url: &Url,
tls_config: &TlsConfig,
timeout_config: TimeoutConfig,
) -> Result<BoxedStream> {
let protocols = tls_config
.validate_http1_alpn(crate::request::ProtocolPolicy::Http1Only)
.map_err(|message| Error::new(ErrorKind::Transport, message))?;
let connector = build_native_tls_connector_for_protocols(tls_config, &protocols)?;
let stream =
connect_native_tls_with_connector(stream, url.host(), connector, timeout_config.connect)
.await?;
Ok(Box::new(stream) as BoxedStream)
}
#[cfg(feature = "btls-backend")]
async fn connect_boring_tls(
stream: TcpStream,
url: &Url,
tls_config: &TlsConfig,
timeout_config: TimeoutConfig,
) -> Result<BoxedStream> {
let protocols = tls_config
.validate_http1_alpn(crate::request::ProtocolPolicy::Http1Only)
.map_err(|message| Error::new(ErrorKind::Transport, message))?;
let connector = build_boring_tls_connector_for_protocols(tls_config, &protocols)?;
let stream = connect_boring_tls_with_connector(
stream,
url.host(),
connector,
tls_config,
timeout_config.connect,
)
.await?;
Ok(Box::new(stream) as BoxedStream)
}
#[cfg(any(feature = "rustls", feature = "native-tls", feature = "btls-backend"))]
async fn connect_secure_socket(
stream: TcpStream,
url: &Url,
tls_config: &TlsConfig,
timeout_config: TimeoutConfig,
) -> Result<BoxedStream> {
fn default_backend() -> TlsBackend {
#[cfg(feature = "rustls")]
{
return TlsBackend::Rustls;
}
#[cfg(all(not(feature = "rustls"), feature = "native-tls"))]
{
return TlsBackend::Native;
}
#[cfg(all(
not(feature = "rustls"),
not(feature = "native-tls"),
feature = "btls-backend"
))]
{
return TlsBackend::Boring;
}
}
let backend = tls_config.backend.unwrap_or_else(default_backend);
match backend {
#[cfg(feature = "rustls")]
TlsBackend::Rustls => connect_async_tls(stream, url, tls_config, timeout_config).await,
#[cfg(feature = "native-tls")]
TlsBackend::Native => connect_native_tls(stream, url, tls_config, timeout_config).await,
#[cfg(feature = "btls-backend")]
TlsBackend::Boring => connect_boring_tls(stream, url, tls_config, timeout_config).await,
}
}
#[cfg(not(any(feature = "rustls", feature = "native-tls", feature = "btls-backend")))]
async fn connect_secure_socket(
_stream: TcpStream,
_url: &Url,
_tls_config: &TlsConfig,
_timeout_config: TimeoutConfig,
) -> Result<BoxedStream> {
Err(Error::new(
ErrorKind::Transport,
"wss requires a TLS backend feature",
))
}
async fn write_client_frame(
stream: &mut dyn IoStream,
frame: &mut Vec<u8>,
mask_seed: &mut u64,
timeout: Option<Duration>,
opcode: u8,
payload: &[u8],
) -> Result<()> {
frame.clear();
if frame.capacity() < payload.len() + 14 {
frame.reserve(payload.len() + 14 - frame.capacity());
}
frame.push(0x80 | (opcode & 0x0F));
if payload.len() <= 125 {
frame.push(0x80 | payload.len() as u8);
} else if payload.len() <= u16::MAX as usize {
frame.push(0x80 | 126);
frame.extend_from_slice(&(payload.len() as u16).to_be_bytes());
} else {
frame.push(0x80 | 127);
frame.extend_from_slice(&(payload.len() as u64).to_be_bytes());
}
let mask = next_mask(mask_seed);
frame.extend_from_slice(&mask);
let mask_word = u32::from_ne_bytes(mask);
let payload_start = frame.len();
frame.extend_from_slice(payload);
let masked = &mut frame[payload_start..];
let chunks = masked.len() / 4;
for i in 0..chunks {
let off = i * 4;
let word = u32::from_ne_bytes([
masked[off],
masked[off + 1],
masked[off + 2],
masked[off + 3],
]);
let xored = (word ^ mask_word).to_ne_bytes();
masked[off..off + 4].copy_from_slice(&xored);
}
for i in (chunks * 4)..masked.len() {
masked[i] ^= mask[i % 4];
}
with_timeout_io(timeout, stream.write_all(&frame), "write timed out").await?;
with_timeout_io(timeout, stream.flush(), "write timed out").await?;
Ok(())
}
fn try_parse_frame(buffer: &mut Vec<u8>) -> Result<Option<Frame>> {
if buffer.len() < 2 {
return Ok(None);
}
let first = buffer[0];
let second = buffer[1];
if first & 0b0111_0000 != 0 {
return Err(Error::new(
ErrorKind::Transport,
"websocket RSV bits are not supported",
));
}
let fin = first & 0x80 != 0;
let opcode = first & 0x0F;
let masked = second & 0x80 != 0;
if masked {
return Err(Error::new(
ErrorKind::Transport,
"websocket server frames must not be masked",
));
}
let mut header_len = 2usize;
let payload_len = match second & 0x7F {
len @ 0..=125 => len as usize,
126 => {
if buffer.len() < 4 {
return Ok(None);
}
header_len = 4;
u16::from_be_bytes([buffer[2], buffer[3]]) as usize
}
127 => {
if buffer.len() < 10 {
return Ok(None);
}
header_len = 10;
let len = u64::from_be_bytes([
buffer[2], buffer[3], buffer[4], buffer[5], buffer[6], buffer[7], buffer[8],
buffer[9],
]);
usize::try_from(len).map_err(|_| {
Error::new(
ErrorKind::Transport,
"websocket frame payload is too large for this platform",
)
})?
}
_ => unreachable!(),
};
let is_control = matches!(opcode, 0x8 | 0x9 | 0xA);
if is_control && (!fin || payload_len > 125) {
return Err(Error::new(
ErrorKind::Transport,
"websocket control frames must be final and 125 bytes or less",
));
}
let frame_len = header_len + payload_len;
if buffer.len() < frame_len {
return Ok(None);
}
let payload = buffer[header_len..frame_len].to_vec();
buffer.drain(..frame_len);
Ok(Some(Frame {
fin,
opcode,
payload,
}))
}
fn message_from_payload(opcode: u8, payload: Vec<u8>) -> Result<WebSocketMessage> {
match opcode {
0x1 => String::from_utf8(payload)
.map(WebSocketMessage::Text)
.map_err(|err| {
Error::with_source(
ErrorKind::Decode,
"websocket text frame is not valid utf-8",
err,
)
}),
0x2 => Ok(WebSocketMessage::Binary(Bytes::from(payload))),
_ => Err(Error::new(
ErrorKind::Transport,
format!("unsupported websocket data opcode: {opcode}"),
)),
}
}
fn parse_close_frame(payload: &[u8]) -> Result<Option<CloseFrame>> {
if payload.is_empty() {
return Ok(None);
}
if payload.len() == 1 {
return Err(Error::new(
ErrorKind::Transport,
"websocket close frame payload must be empty or at least 2 bytes",
));
}
let code = u16::from_be_bytes([payload[0], payload[1]]);
let reason = String::from_utf8(payload[2..].to_vec()).map_err(|err| {
Error::with_source(
ErrorKind::Decode,
"websocket close frame reason is not valid utf-8",
err,
)
})?;
Ok(Some(CloseFrame { code, reason }))
}
fn parse_status_code(status_line: &str) -> Result<u16> {
let code = status_line
.split_whitespace()
.nth(1)
.ok_or_else(|| Error::new(ErrorKind::Transport, "invalid websocket status line"))?
.parse::<u16>()
.map_err(|err| {
Error::with_source(ErrorKind::Transport, "invalid websocket status code", err)
})?;
Ok(code)
}
fn websocket_accept_value(websocket_key: &str) -> String {
let mut hasher = Sha1::new();
hasher.update(websocket_key.as_bytes());
hasher.update(WS_ACCEPT_MAGIC.as_bytes());
encode_base64(&hasher.finalize())
}
fn contains_header_token(value: &str, token: &str) -> bool {
value
.split(',')
.any(|part| part.trim().eq_ignore_ascii_case(token))
}
fn find_headers_end(buffer: &[u8]) -> Option<usize> {
buffer.windows(4).position(|window| window == b"\r\n\r\n")
}
fn fill_random_bytes(bytes: &mut [u8]) {
let mut seed = NEXT_RANDOM_FALLBACK.fetch_add(1, Ordering::Relaxed)
^ std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64;
for chunk in bytes.chunks_mut(8) {
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
let random = seed.to_be_bytes();
let len = chunk.len();
chunk.copy_from_slice(&random[..len]);
}
}
fn initial_mask_seed() -> u64 {
let mut seed = [0_u8; 8];
fill_random_bytes(&mut seed);
u64::from_be_bytes(seed)
}
fn next_mask(seed: &mut u64) -> [u8; 4] {
*seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
let bytes = seed.to_be_bytes();
[bytes[0], bytes[1], bytes[2], bytes[3]]
}
fn encode_base64(bytes: &[u8]) -> String {
crate::util::encode_base64(bytes)
}
async fn with_timeout<F, T>(
timeout: Option<Duration>,
future: F,
timeout_kind: ErrorKind,
timeout_message: &'static str,
) -> Result<T>
where
F: Future<Output = T>,
{
match timeout {
Some(duration) => {
future::or(
Box::pin(async move {
Timer::after(duration).await;
Err(Error::new(timeout_kind, timeout_message))
}),
Box::pin(async move { Ok(future.await) }),
)
.await
}
None => Ok(future.await),
}
}
async fn with_timeout_io<F, T>(
timeout: Option<Duration>,
future: F,
timeout_message: &'static str,
) -> Result<T>
where
F: Future<Output = std::io::Result<T>>,
{
match with_timeout(timeout, future, ErrorKind::Timeout, timeout_message).await {
Ok(Ok(value)) => Ok(value),
Ok(Err(err)) => Err(Error::with_source(
ErrorKind::Transport,
"io operation failed",
err,
)),
Err(err) => Err(err),
}
}