use crate::CowStr;
use crate::deps::fluent_uri::Uri;
use crate::stream::StreamError;
use alloc::boxed::Box;
use alloc::string::String;
use alloc::string::ToString;
use alloc::vec::Vec;
use bytes::Bytes;
use core::borrow::Borrow;
use core::fmt::{self, Display};
use core::future::Future;
use core::ops::Deref;
use core::pin::Pin;
use n0_future::Stream;
#[repr(transparent)]
#[derive(Debug, Clone, Eq, PartialEq, Hash, PartialOrd, Ord)]
pub struct WsText(Bytes);
impl WsText {
pub const fn from_static(s: &'static str) -> Self {
Self(Bytes::from_static(s.as_bytes()))
}
pub fn as_str(&self) -> &str {
unsafe { core::str::from_utf8_unchecked(&self.0) }
}
pub unsafe fn from_bytes_unchecked(bytes: Bytes) -> Self {
Self(bytes)
}
pub fn into_bytes(self) -> Bytes {
self.0
}
}
impl Deref for WsText {
type Target = str;
fn deref(&self) -> &str {
self.as_str()
}
}
impl AsRef<str> for WsText {
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl AsRef<[u8]> for WsText {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl AsRef<Bytes> for WsText {
fn as_ref(&self) -> &Bytes {
&self.0
}
}
impl Borrow<str> for WsText {
fn borrow(&self) -> &str {
self.as_str()
}
}
impl Display for WsText {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
Display::fmt(self.as_str(), f)
}
}
impl From<String> for WsText {
fn from(s: String) -> Self {
Self(Bytes::from(s))
}
}
impl From<&str> for WsText {
fn from(s: &str) -> Self {
Self(Bytes::copy_from_slice(s.as_bytes()))
}
}
impl From<&String> for WsText {
fn from(s: &String) -> Self {
Self::from(s.as_str())
}
}
impl TryFrom<Bytes> for WsText {
type Error = core::str::Utf8Error;
fn try_from(bytes: Bytes) -> Result<Self, Self::Error> {
core::str::from_utf8(&bytes)?;
Ok(Self(bytes))
}
}
impl TryFrom<Vec<u8>> for WsText {
type Error = core::str::Utf8Error;
fn try_from(vec: Vec<u8>) -> Result<Self, Self::Error> {
Self::try_from(Bytes::from(vec))
}
}
impl From<WsText> for Bytes {
fn from(t: WsText) -> Bytes {
t.0
}
}
impl Default for WsText {
fn default() -> Self {
Self(Bytes::new())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u16)]
pub enum CloseCode {
Normal = 1000,
Away = 1001,
Protocol = 1002,
Unsupported = 1003,
Invalid = 1007,
Policy = 1008,
Size = 1009,
Extension = 1010,
Error = 1011,
Tls = 1015,
Other(u16),
}
impl From<u16> for CloseCode {
fn from(code: u16) -> Self {
match code {
1000 => CloseCode::Normal,
1001 => CloseCode::Away,
1002 => CloseCode::Protocol,
1003 => CloseCode::Unsupported,
1007 => CloseCode::Invalid,
1008 => CloseCode::Policy,
1009 => CloseCode::Size,
1010 => CloseCode::Extension,
1011 => CloseCode::Error,
1015 => CloseCode::Tls,
other => CloseCode::Other(other),
}
}
}
impl From<CloseCode> for u16 {
fn from(code: CloseCode) -> u16 {
match code {
CloseCode::Normal => 1000,
CloseCode::Away => 1001,
CloseCode::Protocol => 1002,
CloseCode::Unsupported => 1003,
CloseCode::Invalid => 1007,
CloseCode::Policy => 1008,
CloseCode::Size => 1009,
CloseCode::Extension => 1010,
CloseCode::Error => 1011,
CloseCode::Tls => 1015,
CloseCode::Other(code) => code,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CloseFrame<'a> {
pub code: CloseCode,
pub reason: CowStr<'a>,
}
impl<'a> CloseFrame<'a> {
pub fn new(code: CloseCode, reason: impl Into<CowStr<'a>>) -> Self {
Self {
code,
reason: reason.into(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum WsMessage {
Text(WsText),
Binary(Bytes),
Close(Option<CloseFrame<'static>>),
}
impl WsMessage {
pub fn is_text(&self) -> bool {
matches!(self, WsMessage::Text(_))
}
pub fn is_binary(&self) -> bool {
matches!(self, WsMessage::Binary(_))
}
pub fn is_close(&self) -> bool {
matches!(self, WsMessage::Close(_))
}
pub fn as_text(&self) -> Option<&str> {
match self {
WsMessage::Text(t) => Some(t.as_str()),
_ => None,
}
}
pub fn as_bytes(&self) -> Option<&[u8]> {
match self {
WsMessage::Text(t) => Some(t.as_ref()),
WsMessage::Binary(b) => Some(b),
WsMessage::Close(_) => None,
}
}
}
impl From<WsText> for WsMessage {
fn from(text: WsText) -> Self {
WsMessage::Text(text)
}
}
impl From<String> for WsMessage {
fn from(s: String) -> Self {
WsMessage::Text(WsText::from(s))
}
}
impl From<&str> for WsMessage {
fn from(s: &str) -> Self {
WsMessage::Text(WsText::from(s))
}
}
impl From<Bytes> for WsMessage {
fn from(bytes: Bytes) -> Self {
WsMessage::Binary(bytes)
}
}
impl From<Vec<u8>> for WsMessage {
fn from(vec: Vec<u8>) -> Self {
WsMessage::Binary(Bytes::from(vec))
}
}
#[cfg(not(target_arch = "wasm32"))]
pub struct WsStream(Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>> + Send>>);
#[cfg(target_arch = "wasm32")]
pub struct WsStream(Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>>>>);
impl WsStream {
#[cfg(not(target_arch = "wasm32"))]
pub fn new<S>(stream: S) -> Self
where
S: Stream<Item = Result<WsMessage, StreamError>> + Send + 'static,
{
Self(Box::pin(stream))
}
#[cfg(target_arch = "wasm32")]
pub fn new<S>(stream: S) -> Self
where
S: Stream<Item = Result<WsMessage, StreamError>> + 'static,
{
Self(Box::pin(stream))
}
#[cfg(not(target_arch = "wasm32"))]
pub fn into_inner(self) -> Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>> + Send>> {
self.0
}
#[cfg(target_arch = "wasm32")]
pub fn into_inner(self) -> Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>>>> {
self.0
}
pub fn tee(self) -> (WsStream, WsStream) {
use futures::channel::mpsc;
use n0_future::StreamExt as _;
let (tx1, rx1) = mpsc::unbounded();
let (tx2, rx2) = mpsc::unbounded();
n0_future::task::spawn(async move {
let mut stream = self.0;
while let Some(result) = stream.next().await {
match result {
Ok(msg) => {
let msg2 = msg.clone();
let send1 = tx1.unbounded_send(Ok(msg));
let send2 = tx2.unbounded_send(Ok(msg2));
if send1.is_err() && send2.is_err() {
break;
}
}
Err(_e) => {
break;
}
}
}
});
(WsStream::new(rx1), WsStream::new(rx2))
}
}
impl fmt::Debug for WsStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WsStream").finish_non_exhaustive()
}
}
#[cfg(not(target_arch = "wasm32"))]
pub struct WsSink(Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>>);
#[cfg(target_arch = "wasm32")]
pub struct WsSink(Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError>>>);
impl WsSink {
#[cfg(not(target_arch = "wasm32"))]
pub fn new<S>(sink: S) -> Self
where
S: n0_future::Sink<WsMessage, Error = StreamError> + Send + 'static,
{
Self(Box::pin(sink))
}
#[cfg(target_arch = "wasm32")]
pub fn new<S>(sink: S) -> Self
where
S: n0_future::Sink<WsMessage, Error = StreamError> + 'static,
{
Self(Box::pin(sink))
}
#[cfg(not(target_arch = "wasm32"))]
pub fn into_inner(
self,
) -> Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>> {
self.0
}
#[cfg(target_arch = "wasm32")]
pub fn into_inner(self) -> Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError>>> {
self.0
}
#[cfg(not(target_arch = "wasm32"))]
pub fn get_mut(
&mut self,
) -> &mut Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>> {
use core::borrow::BorrowMut;
self.0.borrow_mut()
}
#[cfg(target_arch = "wasm32")]
pub fn get_mut(
&mut self,
) -> &mut Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + 'static>> {
use core::borrow::BorrowMut;
self.0.borrow_mut()
}
}
impl fmt::Debug for WsSink {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WsSink").finish_non_exhaustive()
}
}
#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
pub trait WebSocketClient: Sync {
type Error: core::error::Error + Send + Sync + 'static;
fn connect(
&self,
uri: Uri<&str>,
) -> impl Future<Output = Result<WebSocketConnection, Self::Error>>;
fn connect_with_headers(
&self,
uri: Uri<&str>,
_headers: Vec<(CowStr<'_>, CowStr<'_>)>,
) -> impl Future<Output = Result<WebSocketConnection, Self::Error>> {
async move { self.connect(uri).await }
}
}
pub struct WebSocketConnection {
tx: WsSink,
rx: WsStream,
}
impl WebSocketConnection {
pub fn new(tx: WsSink, rx: WsStream) -> Self {
Self { tx, rx }
}
pub fn sender_mut(&mut self) -> &mut WsSink {
&mut self.tx
}
pub fn receiver_mut(&mut self) -> &mut WsStream {
&mut self.rx
}
pub fn receiver(&self) -> &WsStream {
&self.rx
}
pub fn sender(&self) -> &WsSink {
&self.tx
}
pub fn split(self) -> (WsSink, WsStream) {
(self.tx, self.rx)
}
pub fn is_open(&self) -> bool {
true
}
}
impl fmt::Debug for WebSocketConnection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WebSocketConnection")
.finish_non_exhaustive()
}
}
pub mod tungstenite_client {
use super::*;
use crate::IntoStatic;
use futures::{SinkExt, StreamExt};
#[derive(Debug, Clone, Default)]
pub struct TungsteniteClient;
impl TungsteniteClient {
pub fn new() -> Self {
Self
}
}
impl WebSocketClient for TungsteniteClient {
type Error = tokio_tungstenite_wasm::Error;
async fn connect(&self, uri: Uri<&str>) -> Result<WebSocketConnection, Self::Error> {
let ws_stream = tokio_tungstenite_wasm::connect(uri.as_str()).await?;
let (sink, stream) = ws_stream.split();
let rx_stream = stream.filter_map(|result| async move {
match result {
Ok(msg) => match convert_message(msg) {
Some(ws_msg) => Some(Ok(ws_msg)),
None => None, },
Err(e) => Some(Err(StreamError::transport(e))),
}
});
let rx = WsStream::new(rx_stream);
let tx_sink = sink.with(|msg: WsMessage| async move {
Ok::<_, tokio_tungstenite_wasm::Error>(msg.into())
});
let tx_sink_mapped = tx_sink.sink_map_err(|e| StreamError::transport(e));
let tx = WsSink::new(tx_sink_mapped);
Ok(WebSocketConnection::new(tx, rx))
}
}
fn convert_message(msg: tokio_tungstenite_wasm::Message) -> Option<WsMessage> {
use tokio_tungstenite_wasm::Message;
match msg {
Message::Text(vec) => {
let bytes = Bytes::from(vec);
Some(WsMessage::Text(unsafe {
WsText::from_bytes_unchecked(bytes)
}))
}
Message::Binary(vec) => Some(WsMessage::Binary(Bytes::from(vec))),
Message::Close(frame) => {
let close_frame = frame.map(|f| {
let code = convert_close_code(f.code);
CloseFrame::new(code, CowStr::from(f.reason.into_owned()))
});
Some(WsMessage::Close(close_frame))
}
}
}
fn convert_close_code(code: tokio_tungstenite_wasm::CloseCode) -> CloseCode {
use tokio_tungstenite_wasm::CloseCode as TungsteniteCode;
match code {
TungsteniteCode::Normal => CloseCode::Normal,
TungsteniteCode::Away => CloseCode::Away,
TungsteniteCode::Protocol => CloseCode::Protocol,
TungsteniteCode::Unsupported => CloseCode::Unsupported,
TungsteniteCode::Invalid => CloseCode::Invalid,
TungsteniteCode::Policy => CloseCode::Policy,
TungsteniteCode::Size => CloseCode::Size,
TungsteniteCode::Extension => CloseCode::Extension,
TungsteniteCode::Error => CloseCode::Error,
TungsteniteCode::Tls => CloseCode::Tls,
other => {
let raw: u16 = other.into();
CloseCode::from(raw)
}
}
}
impl From<WsMessage> for tokio_tungstenite_wasm::Message {
fn from(msg: WsMessage) -> Self {
use tokio_tungstenite_wasm::Message;
match msg {
WsMessage::Text(text) => {
let bytes = text.into_bytes();
let string = unsafe { String::from_utf8_unchecked(bytes.to_vec()) };
Message::Text(string)
}
WsMessage::Binary(bytes) => Message::Binary(bytes.to_vec()),
WsMessage::Close(frame) => {
let close_frame = frame.map(|f| {
let code = u16::from(f.code).into();
tokio_tungstenite_wasm::CloseFrame {
code,
reason: f.reason.into_static().to_string().into(),
}
});
Message::Close(close_frame)
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ws_text_from_string() {
let text = WsText::from("hello");
assert_eq!(text.as_str(), "hello");
}
#[test]
fn ws_text_deref() {
let text = WsText::from(String::from("world"));
assert_eq!(&*text, "world");
}
#[test]
fn ws_text_try_from_bytes() {
let bytes = Bytes::from("test");
let text = WsText::try_from(bytes).unwrap();
assert_eq!(text.as_str(), "test");
}
#[test]
fn ws_text_invalid_utf8() {
let bytes = Bytes::from(vec![0xFF, 0xFE]);
assert!(WsText::try_from(bytes).is_err());
}
#[test]
fn ws_message_text() {
let msg = WsMessage::from("hello");
assert!(msg.is_text());
assert_eq!(msg.as_text(), Some("hello"));
}
#[test]
fn ws_message_binary() {
let msg = WsMessage::from(vec![1, 2, 3]);
assert!(msg.is_binary());
assert_eq!(msg.as_bytes(), Some(&[1u8, 2, 3][..]));
}
#[test]
fn close_code_conversion() {
assert_eq!(u16::from(CloseCode::Normal), 1000);
assert_eq!(CloseCode::from(1000), CloseCode::Normal);
assert_eq!(CloseCode::from(9999), CloseCode::Other(9999));
}
#[test]
fn websocket_connection_has_tx_and_rx() {
use futures::sink::SinkExt;
use futures::stream;
let rx_stream = stream::iter(vec![Ok(WsMessage::from("test"))]);
let rx = WsStream::new(rx_stream);
let drain_sink = futures::sink::drain()
.sink_map_err(|_: std::convert::Infallible| StreamError::closed());
let tx = WsSink::new(drain_sink);
let conn = WebSocketConnection::new(tx, rx);
assert!(conn.is_open());
}
}