use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, ToSocketAddrs};
use tokio_rustls::TlsAcceptor;
use tokio_tungstenite::tungstenite;
use tungstenite::handshake::server::{Request, Response};
use tungstenite::http::header::{HeaderName, HeaderValue, SEC_WEBSOCKET_PROTOCOL};
use websock_proto::{Error, Result, ServerOptions};
use crate::Connection;
use crate::connection::{ConnectionInfo, map_tungstenite_err};
pub async fn bind<A>(
addr: A,
opts: ServerOptions,
tls: Option<rustls::ServerConfig>,
) -> Result<Server>
where
A: ToSocketAddrs,
{
let listener = TcpListener::bind(addr)
.await
.map_err(|e| Error::Io(e.to_string()))?;
let headers = prepare_headers(&opts)?;
validate_protocols(&opts)?;
let acceptor = tls.map(|cfg| TlsAcceptor::from(Arc::new(cfg)));
Ok(Server {
listener,
opts,
headers,
acceptor,
})
}
pub trait ServerIo: AsyncRead + AsyncWrite + Unpin + Send {}
impl<T> ServerIo for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
pub type ServerStream = Box<dyn ServerIo>;
pub struct Server {
listener: TcpListener,
opts: ServerOptions,
headers: Vec<(HeaderName, HeaderValue)>,
acceptor: Option<TlsAcceptor>,
}
impl Server {
pub async fn accept(&self) -> Result<Connection<ServerStream>> {
let (stream, _addr) = self
.listener
.accept()
.await
.map_err(|e| Error::Io(e.to_string()))?;
let peer = stream.peer_addr().map_err(|e| Error::Io(e.to_string()))?;
let local = stream.local_addr().map_err(|e| Error::Io(e.to_string()))?;
let (stream, is_tls): (ServerStream, bool) = if let Some(acceptor) = &self.acceptor {
let tls_stream = acceptor
.accept(stream)
.await
.map_err(|e| Error::Tls(e.to_string()))?;
(Box::new(tls_stream), true)
} else {
(Box::new(stream), false)
};
let info = ConnectionInfo {
peer,
local,
is_tls,
};
let headers = self.headers.to_vec();
let protocols = self.opts.protocols.clone();
let ws = tokio_tungstenite::accept_hdr_async(
stream,
move |req: &Request, mut resp: Response| {
for (name, value) in &headers {
resp.headers_mut().append(name, value.clone());
}
if let Some(protocol) = select_protocol(req, &protocols) {
let value =
HeaderValue::from_str(&protocol).expect("protocol value validated on bind");
resp.headers_mut().insert(SEC_WEBSOCKET_PROTOCOL, value);
}
Ok(resp)
},
)
.await
.map_err(map_tungstenite_err)?;
Ok(Connection { ws, info })
}
pub async fn accept_tls(
&self,
) -> Result<Connection<tokio_rustls::server::TlsStream<tokio::net::TcpStream>>> {
let (stream, _addr) = self
.listener
.accept()
.await
.map_err(|e| Error::Io(e.to_string()))?;
let tls_stream = self
.acceptor
.as_ref()
.ok_or_else(|| Error::Tls("missing tls acceptor".into()))?
.accept(stream)
.await
.map_err(|e| Error::Tls(e.to_string()))?;
let headers = self.headers.clone();
let protocols = self.opts.protocols.clone();
let info = ConnectionInfo {
peer: tls_stream
.get_ref()
.0
.peer_addr()
.map_err(|e| Error::Io(e.to_string()))?,
local: tls_stream
.get_ref()
.0
.local_addr()
.map_err(|e| Error::Io(e.to_string()))?,
is_tls: true,
};
let ws = tokio_tungstenite::accept_hdr_async(
tls_stream,
move |req: &Request, mut resp: Response| {
for (name, value) in &headers {
resp.headers_mut().append(name, value.clone());
}
if let Some(protocol) = select_protocol(req, &protocols) {
let value =
HeaderValue::from_str(&protocol).expect("protocol value validated on bind");
resp.headers_mut().insert(SEC_WEBSOCKET_PROTOCOL, value);
}
Ok(resp)
},
)
.await
.map_err(map_tungstenite_err)?;
Ok(Connection { ws, info })
}
pub fn local_addr(&self) -> Result<std::net::SocketAddr> {
self.listener
.local_addr()
.map_err(|e| Error::Io(e.to_string()))
}
}
fn prepare_headers(opts: &ServerOptions) -> Result<Vec<(HeaderName, HeaderValue)>> {
let mut out = Vec::new();
for (k, v) in &opts.headers {
let name = HeaderName::from_bytes(k.as_bytes())
.map_err(|e| Error::Protocol(format!("invalid header name: {e}")))?;
let value = HeaderValue::from_str(v)
.map_err(|e| Error::Protocol(format!("invalid header value: {e}")))?;
out.push((name, value));
}
Ok(out)
}
fn validate_protocols(opts: &ServerOptions) -> Result<()> {
for protocol in &opts.protocols {
HeaderValue::from_str(protocol)
.map_err(|e| Error::Protocol(format!("invalid protocol value: {e}")))?;
}
Ok(())
}
fn select_protocol(req: &Request, allowed: &[String]) -> Option<String> {
if allowed.is_empty() {
return None;
}
let header = req.headers().get(SEC_WEBSOCKET_PROTOCOL)?;
let header = header.to_str().ok()?;
for candidate in header.split(',').map(|s| s.trim()) {
if allowed.iter().any(|p| p == candidate) {
return Some(candidate.to_string());
}
}
None
}