use std::collections::HashSet;
use std::sync::{Arc, Mutex as StdMutex};
use std::time::Instant;
use async_trait::async_trait;
use futures_util::{StreamExt, SinkExt, stream::{SplitSink, SplitStream}};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::{broadcast, Mutex};
use tokio_tungstenite::{connect_async, tungstenite::Message, WebSocketStream, MaybeTlsStream};
use crate::core::{
Credentials, AccountType, Symbol,
ExchangeError, ExchangeResult,
ConnectionStatus, StreamEvent, StreamType, SubscriptionRequest,
};
use crate::core::types::{WebSocketResult, WebSocketError, OrderbookCapabilities, WsBookChannel};
use crate::core::traits::WebSocketConnector;
use super::auth::GeminiAuth;
use super::endpoints::{GeminiUrls, normalize_symbol, format_symbol};
use super::parser::GeminiParser;
type WsStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
type WsSink = SplitSink<WsStream, Message>;
type WsReader = SplitStream<WsStream>;
#[derive(Debug, Clone, Serialize)]
struct SubscribeMessage {
#[serde(rename = "type")]
msg_type: String,
subscriptions: Vec<SubscriptionItem>,
}
#[derive(Debug, Clone, Serialize)]
struct SubscriptionItem {
name: String,
symbols: Vec<String>,
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)]
struct IncomingMessage {
#[serde(rename = "type")]
msg_type: Option<String>,
symbol: Option<String>,
changes: Option<Value>,
trades: Option<Vec<Value>>,
event_id: Option<i64>,
price: Option<String>,
quantity: Option<String>,
side: Option<String>,
timestamp: Option<i64>,
timestampms: Option<i64>,
order_id: Option<String>,
socket_sequence: Option<u64>,
}
pub struct GeminiWebSocket {
ws_type: WebSocketType,
auth: Option<GeminiAuth>,
urls: GeminiUrls,
status: Arc<Mutex<ConnectionStatus>>,
subscriptions: Arc<Mutex<HashSet<String>>>,
broadcast_tx: Arc<StdMutex<Option<broadcast::Sender<WebSocketResult<StreamEvent>>>>>,
ws_writer: Arc<Mutex<Option<WsSink>>>,
last_heartbeat: Arc<Mutex<Instant>>,
ws_ping_rtt_ms: Arc<Mutex<u64>>,
}
#[derive(Debug, Clone, Copy)]
pub enum WebSocketType {
MarketData,
OrderEvents,
}
impl GeminiWebSocket {
pub async fn new_market_data(testnet: bool) -> ExchangeResult<Self> {
let urls = if testnet {
GeminiUrls::TESTNET
} else {
GeminiUrls::MAINNET
};
Ok(Self {
ws_type: WebSocketType::MarketData,
auth: None,
urls,
status: Arc::new(Mutex::new(ConnectionStatus::Disconnected)),
subscriptions: Arc::new(Mutex::new(HashSet::new())),
broadcast_tx: Arc::new(StdMutex::new(None)),
ws_writer: Arc::new(Mutex::new(None)),
last_heartbeat: Arc::new(Mutex::new(Instant::now())),
ws_ping_rtt_ms: Arc::new(Mutex::new(0)),
})
}
pub async fn new_order_events(
credentials: Credentials,
testnet: bool,
) -> ExchangeResult<Self> {
let urls = if testnet {
GeminiUrls::TESTNET
} else {
GeminiUrls::MAINNET
};
let auth = Some(GeminiAuth::new(&credentials)?);
Ok(Self {
ws_type: WebSocketType::OrderEvents,
auth,
urls,
status: Arc::new(Mutex::new(ConnectionStatus::Disconnected)),
subscriptions: Arc::new(Mutex::new(HashSet::new())),
broadcast_tx: Arc::new(StdMutex::new(None)),
ws_writer: Arc::new(Mutex::new(None)),
last_heartbeat: Arc::new(Mutex::new(Instant::now())),
ws_ping_rtt_ms: Arc::new(Mutex::new(0)),
})
}
pub async fn connect(&self) -> ExchangeResult<()> {
let url = match self.ws_type {
WebSocketType::MarketData => self.urls.ws_market_url(),
WebSocketType::OrderEvents => self.urls.ws_orders_url(),
};
let ws_stream = if matches!(self.ws_type, WebSocketType::OrderEvents) {
let auth = self.auth.as_ref()
.ok_or_else(|| ExchangeError::Auth("Authentication required for order events".to_string()))?;
let headers = auth.sign_websocket_request("/v1/order/events")?;
use tokio_tungstenite::tungstenite::handshake::client::Request;
let mut request = Request::builder()
.uri(url);
for (key, value) in headers {
request = request.header(key, value);
}
let request = request.body(())
.map_err(|e| ExchangeError::Network(format!("Failed to build request: {}", e)))?;
let (ws, _) = connect_async(request).await
.map_err(|e| ExchangeError::Network(e.to_string()))?;
ws
} else {
let (ws, _) = connect_async(url).await
.map_err(|e| ExchangeError::Network(e.to_string()))?;
ws
};
let (write, read) = ws_stream.split();
*self.ws_writer.lock().await = Some(write);
*self.status.lock().await = ConnectionStatus::Connected;
let (broadcast_sender, _) = broadcast::channel(1000);
*self.broadcast_tx.lock().unwrap() = Some(broadcast_sender);
self.start_message_handler(read);
Ok(())
}
pub async fn disconnect(&self) -> ExchangeResult<()> {
if let Some(mut writer) = self.ws_writer.lock().await.take() {
writer.close().await.ok();
}
let _ = self.broadcast_tx.lock().unwrap().take();
*self.status.lock().await = ConnectionStatus::Disconnected;
Ok(())
}
pub async fn subscribe_orderbook(&self, symbol: Symbol) -> ExchangeResult<()> {
self.subscribe_orderbook_with_account_type(symbol, AccountType::Spot).await
}
pub async fn subscribe_orderbook_with_account_type(&self, symbol: Symbol, account_type: AccountType) -> ExchangeResult<()> {
let symbol_str = normalize_symbol(&format_symbol(&symbol.base, &symbol.quote, account_type));
self.subscribe("l2", vec![symbol_str.to_uppercase()]).await
}
pub async fn subscribe_candles(&self, symbol: Symbol, interval: &str) -> ExchangeResult<()> {
self.subscribe_candles_with_account_type(symbol, interval, AccountType::Spot).await
}
pub async fn subscribe_candles_with_account_type(&self, symbol: Symbol, interval: &str, account_type: AccountType) -> ExchangeResult<()> {
let symbol_str = normalize_symbol(&format_symbol(&symbol.base, &symbol.quote, account_type));
let feed_name = format!("candles_{}", interval);
self.subscribe(&feed_name, vec![symbol_str.to_uppercase()]).await
}
async fn subscribe(&self, feed_name: &str, symbols: Vec<String>) -> ExchangeResult<()> {
if !matches!(self.ws_type, WebSocketType::MarketData) {
return Err(ExchangeError::Network("Subscriptions only for market data".to_string()));
}
let status = *self.status.lock().await;
if status != ConnectionStatus::Connected {
return Err(ExchangeError::Network("Not connected to WebSocket".to_string()));
}
let subscribe_msg = SubscribeMessage {
msg_type: "subscribe".to_string(),
subscriptions: vec![SubscriptionItem {
name: feed_name.to_string(),
symbols: symbols.clone(),
}],
};
let json_str = serde_json::to_string(&subscribe_msg)
.map_err(|e| ExchangeError::Parse(e.to_string()))?;
let mut writer_guard = self.ws_writer.lock().await;
if let Some(writer) = writer_guard.as_mut() {
writer.send(Message::Text(json_str)).await
.map_err(|e| ExchangeError::Network(e.to_string()))?;
} else {
return Err(ExchangeError::Network("WebSocket stream not available".to_string()));
}
for sym in symbols {
self.subscriptions.lock().await.insert(format!("{}:{}", feed_name, sym));
}
Ok(())
}
pub fn event_stream(&self) -> broadcast::Receiver<WebSocketResult<StreamEvent>> {
self.broadcast_tx.lock().unwrap().as_ref()
.map(|tx| tx.subscribe())
.unwrap_or_else(|| broadcast::channel(1).1)
}
fn start_message_handler(&self, mut reader: WsReader) {
let ws_writer = Arc::clone(&self.ws_writer);
let broadcast_tx = Arc::clone(&self.broadcast_tx);
let status = Arc::clone(&self.status);
let last_heartbeat = Arc::clone(&self.last_heartbeat);
let ws_type = self.ws_type;
tokio::spawn(async move {
while let Some(msg) = reader.next().await {
match msg {
Ok(Message::Text(text)) => {
if let Ok(events) = Self::parse_message(&text, ws_type) {
if let Some(tx) = broadcast_tx.lock().unwrap().as_ref() {
for evt in events {
tx.send(Ok(evt)).ok();
}
}
}
if matches!(ws_type, WebSocketType::OrderEvents) {
*last_heartbeat.lock().await = Instant::now();
}
}
Ok(Message::Ping(data)) => {
let mut writer_guard = ws_writer.lock().await;
if let Some(writer) = writer_guard.as_mut() {
writer.send(Message::Pong(data)).await.ok();
}
}
Ok(Message::Close(_)) => {
*status.lock().await = ConnectionStatus::Disconnected;
break;
}
Err(e) => {
if let Some(tx) = broadcast_tx.lock().unwrap().as_ref() {
tx.send(Err(WebSocketError::ConnectionError(e.to_string()))).ok();
}
break;
}
_ => {}
}
}
let _ = broadcast_tx.lock().unwrap().take();
*status.lock().await = ConnectionStatus::Disconnected;
});
}
fn parse_message(text: &str, ws_type: WebSocketType) -> ExchangeResult<Vec<StreamEvent>> {
let value: Value = serde_json::from_str(text)
.map_err(|e| ExchangeError::Parse(e.to_string()))?;
let msg_type = value.get("type").and_then(|t| t.as_str());
match (ws_type, msg_type) {
(WebSocketType::MarketData, Some("subscribed")) => {
Ok(vec![])
}
(WebSocketType::MarketData, Some("l2_updates")) => {
let mut events: Vec<StreamEvent> = Vec::new();
let has_changes = value.get("changes")
.and_then(|c| c.as_array())
.map(|a| !a.is_empty())
.unwrap_or(false);
if has_changes {
match GeminiParser::parse_ws_l2_update(&value) {
Ok(ev) => events.push(ev),
Err(_) => {} }
}
let has_trades = value.get("trades")
.and_then(|t| t.as_array())
.map(|a| !a.is_empty())
.unwrap_or(false);
if has_trades {
match GeminiParser::parse_ws_l2_trade(&value) {
Ok(ev) => events.push(ev),
Err(_) => {}
}
}
Ok(events)
}
(WebSocketType::MarketData, Some(t)) if t.starts_with("candles_") => {
let kline = GeminiParser::parse_ws_candle(&value)?;
Ok(vec![StreamEvent::Kline(kline)])
}
(WebSocketType::OrderEvents, Some("subscription_ack")) => {
Ok(vec![])
}
(WebSocketType::OrderEvents, Some("heartbeat")) => {
Ok(vec![])
}
(WebSocketType::OrderEvents, Some("initial" | "accepted" | "booked" | "fill" | "cancelled" | "rejected" | "closed")) => {
let order_event = GeminiParser::parse_ws_order_event(&value)?;
Ok(vec![StreamEvent::OrderUpdate(order_event)])
}
_ => Ok(vec![]),
}
}
}
#[async_trait]
impl WebSocketConnector for GeminiWebSocket {
async fn connect(&mut self, _account_type: AccountType) -> WebSocketResult<()> {
Self::connect(self).await
.map_err(|e| WebSocketError::ConnectionError(e.to_string()))
}
async fn disconnect(&mut self) -> WebSocketResult<()> {
Self::disconnect(self).await
.map_err(|e| WebSocketError::ConnectionError(e.to_string()))
}
fn connection_status(&self) -> ConnectionStatus {
self.status.try_lock()
.map(|guard| *guard)
.unwrap_or(ConnectionStatus::Disconnected)
}
async fn subscribe(&mut self, request: SubscriptionRequest) -> WebSocketResult<()> {
match request.stream_type {
StreamType::Ticker => {
self.subscribe_orderbook(request.symbol).await
.map_err(|e| WebSocketError::Subscription(e.to_string()))
}
StreamType::Trade => {
self.subscribe_orderbook(request.symbol).await
.map_err(|e| WebSocketError::Subscription(e.to_string()))
}
StreamType::Orderbook => {
self.subscribe_orderbook(request.symbol).await
.map_err(|e| WebSocketError::Subscription(e.to_string()))
}
StreamType::Kline { interval } => {
self.subscribe_candles(request.symbol, &interval).await
.map_err(|e| WebSocketError::Subscription(e.to_string()))
}
_ => Err(WebSocketError::Subscription(format!("{:?} not supported", request.stream_type))),
}
}
async fn unsubscribe(&mut self, _request: SubscriptionRequest) -> WebSocketResult<()> {
Err(WebSocketError::Subscription("Unsubscribe not supported by Gemini".to_string()))
}
fn event_stream(&self) -> std::pin::Pin<Box<dyn futures_util::Stream<Item = WebSocketResult<StreamEvent>> + Send>> {
let rx = self.broadcast_tx.lock().unwrap().as_ref()
.map(|tx| tx.subscribe())
.unwrap_or_else(|| broadcast::channel(1).1);
Box::pin(futures_util::stream::unfold(rx, |mut rx| async move {
match rx.recv().await {
Ok(event) => Some((event, rx)),
Err(_) => None,
}
}))
}
fn active_subscriptions(&self) -> Vec<SubscriptionRequest> {
vec![]
}
fn ping_rtt_handle(&self) -> Option<Arc<Mutex<u64>>> {
Some(self.ws_ping_rtt_ms.clone())
}
fn orderbook_capabilities(&self, _account_type: AccountType) -> OrderbookCapabilities {
static GEMINI_CHANNELS: &[WsBookChannel] = &[
WsBookChannel::snapshot("depth5", 5, 1000),
WsBookChannel::snapshot("depth10", 10, 1000),
WsBookChannel::snapshot("depth20", 20, 1000),
WsBookChannel::snapshot("depth5@100ms", 5, 100),
WsBookChannel::snapshot("depth10@100ms", 10, 100),
WsBookChannel::snapshot("depth20@100ms", 20, 100),
WsBookChannel::delta("depth", None, Some(1000)),
WsBookChannel::delta("depth@100ms", None, Some(100)),
];
OrderbookCapabilities {
ws_depths: &[5, 10, 20],
ws_default_depth: Some(20),
rest_max_depth: None,
rest_depth_values: &[],
supports_snapshot: true,
supports_delta: true,
update_speeds_ms: &[100, 1000],
default_speed_ms: Some(1000),
ws_channels: GEMINI_CHANNELS,
checksum: None,
has_sequence: true,
has_prev_sequence: false,
supports_aggregation: false,
aggregation_levels: &[],
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_websocket_creation() {
let ws = GeminiWebSocket::new_market_data(false).await.unwrap();
let status = *ws.status.lock().await;
assert_eq!(status, ConnectionStatus::Disconnected);
}
}