use cfg_if::cfg_if;
mod wasm;
pub use wasm::WebSocketInterface as _WSI;
cfg_if! {
if #[cfg(target_arch = "wasm32")] {
use wasm::WebSocketInterface;
} else {
mod native;
use native::WebSocketInterface;
}
}
pub mod bindings;
pub mod config;
pub mod error;
pub mod message;
pub mod options;
pub mod result;
pub use config::WebSocketConfig;
pub use error::Error;
use futures::Future;
pub use message::*;
pub use options::{ConnectOptions, ConnectStrategy};
pub use result::Result;
use async_trait::async_trait;
use std::pin::Pin;
use std::sync::Arc;
use workflow_core::channel::{oneshot, Channel, Receiver, Sender};
pub type ConnectResult<E> = std::result::Result<Option<Receiver<Result<()>>>, E>;
pub type HandshakeFn = Arc<
Box<dyn Send + Sync + Fn(&Sender<Message>, &Receiver<Message>) -> HandshakeFnReturn + 'static>,
>;
pub type HandshakeFnReturn = Pin<Box<(dyn Send + Sync + 'static + Future<Output = Result<()>>)>>;
#[async_trait]
pub trait Handshake: Send + Sync + 'static {
async fn handshake(&self, sender: &Sender<Message>, receiver: &Receiver<Message>)
-> Result<()>;
}
#[async_trait]
pub trait Resolver: Send + Sync + 'static {
async fn resolve_url(&self) -> ResolverResult;
}
pub type ResolverResult = Result<String>;
pub type WebSocketError = Error;
struct Inner {
client: Arc<WebSocketInterface>,
sender_channel: Channel<(Message, Ack)>,
receiver_channel: Channel<Message>,
}
impl Inner {
pub fn new(
client: Arc<WebSocketInterface>,
sender_channel: Channel<(Message, Ack)>,
receiver_channel: Channel<Message>,
) -> Self {
Self {
client,
sender_channel,
receiver_channel,
}
}
}
#[derive(Clone)]
pub struct WebSocket {
inner: Arc<Inner>,
}
impl WebSocket {
pub fn new(url: Option<&str>, config: Option<WebSocketConfig>) -> Result<WebSocket> {
if let Some(url) = url {
if !url.starts_with("ws://") && !url.starts_with("wss://") {
return Err(Error::AddressSchema(url.to_string()));
}
}
let config = config.unwrap_or_default();
let receiver_channel = if let Some(cap) = config.receiver_channel_cap {
Channel::bounded(cap)
} else {
Channel::<Message>::unbounded()
};
let sender_channel = if let Some(cap) = config.sender_channel_cap {
Channel::bounded(cap)
} else {
Channel::<(Message, Ack)>::unbounded()
};
let client = Arc::new(WebSocketInterface::new(
url,
Some(config),
sender_channel.clone(),
receiver_channel.clone(),
)?);
let websocket = WebSocket {
inner: Arc::new(Inner::new(client, sender_channel, receiver_channel)),
};
Ok(websocket)
}
pub fn url(&self) -> Option<String> {
self.inner.client.current_url()
}
pub fn set_url(&self, url: &str) {
self.inner.client.set_default_url(url);
}
pub fn configure(&self, config: WebSocketConfig) {
self.inner.client.configure(config);
}
pub fn sender_tx(&self) -> &Sender<(Message, Ack)> {
&self.inner.sender_channel.sender
}
pub fn receiver_rx(&self) -> &Receiver<Message> {
&self.inner.receiver_channel.receiver
}
pub fn is_connected(&self) -> bool {
self.inner.client.is_connected()
}
pub async fn connect(&self, options: ConnectOptions) -> ConnectResult<Error> {
self.inner.client.connect(options).await
}
pub async fn disconnect(&self) -> Result<()> {
self.inner.client.disconnect().await
}
pub async fn reconnect(&self) -> Result<()> {
self.inner.client.close().await
}
pub async fn post(&self, message: Message) -> Result<&Self> {
if !self.inner.client.is_connected() {
return Err(Error::NotConnected);
}
let result = Ok(self
.inner
.sender_channel
.sender
.send((message, None))
.await?);
workflow_core::task::yield_now().await;
result.map(|_| self)
}
pub async fn send(&self, message: Message) -> std::result::Result<&Self, Arc<Error>> {
if !self.inner.client.is_connected() {
return Err(Arc::new(Error::NotConnected));
}
let (ack_sender, ack_receiver) = oneshot();
self.inner
.sender_channel
.send((message, Some(ack_sender)))
.await
.map_err(|err| Arc::new(err.into()))?;
ack_receiver
.recv()
.await
.map_err(|_| Arc::new(Error::DispatchChannelAck))?
.map(|_| self)
}
pub async fn recv(&self) -> Result<Message> {
Ok(self.inner.receiver_channel.receiver.recv().await?)
}
pub fn trigger_abort(&self) -> Result<()> {
self.inner.client.trigger_abort()
}
}