#[doc(hidden)]
pub mod util;
use std::fmt;
use std::str::Utf8Error;
use tracing::warn;
#[doc(hidden)]
pub use hyper::upgrade;
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::tungstenite;
use tungstenite::protocol::Role;
use tungstenite::protocol::Message as ProtMessage;
pub use tungstenite::{
error::Error,
protocol::CloseFrame,
protocol::frame::coding::CloseCode
};
use futures_util::stream::StreamExt;
use futures_util::sink::SinkExt;
pub trait LogWebSocketReturn: fmt::Debug {
fn should_log_error(&self) -> bool;
}
impl<T, E> LogWebSocketReturn for Result<T, E>
where
T: fmt::Debug,
E: fmt::Debug
{
fn should_log_error(&self) -> bool {
self.is_err()
}
}
impl LogWebSocketReturn for () {
fn should_log_error(&self) -> bool {
false
}
}
#[cfg(feature = "json")]
macro_rules! try2 {
($e:expr) => (match $e {
Some(v) => v,
None => return Ok(None)
})
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Message {
Text(String),
Binary(Vec<u8>)
}
impl Message {
pub fn into_data(self) -> Vec<u8> {
match self {
Self::Text(t) => t.into(),
Self::Binary(b) => b
}
}
pub fn to_text(&self) -> Result<&str, Utf8Error> {
match self {
Self::Text(t) => Ok(&t),
Self::Binary(b) => std::str::from_utf8(b)
}
}
}
impl From<String> for Message {
fn from(s: String) -> Self {
Self::Text(s)
}
}
impl From<&str> for Message {
fn from(s: &str) -> Self {
Self::Text(s.into())
}
}
impl From<Vec<u8>> for Message {
fn from(v: Vec<u8>) -> Self {
Self::Binary(v)
}
}
impl From<&[u8]> for Message {
fn from(v: &[u8]) -> Self {
Self::Binary(v.into())
}
}
impl From<Message> for ProtMessage {
fn from(m: Message) -> Self {
match m {
Message::Text(t) => Self::Text(t),
Message::Binary(b) => Self::Binary(b)
}
}
}
#[derive(Debug)]
pub struct WebSocket {
inner: WebSocketStream<upgrade::Upgraded>
}
impl WebSocket {
pub async fn new(upgraded: upgrade::Upgraded) -> Self {
Self {
inner: WebSocketStream::from_raw_socket(
upgraded,
Role::Server,
None
).await
}
}
#[doc(hidden)]
pub fn from_raw(
inner: WebSocketStream<upgrade::Upgraded>
) -> Self {
Self { inner }
}
pub async fn receive(&mut self) -> Result<Option<Message>, Error> {
loop {
let res = self.inner.next().await.transpose();
return match res {
Ok(None) => Ok(None),
Ok(Some(ProtMessage::Text(t))) => Ok(Some(Message::Text(t))),
Ok(Some(ProtMessage::Binary(b))) => {
Ok(Some(Message::Binary(b)))
},
Ok(Some(ProtMessage::Ping(d))) => {
self.inner.send(ProtMessage::Pong(d)).await?;
continue
},
Ok(Some(ProtMessage::Pong(_))) => continue,
Ok(Some(ProtMessage::Close(_))) => Ok(None),
Ok(Some(ProtMessage::Frame(f))) => {
warn!("we received a websocket frame {:?}", f);
continue
},
Err(Error::ConnectionClosed) |
Err(Error::AlreadyClosed) => Ok(None),
Err(e) => Err(e)
};
}
}
pub async fn send<M>(&mut self, msg: M) -> Result<(), Error>
where M: Into<Message> {
self.inner.send(msg.into().into()).await
}
pub async fn close(&mut self, code: CloseCode, reason: String) {
let _ = self.inner.send(ProtMessage::Close(Some(CloseFrame {
code, reason: reason.into()
}))).await;
let _ = self.inner.close(None).await;
}
pub async fn ping(&mut self) -> Result<(), Error> {
self.inner.send(ProtMessage::Ping(vec![])).await
}
#[cfg(feature = "json")]
#[cfg_attr(docsrs, doc(cfg(feature = "json")))]
pub async fn deserialize<D>(&mut self) -> Result<Option<D>, JsonError>
where D: serde::de::DeserializeOwned {
let msg = try2!(self.receive().await?).into_data();
serde_json::from_slice(&msg)
.map(|d| Some(d))
.map_err(|e| e.into())
}
#[cfg(feature = "json")]
#[cfg_attr(docsrs, doc(cfg(feature = "json")))]
pub async fn serialize<S: ?Sized>(&mut self, value: &S) -> Result<(), JsonError>
where S: serde::Serialize {
let v = serde_json::to_string(value)?;
self.send(v).await
.map_err(|e| e.into())
}
}
#[cfg(feature = "json")]
mod json_error {
use super::Error;
use std::fmt;
#[cfg_attr(docsrs, doc(cfg(feature = "json")))]
#[derive(Debug)]
pub enum JsonError {
ConnectionError(Error),
SerdeError(serde_json::Error)
}
impl From<Error> for JsonError {
fn from(e: Error) -> Self {
Self::ConnectionError(e)
}
}
impl From<serde_json::Error> for JsonError {
fn from(e: serde_json::Error) -> Self {
Self::SerdeError(e)
}
}
impl fmt::Display for JsonError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
impl std::error::Error for JsonError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::ConnectionError(e) => Some(e),
Self::SerdeError(e) => Some(e)
}
}
}
}
#[cfg(feature = "json")]
pub use json_error::JsonError;