use crate::{Result, Role, WebSocketConfig};
use async_tungstenite::{
WebSocketReceiver, WebSocketSender, WebSocketStream,
tungstenite::{self, Message},
};
use futures_lite::{Stream, StreamExt};
use std::{
borrow::Cow,
fmt::Debug,
net::IpAddr,
pin::Pin,
sync::Arc,
task::{self, Poll},
};
use swansong::{Interrupt, Swansong};
use trillium::{Headers, Method, Transport, TypeSet, Upgrade};
use trillium_http::{HttpContext, type_set::entry::Entry};
pub struct WebSocketConn {
request_headers: Headers,
path: Cow<'static, str>,
method: Method,
state: TypeSet,
peer_ip: Option<IpAddr>,
context: Arc<HttpContext>,
sink: WebSocketSender<Box<dyn Transport>>,
stream: Option<WStream>,
}
impl Debug for WebSocketConn {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WebSocketConn")
.field("request_headers", &self.request_headers)
.field("path", &self.path)
.field("method", &self.method)
.field("state", &self.state)
.field("peer_ip", &self.peer_ip)
.field("context", &self.context)
.field("stream", &self.stream)
.finish_non_exhaustive()
}
}
impl WebSocketConn {
pub async fn send_string(&mut self, string: String) -> Result<()> {
self.send(Message::text(string)).await
}
pub async fn send_bytes(&mut self, bin: Vec<u8>) -> Result<()> {
self.send(Message::binary(bin)).await
}
#[cfg(feature = "json")]
pub async fn send_json(&mut self, json: &impl serde::Serialize) -> Result<()> {
self.send_string(serde_json::to_string(json)?).await
}
pub async fn send(&mut self, message: Message) -> Result<()> {
self.sink.send(message).await.map_err(Into::into)
}
#[doc(hidden)]
pub async fn new(
upgrade: impl Into<Upgrade>,
config: Option<WebSocketConfig>,
role: Role,
) -> Self {
let mut upgrade = upgrade.into();
let request_headers = upgrade.take_request_headers();
let path = upgrade.path().to_string().into();
let method = upgrade.method();
let state = upgrade.take_state();
let context = upgrade.context().clone();
let peer_ip = upgrade.peer_ip();
let (buffer, transport) = upgrade.into_transport();
let wss = if buffer.is_empty() {
WebSocketStream::from_raw_socket(transport, role, config).await
} else {
WebSocketStream::from_partially_read(transport, buffer, role, config).await
};
let (sink, stream) = wss.split();
let stream = Some(WStream {
stream: context.swansong().interrupt(stream),
});
Self {
request_headers,
path,
method,
state,
peer_ip,
sink,
stream,
context,
}
}
pub fn swansong(&self) -> Swansong {
self.context.swansong().clone()
}
pub async fn close(&mut self) -> Result<()> {
self.send(Message::Close(None)).await
}
pub fn headers(&self) -> &Headers {
&self.request_headers
}
pub fn peer_ip(&self) -> Option<IpAddr> {
self.peer_ip
}
pub fn set_peer_ip(&mut self, peer_ip: Option<IpAddr>) -> &mut Self {
self.peer_ip = peer_ip;
self
}
pub fn path(&self) -> &str {
self.path.split('?').next().unwrap_or_default()
}
pub fn querystring(&self) -> &str {
self.path
.split_once('?')
.map(|(_, query)| query)
.unwrap_or_default()
}
pub fn method(&self) -> Method {
self.method
}
pub fn state<T: Send + Sync + 'static>(&self) -> Option<&T> {
self.state.get()
}
pub fn state_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
self.state.get_mut()
}
pub fn insert_state<T: Send + Sync + 'static>(&mut self, state: T) -> Option<T> {
self.state.insert(state)
}
pub fn state_entry<T: Send + Sync + 'static>(&mut self) -> Entry<'_, T> {
self.state.entry()
}
pub fn take_state<T: Send + Sync + 'static>(&mut self) -> Option<T> {
self.state.take()
}
pub fn take_inbound_stream(&mut self) -> Option<impl Stream<Item = MessageResult> + use<>> {
self.stream.take()
}
pub fn inbound_stream(&mut self) -> Option<impl Stream<Item = MessageResult> + '_> {
self.stream.as_mut()
}
}
type MessageResult = std::result::Result<Message, tungstenite::Error>;
pub struct WStream {
stream: Interrupt<WebSocketReceiver<Box<dyn Transport>>>,
}
impl Debug for WStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WStream").finish_non_exhaustive()
}
}
impl Stream for WStream {
type Item = MessageResult;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
self.stream.poll_next(cx)
}
}
impl AsMut<TypeSet> for WebSocketConn {
fn as_mut(&mut self) -> &mut TypeSet {
&mut self.state
}
}
impl AsRef<TypeSet> for WebSocketConn {
fn as_ref(&self) -> &TypeSet {
&self.state
}
}
impl Stream for WebSocketConn {
type Item = MessageResult;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
match self.stream.as_mut() {
Some(stream) => stream.poll_next(cx),
None => Poll::Ready(None),
}
}
}