use std::collections::HashSet;
use std::pin::Pin;
use std::sync::{Arc, Mutex as StdMutex};
use std::time::{Duration, Instant};
use async_trait::async_trait;
use futures_util::{Stream, StreamExt, SinkExt};
use serde::Serialize;
use serde_json::{json, Value};
use tokio::sync::{mpsc, broadcast, Mutex};
use tokio_tungstenite::{connect_async, tungstenite::Message, WebSocketStream, MaybeTlsStream};
use crate::core::{
Credentials, AccountType, ExchangeResult,
ConnectionStatus, StreamEvent, SubscriptionRequest,
Ticker, PublicTrade, OrderBook,
};
use crate::core::types::OrderBookLevel;
use crate::core::types::TradeSide;
use crate::core::types::{WebSocketResult, WebSocketError, OrderbookCapabilities};
use crate::core::traits::WebSocketConnector;
use super::auth::LighterAuth;
use super::endpoints::{LighterUrls, symbol_to_market_id};
fn build_channel(stream_type: &crate::core::types::StreamType, base: &str) -> Result<String, WebSocketError> {
let market_id = symbol_to_market_id(base).ok_or_else(|| {
WebSocketError::UnsupportedOperation(
format!("Unknown Lighter market for base asset '{}'. Known: ETH(0), BTC(1), SOL(2), etc.", base)
)
})?;
match stream_type {
crate::core::types::StreamType::Ticker => Ok(format!("market_stats/{}", market_id)),
crate::core::types::StreamType::Trade => Ok(format!("trade/{}", market_id)),
crate::core::types::StreamType::Orderbook => Ok(format!("order_book/{}", market_id)),
other => Err(WebSocketError::UnsupportedOperation(
format!("Stream type {:?} not supported for Lighter WebSocket", other)
)),
}
}
#[derive(Debug, Clone, Serialize)]
struct SubscribeMessage {
#[serde(rename = "type")]
msg_type: String,
channel: String,
#[serde(skip_serializing_if = "Option::is_none")]
auth: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
#[allow(dead_code)]
struct PingMessage {
#[serde(rename = "type")]
msg_type: String,
}
#[derive(Debug, Clone, Serialize)]
#[allow(dead_code)]
struct PongMessage {
#[serde(rename = "type")]
msg_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
timestamp: Option<i64>,
}
#[derive(Debug, Clone)]
struct IncomingMessage {
raw: Value,
}
impl IncomingMessage {
fn from_value(v: Value) -> Self {
Self { raw: v }
}
fn msg_type(&self) -> Option<&str> {
self.raw.get("type").and_then(|v| v.as_str())
}
fn channel(&self) -> Option<&str> {
self.raw.get("channel").and_then(|v| v.as_str())
}
fn error_message(&self) -> Option<String> {
if let Some(msg) = self.raw.get("message").and_then(|v| v.as_str()) {
return Some(msg.to_string());
}
if let Some(err) = self.raw.get("error") {
if let Some(msg) = err.get("message").and_then(|v| v.as_str()) {
return Some(msg.to_string());
}
}
None
}
fn data_object(&self, key: &str) -> Option<&Value> {
self.raw.get(key)
}
}
type WsStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
#[allow(dead_code)]
pub struct LighterWebSocket {
auth: Option<LighterAuth>,
urls: LighterUrls,
testnet: bool,
ws: Arc<Mutex<Option<WsStream>>>,
status: Arc<Mutex<ConnectionStatus>>,
subscriptions: Arc<Mutex<HashSet<String>>>,
subscription_requests: Arc<Mutex<Vec<SubscriptionRequest>>>,
event_tx: mpsc::UnboundedSender<WebSocketResult<StreamEvent>>,
event_rx: Arc<Mutex<Option<mpsc::UnboundedReceiver<WebSocketResult<StreamEvent>>>>>,
broadcast_tx: Arc<StdMutex<Option<broadcast::Sender<WebSocketResult<StreamEvent>>>>>,
last_ping: Arc<Mutex<Instant>>,
ping_interval: Duration,
ws_ping_rtt_ms: Arc<Mutex<u64>>,
}
impl LighterWebSocket {
pub async fn new(
credentials: Option<Credentials>,
testnet: bool,
) -> ExchangeResult<Self> {
let urls = if testnet {
LighterUrls::TESTNET
} else {
LighterUrls::MAINNET
};
let auth = credentials
.as_ref()
.map(LighterAuth::new)
.transpose()?;
let (event_tx, event_rx) = mpsc::unbounded_channel();
Ok(Self {
auth,
urls,
testnet,
ws: Arc::new(Mutex::new(None)),
status: Arc::new(Mutex::new(ConnectionStatus::Disconnected)),
subscriptions: Arc::new(Mutex::new(HashSet::new())),
subscription_requests: Arc::new(Mutex::new(Vec::new())),
event_tx,
event_rx: Arc::new(Mutex::new(Some(event_rx))),
broadcast_tx: Arc::new(StdMutex::new(None)),
last_ping: Arc::new(Mutex::new(Instant::now())),
ping_interval: Duration::from_secs(30),
ws_ping_rtt_ms: Arc::new(Mutex::new(0)),
})
}
pub async fn public(testnet: bool) -> ExchangeResult<Self> {
Self::new(None, testnet).await
}
async fn connect_ws(&self) -> WebSocketResult<()> {
let ws_url = self.urls.ws_url();
let (ws_stream, _) = connect_async(ws_url)
.await
.map_err(|e| WebSocketError::ConnectionError(e.to_string()))?;
*self.ws.lock().await = Some(ws_stream);
*self.status.lock().await = ConnectionStatus::Connected;
Ok(())
}
async fn disconnect_ws(&self) -> WebSocketResult<()> {
if let Some(mut ws) = self.ws.lock().await.take() {
let _ = ws.close(None).await;
}
*self.status.lock().await = ConnectionStatus::Disconnected;
let _ = self.broadcast_tx.lock().unwrap().take();
Ok(())
}
async fn subscribe_channel(&self, channel: &str, auth: Option<String>) -> WebSocketResult<()> {
let msg = SubscribeMessage {
msg_type: "subscribe".to_string(),
channel: channel.to_string(),
auth,
};
let json_str = serde_json::to_string(&msg)
.map_err(|e| WebSocketError::Parse(e.to_string()))?;
if let Some(ws) = self.ws.lock().await.as_mut() {
ws.send(Message::Text(json_str))
.await
.map_err(|e| WebSocketError::SendError(e.to_string()))?;
self.subscriptions.lock().await.insert(channel.to_string());
} else {
return Err(WebSocketError::NotConnected);
}
Ok(())
}
async fn unsubscribe_channel(&self, channel: &str) -> WebSocketResult<()> {
let msg = json!({
"type": "unsubscribe",
"channel": channel
});
let json_str = serde_json::to_string(&msg)
.map_err(|e| WebSocketError::Parse(e.to_string()))?;
if let Some(ws) = self.ws.lock().await.as_mut() {
ws.send(Message::Text(json_str))
.await
.map_err(|e| WebSocketError::SendError(e.to_string()))?;
self.subscriptions.lock().await.remove(channel);
}
Ok(())
}
#[allow(dead_code)]
async fn send_ping(&self) -> WebSocketResult<()> {
let msg = PingMessage {
msg_type: "ping".to_string(),
};
let json_str = serde_json::to_string(&msg)
.map_err(|e| WebSocketError::Parse(e.to_string()))?;
if let Some(ws) = self.ws.lock().await.as_mut() {
ws.send(Message::Text(json_str))
.await
.map_err(|e| WebSocketError::SendError(e.to_string()))?;
*self.last_ping.lock().await = Instant::now();
}
Ok(())
}
async fn start_message_loop(&self) {
let ws = self.ws.clone();
let event_tx = self.event_tx.clone();
let status = self.status.clone();
let last_ping = self.last_ping.clone();
let ws_ping_rtt_ms = self.ws_ping_rtt_ms.clone();
tokio::spawn(async move {
loop {
let mut ws_guard = ws.lock().await;
if ws_guard.is_none() {
break;
}
if let Some(msg_result) = ws_guard.as_mut().expect("WebSocket is initialized").next().await {
match msg_result {
Ok(Message::Text(text)) => {
let val = match serde_json::from_str::<Value>(&text) {
Ok(v) => v,
Err(_) => continue,
};
if val.get("type").and_then(|t| t.as_str()) == Some("ping") {
let ts = val.get("timestamp").and_then(|t| t.as_i64());
let pong = if let Some(ts) = ts {
json!({"type": "pong", "timestamp": ts})
} else {
json!({"type": "pong"})
};
if let Some(ws_inner) = ws_guard.as_mut() {
let _ = ws_inner.send(Message::Text(pong.to_string())).await;
}
continue;
}
let incoming = IncomingMessage::from_value(val);
Self::handle_message(incoming, &event_tx);
}
Ok(Message::Ping(data)) => {
if let Some(ws_inner) = ws_guard.as_mut() {
let _ = ws_inner.send(Message::Pong(data)).await;
}
}
Ok(Message::Pong(_)) => {
let rtt = last_ping.lock().await.elapsed().as_millis() as u64;
*ws_ping_rtt_ms.lock().await = rtt;
}
Ok(Message::Close(_)) => {
*status.lock().await = ConnectionStatus::Disconnected;
break;
}
Err(_) => {
*status.lock().await = ConnectionStatus::Disconnected;
break;
}
_ => {}
}
}
}
});
}
fn start_ws_ping_task(&self) {
let ws = self.ws.clone();
let last_ping = self.last_ping.clone();
let ping_interval = self.ping_interval;
tokio::spawn(async move {
let mut interval = tokio::time::interval(ping_interval);
interval.tick().await;
loop {
interval.tick().await;
let mut ws_guard = ws.lock().await;
if let Some(ws_inner) = ws_guard.as_mut() {
*last_ping.lock().await = Instant::now();
if ws_inner.send(Message::Ping(vec![])).await.is_err() {
break;
}
} else {
break;
}
}
});
}
fn start_forwarder(&self) {
let broadcast_tx = self.broadcast_tx.clone();
let event_rx = self.event_rx.clone();
let (tx, _) = broadcast::channel(1000);
*broadcast_tx.lock().unwrap() = Some(tx);
let broadcast_tx_inner = self.broadcast_tx.clone();
tokio::spawn(async move {
let mut rx = match event_rx.lock().await.take() {
Some(rx) => rx,
None => return,
};
while let Some(event) = rx.recv().await {
let tx_guard = broadcast_tx_inner.lock().unwrap();
if let Some(ref tx) = *tx_guard {
let _ = tx.send(event);
}
}
let _ = broadcast_tx_inner.lock().unwrap().take();
});
}
fn handle_message(
msg: IncomingMessage,
event_tx: &mpsc::UnboundedSender<WebSocketResult<StreamEvent>>,
) {
if msg.raw.get("error").is_some() {
let error_msg = msg.error_message().unwrap_or_else(|| "Unknown error".to_string());
let _ = event_tx.send(Err(WebSocketError::ProtocolError(error_msg)));
return;
}
match msg.msg_type() {
Some("pong") => return,
Some("connected") => return,
Some("error") => {
let error_msg = msg.error_message().unwrap_or_else(|| "Unknown error".to_string());
eprintln!("[lighter-ws] error from server: {}", error_msg);
let _ = event_tx.send(Err(WebSocketError::ProtocolError(error_msg)));
return;
}
None => return,
_ => {}
}
let msg_type = msg.msg_type().unwrap_or("");
let channel = msg.channel().unwrap_or("");
match msg_type {
"update/orderbook" | "update/order_book" => {
if let Some(event) = Self::parse_orderbook(&msg, channel) {
let _ = event_tx.send(Ok(event));
}
}
"update/trade" => {
if let Some(event) = Self::parse_trade(&msg, channel) {
let _ = event_tx.send(Ok(event));
}
}
"update/market_stats" => {
if let Some(event) = Self::parse_market_stats(&msg, channel) {
let _ = event_tx.send(Ok(event));
}
}
"update/ticker" => {
if let Some(event) = Self::parse_ticker_channel(&msg, channel) {
let _ = event_tx.send(Ok(event));
}
}
"update/height" => {}
"update/account" | "update/account_all" | "update/account_market" => {}
_ => {}
}
}
fn extract_market_id(channel: &str) -> &str {
channel.rsplit(':').next()
.or_else(|| channel.rsplit('/').next())
.unwrap_or(channel)
}
fn val_f64(obj: &Value, field: &str) -> Option<f64> {
obj.get(field).and_then(|v| {
v.as_str().and_then(|s| s.parse::<f64>().ok())
.or_else(|| v.as_f64())
})
}
fn val_str<'a>(obj: &'a Value, field: &str) -> Option<&'a str> {
obj.get(field).and_then(|v| v.as_str())
}
fn val_i64(obj: &Value, field: &str) -> Option<i64> {
obj.get(field).and_then(|v| v.as_i64())
}
fn val_u64(obj: &Value, field: &str) -> Option<u64> {
obj.get(field).and_then(|v| v.as_u64())
}
fn val_bool(obj: &Value, field: &str) -> Option<bool> {
obj.get(field).and_then(|v| v.as_bool())
}
fn parse_levels(arr: &Value) -> Vec<OrderBookLevel> {
arr.as_array()
.map(|levels| {
levels.iter().filter_map(|entry| {
if let Some(obj) = entry.as_object() {
let price = obj.get("price")
.and_then(|v| v.as_str())
.and_then(|s| s.parse::<f64>().ok())?;
let size = obj.get("size")
.and_then(|v| v.as_str())
.and_then(|s| s.parse::<f64>().ok())?;
Some(OrderBookLevel::new(price, size))
}
else if let Some(pair_arr) = entry.as_array() {
if pair_arr.len() >= 2 {
let price = pair_arr[0].as_str()?.parse::<f64>().ok()?;
let size = pair_arr[1].as_str()?.parse::<f64>().ok()?;
Some(OrderBookLevel::new(price, size))
} else {
None
}
} else {
None
}
}).collect()
})
.unwrap_or_default()
}
fn parse_orderbook(msg: &IncomingMessage, _channel: &str) -> Option<StreamEvent> {
let data = msg.data_object("order_book").unwrap_or(&msg.raw);
let asks = data.get("asks").map(Self::parse_levels).unwrap_or_default();
let bids = data.get("bids").map(Self::parse_levels).unwrap_or_default();
if asks.is_empty() && bids.is_empty() {
return None;
}
let timestamp = Self::val_i64(&msg.raw, "timestamp")
.or_else(|| Self::val_i64(data, "timestamp"))
.unwrap_or(0);
let sequence = Self::val_i64(data, "nonce").map(|n| n.to_string());
Some(StreamEvent::OrderbookSnapshot(OrderBook {
bids,
asks,
timestamp,
sequence,
last_update_id: None,
first_update_id: None,
prev_update_id: None,
event_time: None,
transaction_time: None,
checksum: None,
}))
}
fn parse_trade(msg: &IncomingMessage, channel: &str) -> Option<StreamEvent> {
let data = msg.data_object("trade").unwrap_or(&msg.raw);
let price = Self::val_f64(data, "price")?;
let quantity = Self::val_f64(data, "size")?;
let timestamp = Self::val_i64(&msg.raw, "timestamp")
.or_else(|| Self::val_i64(data, "timestamp"))
.unwrap_or(0);
let trade_id = Self::val_u64(data, "trade_id").unwrap_or(0);
let market_id = Self::extract_market_id(channel);
let side = if let Some(side_str) = Self::val_str(data, "side") {
match side_str {
"buy" => TradeSide::Buy,
"sell" => TradeSide::Sell,
_ => {
if Self::val_bool(data, "is_maker_ask").unwrap_or(false) {
TradeSide::Buy
} else {
TradeSide::Sell
}
}
}
} else if Self::val_bool(data, "is_maker_ask").unwrap_or(false) {
TradeSide::Buy
} else {
TradeSide::Sell
};
Some(StreamEvent::Trade(PublicTrade {
id: trade_id.to_string(),
symbol: market_id.to_string(),
price,
quantity,
side,
timestamp,
}))
}
fn parse_market_stats(msg: &IncomingMessage, channel: &str) -> Option<StreamEvent> {
let data = msg.data_object("market_stats").unwrap_or(&msg.raw);
let last_price = Self::val_f64(data, "last_trade_price")
.or_else(|| Self::val_f64(data, "last_price"))
.or_else(|| Self::val_f64(data, "mark_price"))?;
let market_id = Self::extract_market_id(channel);
let symbol_name = Self::val_str(data, "symbol").unwrap_or(market_id);
let high_24h = Self::val_f64(data, "daily_price_high")
.or_else(|| Self::val_f64(data, "daily_high"));
let low_24h = Self::val_f64(data, "daily_price_low")
.or_else(|| Self::val_f64(data, "daily_low"));
let volume_24h = Self::val_f64(data, "daily_volume")
.or_else(|| Self::val_f64(data, "daily_base_token_volume"));
let price_change_24h = Self::val_f64(data, "daily_price_change")
.or_else(|| Self::val_f64(data, "daily_change"));
let timestamp = Self::val_i64(&msg.raw, "timestamp")
.or_else(|| Self::val_i64(data, "timestamp"))
.unwrap_or(0);
let price_change_percent_24h = price_change_24h.and_then(|change| {
let open = last_price - change;
if open.abs() > 1e-10 {
Some((change / open) * 100.0)
} else {
None
}
});
Some(StreamEvent::Ticker(Ticker {
symbol: symbol_name.to_string(),
last_price,
bid_price: None,
ask_price: None,
high_24h,
low_24h,
volume_24h,
quote_volume_24h: None,
price_change_24h,
price_change_percent_24h,
timestamp,
}))
}
fn parse_ticker_channel(msg: &IncomingMessage, channel: &str) -> Option<StreamEvent> {
let data = msg.data_object("ticker").unwrap_or(&msg.raw);
let last_price = Self::val_f64(data, "last_price")
.or_else(|| Self::val_f64(data, "mark_price"))?;
let bid_price = Self::val_f64(data, "best_bid")
.or_else(|| Self::val_f64(data, "bid_price"));
let ask_price = Self::val_f64(data, "best_ask")
.or_else(|| Self::val_f64(data, "ask_price"));
let market_id = Self::extract_market_id(channel);
let symbol_name = Self::val_str(data, "symbol").unwrap_or(market_id);
let timestamp = Self::val_i64(&msg.raw, "timestamp")
.or_else(|| Self::val_i64(data, "timestamp"))
.unwrap_or(0);
Some(StreamEvent::Ticker(Ticker {
symbol: symbol_name.to_string(),
last_price,
bid_price,
ask_price,
high_24h: None,
low_24h: None,
volume_24h: None,
quote_volume_24h: None,
price_change_24h: None,
price_change_percent_24h: None,
timestamp,
}))
}
}
#[async_trait]
impl WebSocketConnector for LighterWebSocket {
async fn connect(&mut self, _account_type: AccountType) -> WebSocketResult<()> {
self.connect_ws().await?;
self.start_message_loop().await;
self.start_forwarder();
self.start_ws_ping_task();
Ok(())
}
async fn disconnect(&mut self) -> WebSocketResult<()> {
self.disconnect_ws().await
}
fn connection_status(&self) -> ConnectionStatus {
match self.status.try_lock() {
Ok(status) => *status,
Err(_) => ConnectionStatus::Disconnected,
}
}
async fn subscribe(&mut self, request: SubscriptionRequest) -> WebSocketResult<()> {
let channel = build_channel(&request.stream_type, &request.symbol.base)?;
let auth = None;
self.subscribe_channel(&channel, auth).await?;
self.subscription_requests.lock().await.push(request);
Ok(())
}
async fn unsubscribe(&mut self, request: SubscriptionRequest) -> WebSocketResult<()> {
let channel = build_channel(&request.stream_type, &request.symbol.base)?;
self.unsubscribe_channel(&channel).await?;
self.subscription_requests.lock().await.retain(|r| {
r.symbol != request.symbol || r.stream_type != request.stream_type
});
Ok(())
}
fn event_stream(&self) -> Pin<Box<dyn Stream<Item = WebSocketResult<StreamEvent>> + Send>> {
let tx_guard = self.broadcast_tx.lock().unwrap();
if let Some(ref tx) = *tx_guard {
let rx = tx.subscribe();
Box::pin(tokio_stream::wrappers::BroadcastStream::new(rx).map(|r| {
r.map_err(|e| WebSocketError::ConnectionError(format!("Broadcast error: {}", e)))
.and_then(|x| x)
}))
} else {
Box::pin(futures_util::stream::empty())
}
}
fn active_subscriptions(&self) -> Vec<SubscriptionRequest> {
match self.subscription_requests.try_lock() {
Ok(subs) => subs.clone(),
Err(_) => Vec::new(),
}
}
fn ping_rtt_handle(&self) -> Option<Arc<Mutex<u64>>> {
Some(self.ws_ping_rtt_ms.clone())
}
fn orderbook_capabilities(&self, _account_type: AccountType) -> OrderbookCapabilities {
OrderbookCapabilities {
ws_depths: &[],
ws_default_depth: None,
rest_max_depth: Some(250),
rest_depth_values: &[],
supports_snapshot: true,
supports_delta: true,
update_speeds_ms: &[50],
default_speed_ms: Some(50),
ws_channels: &[],
checksum: None,
has_sequence: true,
has_prev_sequence: true,
supports_aggregation: false,
aggregation_levels: &[],
}
}
}
impl LighterWebSocket {
pub async fn subscribe_orderbook(&self, market_id: u16) -> WebSocketResult<()> {
let channel = format!("order_book/{}", market_id);
self.subscribe_channel(&channel, None).await
}
pub async fn subscribe_trades(&self, market_id: u16) -> WebSocketResult<()> {
let channel = format!("trade/{}", market_id);
self.subscribe_channel(&channel, None).await
}
pub async fn subscribe_market_stats(&self, market_id: u16) -> WebSocketResult<()> {
let channel = format!("market_stats/{}", market_id);
self.subscribe_channel(&channel, None).await
}
pub async fn subscribe_account(&self, account_id: u64) -> WebSocketResult<()> {
let channel = format!("account_all/{}", account_id);
self.subscribe_channel(&channel, None).await
}
pub async fn subscribe_height(&self) -> WebSocketResult<()> {
self.subscribe_channel("height", None).await
}
pub async fn subscribe_ticker(&self, market_id: u16) -> WebSocketResult<()> {
let channel = format!("ticker/{}", market_id);
self.subscribe_channel(&channel, None).await
}
pub async fn unsubscribe_ticker(&self, market_id: u16) -> WebSocketResult<()> {
let channel = format!("ticker/{}", market_id);
self.unsubscribe_channel(&channel).await
}
pub async fn subscribe_market_stats_all(&self) -> WebSocketResult<()> {
self.subscribe_channel("market_stats/all", None).await
}
pub async fn unsubscribe_market_stats_all(&self) -> WebSocketResult<()> {
self.unsubscribe_channel("market_stats/all").await
}
pub async fn subscribe_account_market(
&self,
market_id: u16,
account_id: u64,
auth: Option<String>,
) -> WebSocketResult<()> {
let channel = format!("account_market/{}/{}", market_id, account_id);
self.subscribe_channel(&channel, auth).await
}
pub async fn unsubscribe_account_market(
&self,
market_id: u16,
account_id: u64,
) -> WebSocketResult<()> {
let channel = format!("account_market/{}/{}", market_id, account_id);
self.unsubscribe_channel(&channel).await
}
}