use std::convert::TryFrom;
use std::fmt::{Debug, Error as FmtError, Formatter};
use native_tls::{
TlsConnector as NativeTlsTlsConnector, TlsConnectorBuilder as NativeTlsTlsConnectorBuilder,
};
use rand::SeedableRng;
use rand_chacha::ChaCha20Rng;
use tokio::io::{self, BufReader, BufWriter};
use tokio::net::TcpStream;
use super::handshake::Handshake;
use super::parsed_addr::ParsedAddr;
use super::split::{WebSocketReadHalf, WebSocketWriteHalf};
use super::stream::Stream;
use super::FrameType;
use super::WebSocket;
use crate::error::WebSocketError;
use crate::secure::{TlsCertificate, TlsIdentity, TlsProtocol};
pub struct WebSocketBuilder {
additional_handshake_headers: Vec<(String, String)>,
subprotocols: Vec<String>,
tls_connector_builder: NativeTlsTlsConnectorBuilder,
}
impl Debug for WebSocketBuilder {
fn fmt(&self, f: &mut Formatter) -> Result<(), FmtError> {
f.write_str("WebSocketBuilder")
}
}
impl WebSocketBuilder {
pub(super) fn new() -> Self {
Self {
additional_handshake_headers: Vec::new(),
subprotocols: Vec::new(),
tls_connector_builder: NativeTlsTlsConnector::builder(),
}
}
pub async fn connect(&mut self, url: &str) -> Result<WebSocket, WebSocketError> {
let parsed_addr = ParsedAddr::try_from(url)?;
let stream = Stream::Plain(
TcpStream::connect(parsed_addr.addr)
.await
.map_err(|e| WebSocketError::TcpConnectionError(e))?,
);
let stream = match &parsed_addr.scheme[..] {
"ws" => stream,
"wss" => {
let tls_config = self
.tls_connector_builder
.build()
.map_err(|e| WebSocketError::TlsBuilderError(e))?;
stream.into_tls(&parsed_addr.host, tls_config).await?
}
_ => return Err(WebSocketError::SchemeError),
};
let (read_half, write_half) = io::split(stream);
let (sender, receiver) = flume::unbounded();
let mut ws = WebSocket {
read_half: WebSocketReadHalf {
stream: BufReader::new(read_half),
last_frame_type: FrameType::default(),
sender,
},
write_half: WebSocketWriteHalf {
shutdown: false,
sent_closed: false,
stream: BufWriter::new(write_half),
rng: ChaCha20Rng::from_entropy(),
receiver,
},
accepted_subprotocol: None,
handshake_response_headers: None,
};
let handshake = Handshake::new(
&parsed_addr,
&self.additional_handshake_headers,
&self.subprotocols,
);
handshake.send_request(&mut ws).await?;
match handshake.check_response(&mut ws).await {
Ok(_) => Ok(ws),
Err(e) => {
ws.shutdown().await?;
Err(e)
}
}
}
pub fn add_header(&mut self, header_name: &str, header_value: &str) -> &mut Self {
self.additional_handshake_headers
.push((header_name.to_string(), header_value.to_string()));
self
}
pub fn remove_header(&mut self, header_name: &str) -> &mut Self {
self.additional_handshake_headers
.retain(|header| header.0 != header_name);
self
}
pub fn add_subprotocol(&mut self, subprotocol: &str) -> &mut Self {
self.subprotocols.push(subprotocol.to_string());
self
}
pub fn remove_subprotocol(&mut self, subprotocol: &str) -> &mut Self {
self.subprotocols.retain(|s| s != subprotocol);
self
}
pub fn tls_danger_accept_invalid_certs(&mut self, accept_invalid_certs: bool) -> &mut Self {
self.tls_connector_builder
.danger_accept_invalid_certs(accept_invalid_certs);
self
}
pub fn tls_danger_accept_invalid_hostnames(
&mut self,
accept_invalid_hostnames: bool,
) -> &mut Self {
self.tls_connector_builder
.danger_accept_invalid_hostnames(accept_invalid_hostnames);
self
}
pub fn tls_add_root_certificate(&mut self, cert: TlsCertificate) -> &mut Self {
self.tls_connector_builder.add_root_certificate(cert.0);
self
}
pub fn tls_disable_built_in_roots(&mut self, disable: bool) -> &mut Self {
self.tls_connector_builder.disable_built_in_roots(disable);
self
}
pub fn tls_identity(&mut self, identity: TlsIdentity) -> &mut Self {
self.tls_connector_builder.identity(identity.0);
self
}
pub fn tls_max_protocol_version(&mut self, protocol: Option<TlsProtocol>) -> &mut Self {
self.tls_connector_builder.max_protocol_version(protocol);
self
}
pub fn tls_min_protocol_version(&mut self, protocol: Option<TlsProtocol>) -> &mut Self {
self.tls_connector_builder.min_protocol_version(protocol);
self
}
pub fn tls_use_sni(&mut self, use_sni: bool) -> &mut Self {
self.tls_connector_builder.use_sni(use_sni);
self
}
}