use std::collections::{HashMap, hash_map::Entry};
use std::sync::Arc;
use alloy::primitives::Address;
use async_stream::try_stream;
use futures::Stream;
use futures::StreamExt as _;
use rust_decimal::Decimal;
use super::config::Config;
use super::connection::{ConnectionManager, ConnectionState};
use super::interest::InterestTracker;
use super::subscription::{ChannelType, SubscriptionManager};
use super::types::response::{
BookUpdate, MidpointUpdate, OrderMessage, PriceChange, TradeMessage, WsMessage,
};
use crate::Result;
use crate::auth::state::{Authenticated, State, Unauthenticated};
use crate::auth::{Credentials, Kind as AuthKind, Normal};
use crate::error::Error;
#[derive(Clone)]
pub struct Client<S: State = Unauthenticated> {
inner: Arc<ClientInner<S>>,
}
impl Default for Client<Unauthenticated> {
fn default() -> Self {
Self::new(
"wss://ws-subscriptions-clob.polymarket.com",
Config::default(),
)
.expect("WebSocket client with default endpoint should succeed")
}
}
struct ClientInner<S: State> {
state: S,
config: Config,
base_endpoint: String,
channels: HashMap<ChannelType, ChannelHandles>,
}
impl Client<Unauthenticated> {
pub fn new(endpoint: &str, config: Config) -> Result<Self> {
let normalized = normalize_base_endpoint(endpoint);
let market_handles =
ChannelHandles::connect(channel_endpoint(&normalized, ChannelType::Market), &config)?;
let mut channels = HashMap::new();
channels.insert(ChannelType::Market, market_handles);
Ok(Self {
inner: Arc::new(ClientInner {
state: Unauthenticated,
config,
base_endpoint: normalized,
channels,
}),
})
}
pub fn authenticate(
self,
credentials: Credentials,
address: Address,
) -> Result<Client<Authenticated<Normal>>> {
let inner = Arc::into_inner(self.inner).ok_or(Error::validation(
"Cannot authenticate while other references to this client exist; \
drop all clones before calling authenticate",
))?;
let ClientInner {
config,
base_endpoint,
mut channels,
..
} = inner;
if let Entry::Vacant(slot) = channels.entry(ChannelType::User) {
let handles = ChannelHandles::connect(
channel_endpoint(&base_endpoint, ChannelType::User),
&config,
)?;
slot.insert(handles);
}
Ok(Client {
inner: Arc::new(ClientInner {
state: Authenticated {
address,
credentials,
kind: Normal,
},
config,
base_endpoint,
channels,
}),
})
}
}
impl<S: State> Client<S> {
pub fn subscribe_orderbook(
&self,
asset_ids: Vec<String>,
) -> Result<impl Stream<Item = Result<BookUpdate>>> {
let stream = self
.market_handles()?
.subscriptions
.subscribe_market(asset_ids)?;
Ok(stream.filter_map(|msg_result| async move {
match msg_result {
Ok(WsMessage::Book(book)) => Some(Ok(book)),
Err(e) => Some(Err(e)),
_ => None,
}
}))
}
pub fn subscribe_prices(
&self,
asset_ids: Vec<String>,
) -> Result<impl Stream<Item = Result<PriceChange>>> {
let stream = self
.market_handles()?
.subscriptions
.subscribe_market(asset_ids)?;
Ok(stream.filter_map(|msg_result| async move {
match msg_result {
Ok(WsMessage::PriceChange(price)) => Some(Ok(price)),
Err(e) => Some(Err(e)),
_ => None,
}
}))
}
pub fn subscribe_midpoints(
&self,
asset_ids: Vec<String>,
) -> Result<impl Stream<Item = Result<MidpointUpdate>>> {
let stream = self.subscribe_orderbook(asset_ids)?;
Ok(try_stream! {
for await book_result in stream {
let book = book_result?;
if let (Some(bid), Some(ask)) = (book.bids.first(), book.asks.first()) {
let midpoint = (bid.price + ask.price) / Decimal::TWO;
yield MidpointUpdate {
asset_id: book.asset_id,
market: book.market,
midpoint,
timestamp: book.timestamp,
};
}
}
})
}
#[must_use]
pub fn connection_state(&self) -> ConnectionState {
if let Some(handles) = self.inner.channel(ChannelType::Market) {
handles.connection.state()
} else {
ConnectionState::Disconnected
}
}
#[must_use]
pub fn subscription_count(&self) -> usize {
self.inner
.channels
.values()
.map(|handles| handles.subscriptions.subscription_count())
.sum()
}
fn market_handles(&self) -> Result<&ChannelHandles> {
self.inner
.channel(ChannelType::Market)
.ok_or_else(|| Error::validation("Market channel unavailable; recreate client"))
}
}
impl<K: AuthKind> Client<Authenticated<K>> {
pub fn subscribe_user_events(
&self,
markets: Vec<String>,
) -> Result<impl Stream<Item = Result<WsMessage>>> {
let handles = self
.inner
.channel(ChannelType::User)
.ok_or_else(|| Error::validation("User channel unavailable; authenticate first"))?;
handles
.subscriptions
.subscribe_user(markets, self.inner.state.credentials.clone())
}
pub fn subscribe_orders(
&self,
markets: Vec<String>,
) -> Result<impl Stream<Item = Result<OrderMessage>>> {
let stream = self.subscribe_user_events(markets)?;
Ok(stream.filter_map(|msg_result| async move {
match msg_result {
Ok(WsMessage::Order(order)) => Some(Ok(order)),
Err(e) => Some(Err(e)),
_ => None,
}
}))
}
pub fn subscribe_trades(
&self,
markets: Vec<String>,
) -> Result<impl Stream<Item = Result<TradeMessage>>> {
let stream = self.subscribe_user_events(markets)?;
Ok(stream.filter_map(|msg_result| async move {
match msg_result {
Ok(WsMessage::Trade(trade)) => Some(Ok(trade)),
Err(e) => Some(Err(e)),
_ => None,
}
}))
}
pub fn deauthenticate(self) -> Result<Client<Unauthenticated>> {
let inner = Arc::into_inner(self.inner).ok_or(Error::validation(
"Cannot deauthenticate while other references to this client exist; \
drop all clones before calling deauthenticate",
))?;
let ClientInner {
config,
base_endpoint,
mut channels,
..
} = inner;
channels.remove(&ChannelType::User);
Ok(Client {
inner: Arc::new(ClientInner {
state: Unauthenticated,
config,
base_endpoint,
channels,
}),
})
}
}
impl<S: State> ClientInner<S> {
fn channel(&self, kind: ChannelType) -> Option<&ChannelHandles> {
self.channels.get(&kind)
}
}
#[derive(Clone)]
struct ChannelHandles {
connection: ConnectionManager,
subscriptions: Arc<SubscriptionManager>,
}
impl ChannelHandles {
fn connect(endpoint: String, config: &Config) -> Result<Self> {
let interest = Arc::new(InterestTracker::new());
let connection = ConnectionManager::new(endpoint, config.clone(), &interest)?;
let subscriptions = Arc::new(SubscriptionManager::new(connection.clone(), interest));
subscriptions.start_reconnection_handler();
Ok(Self {
connection,
subscriptions,
})
}
}
fn normalize_base_endpoint(endpoint: &str) -> String {
let trimmed = endpoint.trim_end_matches('/');
if let Some(stripped) = trimmed.strip_suffix("/ws/market") {
stripped.to_owned()
} else if let Some(stripped) = trimmed.strip_suffix("/ws/user") {
stripped.to_owned()
} else if let Some(stripped) = trimmed.strip_suffix("/ws") {
stripped.to_owned()
} else {
trimmed.to_owned()
}
}
fn channel_endpoint(base: &str, channel: ChannelType) -> String {
let trimmed = base.trim_end_matches('/');
let segment = match channel {
ChannelType::Market => "market",
ChannelType::User => "user",
};
format!("{trimmed}/ws/{segment}")
}