use bytes::{Buf, Bytes, TryGetError};
use bytestring::ByteString;
use serde::Serialize;
use serde::de::DeserializeOwned;
use std::mem::take;
use std::ops::ControlFlow;
use tokio::sync::mpsc;
use tokio::task::coop;
use tokio_websockets::proto::ProtocolError;
use super::error::ErrorKind;
use crate::error::Error;
use crate::request::Payload;
pub use tokio_websockets::CloseCode;
pub(super) type Sender = mpsc::Sender<Message>;
pub(super) type Receiver = mpsc::Receiver<Message>;
pub struct Channel(Sender, Receiver);
#[derive(Debug)]
#[non_exhaustive]
pub enum Message {
Binary(Bytes),
Close(Option<(CloseCode, Option<ByteString>)>),
Text(ByteString),
}
impl Channel {
pub(super) fn new(tx: Sender, rx: Receiver) -> Self {
Self(tx, rx)
}
pub async fn send(&mut self, message: impl Into<Message>) -> super::Result<()> {
if self.0.send(message.into()).await.is_err() {
Err(ControlFlow::Break(ErrorKind::CLOSED.into()))
} else {
Ok(())
}
}
pub fn recv(&mut self) -> impl Future<Output = Option<Message>> {
coop::unconstrained(self.1.recv())
}
}
impl Message {
pub fn json(data: &impl Serialize) -> Result<Self, Error> {
Ok(serde_json::to_string(data)?.into())
}
}
impl From<Bytes> for Message {
#[inline]
fn from(data: Bytes) -> Self {
Self::Binary(data)
}
}
impl From<ByteString> for Message {
#[inline]
fn from(data: ByteString) -> Self {
Self::Text(data)
}
}
impl From<Vec<u8>> for Message {
#[inline]
fn from(data: Vec<u8>) -> Self {
Self::from(Bytes::from(data))
}
}
impl From<&'_ [u8]> for Message {
#[inline]
fn from(data: &'_ [u8]) -> Self {
Self::from(Bytes::copy_from_slice(data))
}
}
impl From<String> for Message {
#[inline]
fn from(data: String) -> Self {
ByteString::from(data).into()
}
}
impl From<&'_ str> for Message {
#[inline]
fn from(data: &'_ str) -> Self {
ByteString::from(data).into()
}
}
impl TryFrom<tokio_websockets::Message> for Message {
type Error = tokio_websockets::Error;
fn try_from(message: tokio_websockets::Message) -> Result<Self, Self::Error> {
let is_binary = message.is_binary();
let is_text = !is_binary && message.is_text();
let mut bytes = Bytes::from(message.into_payload());
if is_binary {
Ok(Self::Binary(bytes))
} else if is_text {
let utf8 = bytes.try_into().or(Err(ProtocolError::InvalidUtf8))?;
Ok(Self::Text(utf8))
} else {
match bytes.try_get_u16() {
Err(TryGetError { available: 0, .. }) => Ok(Self::Close(None)),
Ok(0..=999) | Ok(4999..) | Err(_) => Err(ProtocolError::InvalidCloseCode.into()),
Ok(u16) => {
let code = u16.try_into()?;
Ok(if bytes.remaining() == 0 {
Self::Close(Some((code, None)))
} else {
let reason = bytes.try_into().or(Err(ProtocolError::InvalidUtf8))?;
Self::Close(Some((code, Some(reason))))
})
}
}
}
}
}
impl Payload for ByteString {
fn copy_to_unique(&mut self) -> Result<Bytes, Error> {
take(self).into_bytes().copy_to_unique()
}
fn copy_to_vec(&mut self) -> Result<Vec<u8>, Error> {
take(self).into_bytes().copy_to_vec()
}
fn json<T>(&mut self) -> Result<T, Error>
where
T: DeserializeOwned,
{
take(self).into_bytes().json()
}
}
impl From<Message> for tokio_websockets::Message {
#[inline]
fn from(message: Message) -> Self {
match message {
Message::Binary(binary) => Self::binary(binary),
Message::Text(text) => Self::text(text.into_bytes()),
Message::Close(None) => Self::close(None, ""),
Message::Close(Some((code, reason))) => {
Self::close(Some(code), reason.as_deref().unwrap_or_default())
}
}
}
}