use crate::Result;
use async_tungstenite::{
tungstenite::{protocol::Role, Message},
WebSocketStream,
};
use futures_util::{stream::Stream, SinkExt};
use std::{
pin::Pin,
task::{Context, Poll},
};
use stopper::{Stopper, StreamStopper};
use trillium::{
http_types::{headers::Headers, Extensions, Method},
Upgrade,
};
use trillium_http::transport::BoxedTransport;
#[derive(Debug)]
pub struct WebSocketConn {
request_headers: Headers,
path: String,
method: Method,
state: Extensions,
stopper: Stopper,
wss: StreamStopper<WebSocketStream<BoxedTransport>>,
}
impl WebSocketConn {
pub async fn send_string(&mut self, string: impl Into<String>) {
self.wss.send(Message::text(string)).await.ok();
}
pub async fn send_bytes(&mut self, bin: impl Into<Vec<u8>>) {
self.wss.send(Message::binary(bin)).await.ok();
}
#[cfg(feature = "json")]
pub async fn send_json(&mut self, json: &impl serde::Serialize) -> serde_json::Result<()> {
self.send_string(serde_json::to_string(json)?).await;
Ok(())
}
pub(crate) async fn new(upgrade: Upgrade) -> Self {
let Upgrade {
request_headers,
path,
method,
state,
buffer,
transport,
stopper,
} = upgrade;
let wss = if let Some(vec) = buffer {
WebSocketStream::from_partially_read(transport, vec, Role::Server, None).await
} else {
WebSocketStream::from_raw_socket(transport, Role::Server, None).await
};
Self {
request_headers,
path,
method,
state,
wss: stopper.stop_stream(wss),
stopper,
}
}
pub fn stopper(&self) -> Stopper {
self.stopper.clone()
}
pub async fn close(&mut self) {
self.wss.close(None).await.ok();
}
pub fn headers(&self) -> &Headers {
&self.request_headers
}
pub fn path(&self) -> &str {
&self.path
}
pub fn method(&self) -> &Method {
&self.method
}
pub fn state<T: 'static>(&self) -> Option<&T> {
self.state.get()
}
pub fn take_state<T: 'static>(&mut self) -> Option<T> {
self.state.remove()
}
}
impl Stream for WebSocketConn {
type Item = Result;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.wss).poll_next(cx)
}
}