use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_tungstenite::tungstenite;
use crate::protocol::validate_protocol;
use crate::transport::WsTransport;
use crate::{alpn, Config, Error, Session, Version, PREFIX_QMUX, PREFIX_WEBTRANSPORT};
#[derive(Debug, Clone, Copy)]
pub struct KeepAlive {
pub interval: Duration,
pub timeout: Duration,
}
impl KeepAlive {
pub fn new(interval: Duration, timeout: Duration) -> Self {
Self { interval, timeout }
}
}
impl Default for KeepAlive {
fn default() -> Self {
Self {
interval: Duration::from_secs(5),
timeout: Duration::from_secs(30),
}
}
}
fn parse_alpn(alpn: Option<&str>) -> (Version, Option<String>) {
let alpn = match alpn {
Some(s) if !s.is_empty() => s,
_ => return (Version::WebTransport, None),
};
if let Some(proto) = alpn.strip_prefix(PREFIX_QMUX) {
let proto = if proto.is_empty() {
None
} else {
Some(proto.to_string())
};
return (Version::QMux00, proto);
}
if alpn == crate::ALPN_QMUX {
return (Version::QMux00, None);
}
if let Some(proto) = alpn.strip_prefix(PREFIX_WEBTRANSPORT) {
let proto = if proto.is_empty() {
None
} else {
Some(proto.to_string())
};
return (Version::WebTransport, proto);
}
(Version::WebTransport, None)
}
pub struct Bare<T> {
ws: T,
alpn: Option<String>,
keep_alive: Option<KeepAlive>,
}
impl<T> Bare<T>
where
T: futures::Stream<Item = Result<tungstenite::Message, tungstenite::Error>>
+ futures::Sink<tungstenite::Message, Error = tungstenite::Error>
+ Unpin
+ Send
+ 'static,
{
pub fn new(ws: T) -> Self {
Self {
ws,
alpn: None,
keep_alive: None,
}
}
pub fn with_alpn(mut self, alpn: &str) -> Self {
self.alpn = Some(alpn.to_string());
self
}
pub fn with_keep_alive(mut self, keep_alive: KeepAlive) -> Self {
self.keep_alive = Some(keep_alive);
self
}
pub fn connect(self) -> Session {
let (version, protocol) = parse_alpn(self.alpn.as_deref());
Session::connect(self.into_transport(), Config::new(version, protocol))
}
pub fn accept(self) -> Session {
let (version, protocol) = parse_alpn(self.alpn.as_deref());
Session::accept(self.into_transport(), Config::new(version, protocol))
}
fn into_transport(self) -> WsTransport<T> {
let transport = WsTransport::new(self.ws);
match self.keep_alive {
Some(ka) => transport.with_keep_alive(ka),
None => transport,
}
}
}
#[derive(Default, Clone)]
pub struct Client {
protocols: Vec<String>,
config: Option<tungstenite::protocol::WebSocketConfig>,
keep_alive: Option<KeepAlive>,
#[cfg(feature = "wss")]
connector: Option<tokio_tungstenite::Connector>,
}
impl Client {
pub fn new() -> Self {
Self::default()
}
pub fn with_protocol(mut self, protocol: &str) -> Self {
self.protocols.push(protocol.to_string());
self
}
pub fn with_protocols(mut self, protocols: &[&str]) -> Self {
self.protocols
.extend(protocols.iter().map(|s| s.to_string()));
self
}
pub fn with_config(mut self, config: tungstenite::protocol::WebSocketConfig) -> Self {
self.config = Some(config);
self
}
pub fn with_keep_alive(mut self, keep_alive: KeepAlive) -> Self {
self.keep_alive = Some(keep_alive);
self
}
#[cfg(feature = "wss")]
pub fn with_connector(mut self, connector: tokio_tungstenite::Connector) -> Self {
self.connector = Some(connector);
self
}
pub async fn connect(&self, url: &str) -> Result<Session, Error> {
use tungstenite::{client::IntoClientRequest, http};
for p in &self.protocols {
validate_protocol(p)?;
}
let mut request = url
.into_client_request()
.map_err(|e| Error::Io(e.to_string()))?;
let protocol_value = alpn::build(&self.protocols).join(", ");
request.headers_mut().insert(
http::header::SEC_WEBSOCKET_PROTOCOL,
http::HeaderValue::from_str(&protocol_value)
.map_err(|_| Error::InvalidProtocol(protocol_value))?,
);
#[cfg(feature = "wss")]
let (ws_stream, response) = {
tokio_tungstenite::connect_async_tls_with_config(
request,
self.config,
false,
self.connector.clone(),
)
.await
.map_err(|e| Error::Io(e.to_string()))?
};
#[cfg(not(feature = "wss"))]
let (ws_stream, response) =
tokio_tungstenite::connect_async_with_config(request, self.config, false)
.await
.map_err(|e| Error::Io(e.to_string()))?;
let negotiated = response
.headers()
.get(http::header::SEC_WEBSOCKET_PROTOCOL)
.and_then(|h| h.to_str().ok());
let (version, protocol) = parse_alpn(negotiated);
let transport = match self.keep_alive {
Some(ka) => WsTransport::new(ws_stream).with_keep_alive(ka),
None => WsTransport::new(ws_stream),
};
Ok(Session::connect(transport, Config::new(version, protocol)))
}
}
#[derive(Default, Clone)]
pub struct Server {
protocols: Vec<String>,
keep_alive: Option<KeepAlive>,
}
impl Server {
pub fn new() -> Self {
Self::default()
}
pub fn with_protocol(mut self, protocol: &str) -> Self {
self.protocols.push(protocol.to_string());
self
}
pub fn with_protocols(mut self, protocols: &[&str]) -> Self {
self.protocols
.extend(protocols.iter().map(|s| s.to_string()));
self
}
pub fn with_keep_alive(mut self, keep_alive: KeepAlive) -> Self {
self.keep_alive = Some(keep_alive);
self
}
pub async fn accept<T: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
&self,
socket: T,
) -> Result<Session, Error> {
use std::sync::{Arc, Mutex};
use tungstenite::{handshake::server, http};
for p in &self.protocols {
validate_protocol(p)?;
}
let negotiated = Arc::new(Mutex::new(None::<(Version, Option<String>)>));
let negotiated_clone = negotiated.clone();
let supported = self.protocols.clone();
#[allow(clippy::result_large_err)]
let callback = move |req: &server::Request,
mut response: server::Response|
-> Result<server::Response, server::ErrorResponse> {
let header_protocols: Vec<&str> = req
.headers()
.get_all(http::header::SEC_WEBSOCKET_PROTOCOL)
.iter()
.filter_map(|v| v.to_str().ok())
.flat_map(|h| h.split(','))
.map(|p| p.trim())
.filter(|p| !p.is_empty())
.collect();
let qmux_match = header_protocols
.iter()
.filter_map(|p| p.strip_prefix(PREFIX_QMUX))
.find(|p| supported.iter().any(|s| s == p))
.map(|p| p.to_string());
if let Some(ref proto) = qmux_match {
let response_value = format!("{PREFIX_QMUX}{proto}");
response.headers_mut().insert(
http::header::SEC_WEBSOCKET_PROTOCOL,
http::HeaderValue::from_str(&response_value).unwrap(),
);
*negotiated_clone.lock().unwrap() = Some((Version::QMux00, Some(proto.clone())));
return Ok(response);
}
if supported.is_empty() && header_protocols.contains(&crate::ALPN_QMUX) {
response.headers_mut().insert(
http::header::SEC_WEBSOCKET_PROTOCOL,
http::HeaderValue::from_str(crate::ALPN_QMUX).unwrap(),
);
*negotiated_clone.lock().unwrap() = Some((Version::QMux00, None));
return Ok(response);
}
let wt_match = header_protocols
.iter()
.filter_map(|p| p.strip_prefix(PREFIX_WEBTRANSPORT))
.find(|p| supported.iter().any(|s| s == p))
.map(|p| p.to_string());
if let Some(ref proto) = wt_match {
let response_value = format!("{PREFIX_WEBTRANSPORT}{proto}");
response.headers_mut().insert(
http::header::SEC_WEBSOCKET_PROTOCOL,
http::HeaderValue::from_str(&response_value).unwrap(),
);
*negotiated_clone.lock().unwrap() =
Some((Version::WebTransport, Some(proto.clone())));
return Ok(response);
}
if supported.is_empty() && header_protocols.contains(&crate::ALPN_WEBTRANSPORT) {
response.headers_mut().insert(
http::header::SEC_WEBSOCKET_PROTOCOL,
http::HeaderValue::from_str(crate::ALPN_WEBTRANSPORT).unwrap(),
);
*negotiated_clone.lock().unwrap() = Some((Version::WebTransport, None));
return Ok(response);
}
Err(http::Response::builder()
.status(http::StatusCode::BAD_REQUEST)
.body(Some("no supported protocol".to_string()))
.unwrap())
};
let ws = tokio_tungstenite::accept_hdr_async_with_config(socket, callback, None).await?;
let (version, protocol) = negotiated
.lock()
.unwrap()
.take()
.expect("negotiated must be set after successful handshake");
let transport = match self.keep_alive {
Some(ka) => WsTransport::new(ws).with_keep_alive(ka),
None => WsTransport::new(ws),
};
Ok(Session::accept(transport, Config::new(version, protocol)))
}
}