pub async fn connect(url: &str) -> Result<self::Online, Box<dyn std::error::Error + Send + Sync>> {
self::Offline::new().connect(url).await
}
pub struct Offline {
vectored: bool,
auto_close: bool,
auto_pong: bool,
max_message_size: usize,
writev_threshold: usize,
auto_apply_mask: bool,
custom_headers: crate::HeaderMap,
#[cfg(feature = "proxy")]
proxy: Option<crate::proxy::Proxy>,
}
impl Default for Offline {
fn default() -> Self {
Self::new()
}
}
impl Offline {
#[must_use]
pub fn new() -> Self {
Self {
vectored: true,
auto_close: true,
auto_pong: true,
max_message_size: 64 << 20,
writev_threshold: 1024,
auto_apply_mask: true,
custom_headers: crate::HeaderMap::new(),
#[cfg(feature = "proxy")]
proxy: None,
}
}
pub fn add_header<K, V>(&mut self, key: K, value: V) -> &mut Self
where
K: std::convert::TryInto<hyper::header::HeaderName>,
K::Error: std::fmt::Debug,
V: std::convert::TryInto<hyper::header::HeaderValue>,
V::Error: std::fmt::Debug,
{
let name = key.try_into().expect("invalid header name");
let val = value.try_into().expect("invalid header value");
self.custom_headers.insert(name, val);
self
}
pub fn set_writev(&mut self, vectored: bool) -> &mut Self {
self.vectored = vectored;
self
}
pub fn set_writev_threshold(&mut self, threshold: usize) -> &mut Self {
self.writev_threshold = threshold;
self
}
pub fn set_auto_close(&mut self, auto_close: bool) -> &mut Self {
self.auto_close = auto_close;
self
}
pub fn set_auto_pong(&mut self, auto_pong: bool) -> &mut Self {
self.auto_pong = auto_pong;
self
}
pub fn set_max_message_size(&mut self, max_message_size: usize) -> &mut Self {
self.max_message_size = max_message_size;
self
}
pub fn set_auto_apply_mask(&mut self, auto_apply_mask: bool) -> &mut Self {
self.auto_apply_mask = auto_apply_mask;
self
}
#[cfg(feature = "proxy")]
pub fn set_proxy(&mut self, proxy: Option<crate::proxy::Proxy>) -> &mut Self {
self.proxy = proxy;
self
}
pub async fn connect(
&mut self,
url: &str,
) -> Result<Online, Box<dyn std::error::Error + Send + Sync>> {
let url = url::Url::parse(url).expect("invalid url");
let host = url.host_str().expect("invalid host").to_owned();
let port = url.port_or_known_default().expect("the port is unknown");
let address = format!("{host}:{port}");
#[cfg(feature = "proxy")]
let tcp_stream = if let Some(proxy) = &self.proxy {
proxy.tunnel(&host, port).await?
} else {
tokio::net::TcpStream::connect(&address).await?
};
#[cfg(not(feature = "proxy"))]
let tcp_stream = tokio::net::TcpStream::connect(&address).await?;
let mut req_builder = hyper::Request::builder()
.method("GET")
.uri(url.to_string())
.header("Host", &address)
.header(hyper::header::UPGRADE, "websocket")
.header(hyper::header::CONNECTION, "upgrade")
.header(
"Sec-WebSocket-Key",
fastwebsockets::handshake::generate_key(),
)
.header("Sec-WebSocket-Version", "13");
for (key, value) in self.custom_headers.iter() {
req_builder = req_builder.header(key, value);
}
let request = req_builder.body(http_body_util::Empty::<hyper::body::Bytes>::new())?;
let (mut ws, _) = match url.scheme() {
"wss" | "https" => {
let tls_connector = crate::tls_connector::get_tls_connector();
let server_name = rustls_pki_types::ServerName::try_from(host).map_err(|_| {
std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid server name")
})?;
let tls_stream = tls_connector.connect(server_name, tcp_stream).await?;
fastwebsockets::handshake::client(&SpawnExecutor, request, tls_stream).await?
}
_ => fastwebsockets::handshake::client(&SpawnExecutor, request, tcp_stream).await?,
};
ws.set_writev(self.vectored);
ws.set_writev_threshold(self.writev_threshold);
ws.set_auto_close(self.auto_close);
ws.set_auto_pong(self.auto_pong);
ws.set_max_message_size(self.max_message_size);
ws.set_auto_apply_mask(self.auto_apply_mask);
Ok(Online(fastwebsockets::FragmentCollector::new(ws)))
}
}
pub struct Online(
fastwebsockets::FragmentCollector<hyper_util::rt::tokio::TokioIo<hyper::upgrade::Upgraded>>,
);
impl Online {
#[inline]
pub async fn receive_frame(
&mut self,
) -> Result<fastwebsockets::Frame<'_>, Box<dyn std::error::Error + Send + Sync>> {
Ok(self.0.read_frame().await?)
}
#[inline]
async fn _send_frame(
&mut self,
frame: fastwebsockets::Frame<'_>,
) -> Result<(), fastwebsockets::WebSocketError> {
self.0.write_frame(frame).await
}
#[inline]
pub async fn send_ping(
&mut self,
data: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
self._send_frame(fastwebsockets::Frame::new(
true,
crate::OpCode::Ping,
None,
data.as_bytes().into(),
))
.await?;
Ok(())
}
#[inline]
pub async fn send_pong(
&mut self,
data: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
self._send_frame(fastwebsockets::Frame::pong(data.as_bytes().into()))
.await?;
Ok(())
}
#[inline]
pub async fn send_string(
&mut self,
data: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
self._send_frame(fastwebsockets::Frame::text(data.as_bytes().into()))
.await?;
Ok(())
}
#[inline]
pub async fn send_json(
&mut self,
data: impl serde::Serialize,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let json_bytes = serde_json::to_vec(&data)
.expect("Failed to serialize data passed to send_json into JSON");
self._send_frame(fastwebsockets::Frame::text(json_bytes.into()))
.await?;
Ok(())
}
#[inline]
pub async fn send_binary(
&mut self,
data: &[u8],
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
self._send_frame(fastwebsockets::Frame::binary(data.into()))
.await?;
Ok(())
}
pub async fn send_close(
&mut self,
data: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
self._send_frame(fastwebsockets::Frame::close(
fastwebsockets::CloseCode::Normal.into(),
data.as_bytes(),
))
.await?;
Ok(())
}
}
struct SpawnExecutor;
impl<Fut> hyper::rt::Executor<Fut> for SpawnExecutor
where
Fut: std::future::Future + Send + 'static,
Fut::Output: Send + 'static,
{
fn execute(&self, future: Fut) {
tokio::task::spawn(future);
}
}