use std::{marker::PhantomData, sync::Arc, time::Duration};
use serde::{Serialize, de::DeserializeOwned};
use tokio::task::JoinHandle;
use tracing::info;
use super::{
config::WsConfig,
connection::{Connection, SubscriptionGuard},
protocol::{ProtocolHandler, WsMessage},
types::Topic,
};
use crate::error::TransportResult;
#[derive(Clone)]
pub struct WsClient<H: ProtocolHandler> {
handle: super::connection::ConnectionHandle,
_stream_task: Arc<JoinHandle<()>>,
_marker: PhantomData<H>,
}
impl<H: ProtocolHandler> WsClient<H> {
pub async fn connect(config: WsConfig, handler: H) -> TransportResult<Self> {
let url = config.url.clone();
let (handle, stream) = Connection::connect(config, handler).await?;
let stream_task = tokio::spawn(async move {
let mut stream = stream;
while let Some(_event) = stream.next().await {}
});
info!(url = %url, "WebSocket client created");
Ok(Self {
handle,
_stream_task: Arc::new(stream_task),
_marker: PhantomData,
})
}
pub fn is_connected(&self) -> bool {
self.handle.is_connected()
}
pub async fn request<R, T>(&self, request: &R) -> TransportResult<T>
where
R: Serialize,
T: DeserializeOwned,
{
self.request_with_timeout(request, None).await
}
pub async fn request_with_timeout<R, T>(
&self,
request: &R,
timeout: Option<Duration>,
) -> TransportResult<T>
where
R: Serialize,
T: DeserializeOwned,
{
self.handle.request_with_timeout(request, timeout).await
}
pub async fn request_raw(&self, message: WsMessage) -> TransportResult<String> {
self.handle.request_raw(message).await
}
pub async fn request_raw_with_timeout(
&self,
message: WsMessage,
timeout: Option<Duration>,
) -> TransportResult<String> {
self.handle.request_raw_with_timeout(message, timeout).await
}
pub async fn subscribe(&self, topic: impl Into<Topic>) -> TransportResult<SubscriptionGuard> {
self.handle.subscribe(topic).await
}
pub async fn subscribe_many(
&self,
topics: impl IntoIterator<Item = impl Into<Topic>>,
) -> TransportResult<Vec<SubscriptionGuard>> {
self.handle.subscribe_many(topics).await
}
pub async fn unsubscribe(&self, topic: impl Into<Topic>) -> TransportResult<()> {
self.handle.unsubscribe(topic).await
}
pub async fn send(&self, message: WsMessage) -> TransportResult<()> {
self.handle.send(message).await
}
pub async fn send_json<T: Serialize>(&self, payload: &T) -> TransportResult<()> {
let json = serde_json::to_string(payload)?;
self.send(WsMessage::text(json)).await
}
pub async fn close(&self) -> TransportResult<()> {
self.handle.close().await
}
pub fn pending_count(&self) -> usize {
self.handle.pending_count()
}
pub fn subscription_count(&self) -> usize {
self.handle.subscription_count()
}
pub fn subscribed_topics(&self) -> Vec<Topic> {
self.handle.subscribed_topics()
}
}
#[cfg(test)]
mod tests {
use super::{super::types::RequestId, *};
struct TestHandler;
impl ProtocolHandler for TestHandler {
fn classify_message(&self, _message: &str) -> super::super::types::MessageKind {
super::super::types::MessageKind::Update
}
fn extract_topic(&self, _message: &str) -> Option<Topic> {
None
}
fn extract_request_id(&self, _message: &str) -> Option<RequestId> {
None
}
fn build_subscribe(&self, _topics: &[Topic], _request_id: RequestId) -> WsMessage {
WsMessage::text("{}")
}
fn build_unsubscribe(&self, _topics: &[Topic], _request_id: RequestId) -> WsMessage {
WsMessage::text("{}")
}
}
#[test]
fn test_ws_client_is_send_sync() {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
assert_send::<WsClient<TestHandler>>();
assert_sync::<WsClient<TestHandler>>();
}
#[test]
fn ws_client_exposes_connection_handle() {
fn assert_has_handle<H: ProtocolHandler>(client: &WsClient<H>) {
let _ = &client.handle;
}
let _ = assert_has_handle::<TestHandler>;
}
}