use futures_util::{SinkExt, Stream, StreamExt};
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio_tungstenite::{connect_async, tungstenite::Message};
use crate::error::{Error, Result};
use crate::types::{MarketSubscription, WsEvent};
#[derive(Clone)]
pub struct SubscriptionHandle {
current_tokens: Arc<RwLock<Vec<String>>>,
}
impl SubscriptionHandle {
pub async fn current_tokens(&self) -> Vec<String> {
self.current_tokens.read().await.clone()
}
}
#[derive(Debug, Clone)]
pub struct MarketWsClient {
ws_url: String,
}
fn parse_ws_message(
msg: std::result::Result<Message, tokio_tungstenite::tungstenite::Error>,
) -> Option<Result<WsEvent>> {
match msg {
Ok(Message::Text(text)) => {
let trimmed = text.trim();
if trimmed.is_empty() {
return None;
}
if trimmed.eq_ignore_ascii_case("ping") || trimmed.eq_ignore_ascii_case("pong") {
return None;
}
if let Ok(events) = serde_json::from_str::<Vec<serde_json::Value>>(&text) {
if let Some(first) = events.first() {
match serde_json::from_value::<WsEvent>(first.clone()) {
Ok(event) => return Some(Ok(event)),
Err(e) => return Some(Err(Error::Json(e))),
}
} else {
return None;
}
}
match serde_json::from_str::<WsEvent>(&text) {
Ok(event) => Some(Ok(event)),
Err(e) => {
log::warn!(
"Unexpected WebSocket message (first 200 chars): {}",
&text.chars().take(200).collect::<String>()
);
Some(Err(Error::Json(e)))
}
}
}
Ok(Message::Close(_)) => {
Some(Err(Error::ConnectionClosed))
}
Ok(Message::Ping(_)) | Ok(Message::Pong(_)) => {
None
}
Ok(Message::Binary(_)) => {
Some(Err(Error::WebSocket(
"Unexpected binary message".to_string(),
)))
}
Ok(Message::Frame(_)) => {
None
}
Err(e) => {
Some(Err(Error::WebSocket(e.to_string())))
}
}
}
impl MarketWsClient {
const DEFAULT_WS_URL: &'static str = "wss://ws-subscriptions-clob.polymarket.com/ws/market";
pub fn new() -> Self {
Self {
ws_url: Self::DEFAULT_WS_URL.to_string(),
}
}
pub fn with_url(ws_url: impl Into<String>) -> Self {
Self {
ws_url: ws_url.into(),
}
}
pub async fn subscribe_with_handle(
&self,
token_ids: Vec<String>,
) -> Result<(
Pin<Box<dyn Stream<Item = Result<WsEvent>> + Send>>,
SubscriptionHandle,
)> {
let (ws_stream, _) = connect_async(&self.ws_url).await?;
let (write, read) = ws_stream.split();
let mut write = write;
let subscription = MarketSubscription {
assets_ids: token_ids.clone(),
};
let subscription_msg = serde_json::to_string(&subscription)?;
write
.send(Message::Text(subscription_msg))
.await
.map_err(|e| Error::WebSocket(e.to_string()))?;
drop(write);
let current_tokens = Arc::new(RwLock::new(token_ids));
let handle = SubscriptionHandle { current_tokens };
let stream = read.filter_map(|msg| async move { parse_ws_message(msg) });
Ok((Box::pin(stream), handle))
}
pub async fn subscribe(
&self,
token_ids: Vec<String>,
) -> Result<Pin<Box<dyn Stream<Item = Result<WsEvent>> + Send>>> {
let (ws_stream, _) = connect_async(&self.ws_url).await?;
let (write, read) = ws_stream.split();
let mut write = write;
let subscription = MarketSubscription {
assets_ids: token_ids,
};
let subscription_msg = serde_json::to_string(&subscription)?;
write
.send(Message::Text(subscription_msg))
.await
.map_err(|e| Error::WebSocket(e.to_string()))?;
drop(write);
let stream = read.filter_map(|msg| async move { parse_ws_message(msg) });
Ok(Box::pin(stream))
}
}
impl Default for MarketWsClient {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_creation() {
let client = MarketWsClient::new();
assert_eq!(client.ws_url, MarketWsClient::DEFAULT_WS_URL);
}
#[test]
fn test_client_with_custom_url() {
let custom_url = "wss://custom.example.com/ws";
let client = MarketWsClient::with_url(custom_url);
assert_eq!(client.ws_url, custom_url);
}
}