Skip to main content

k256_sdk/ws/
client.rs

1//! K256 WebSocket client implementation.
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use futures_util::{SinkExt, StreamExt};
7use serde::{Deserialize, Serialize};
8use serde_json;
9use tokio::sync::{mpsc, RwLock};
10use tokio_tungstenite::{connect_async, tungstenite::Message};
11use tracing::{debug, error, info, warn};
12
13use crate::types::{Blockhash, Heartbeat, PoolUpdate, PriorityFees, Quote};
14use crate::ws::decoder::decode_message;
15
16/// Configuration for K256 WebSocket client.
17#[derive(Debug, Clone)]
18pub struct Config {
19    /// K256 API key
20    pub api_key: String,
21    /// WebSocket endpoint URL
22    pub endpoint: String,
23    /// Whether to automatically reconnect
24    pub reconnect: bool,
25    /// Initial reconnect delay
26    pub reconnect_delay_initial: Duration,
27    /// Maximum reconnect delay
28    pub reconnect_delay_max: Duration,
29    /// Ping interval (0 to disable)
30    pub ping_interval: Duration,
31}
32
33impl Default for Config {
34    fn default() -> Self {
35        Self {
36            api_key: String::new(),
37            endpoint: "wss://gateway.k256.xyz/v1/ws".to_string(),
38            reconnect: true,
39            reconnect_delay_initial: Duration::from_secs(1),
40            reconnect_delay_max: Duration::from_secs(60),
41            ping_interval: Duration::from_secs(30),
42        }
43    }
44}
45
46/// WebSocket subscription request.
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct SubscribeRequest {
49    /// Request type (always "subscribe")
50    #[serde(rename = "type")]
51    pub request_type: String,
52    /// List of channels to subscribe to
53    pub channels: Vec<String>,
54    /// Message format ("binary" or "json")
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub format: Option<String>,
57    /// Optional list of DEX protocols to filter
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub protocols: Option<Vec<String>>,
60    /// Optional list of pool addresses to filter
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub pools: Option<Vec<String>>,
63    /// Optional list of token pairs to filter
64    #[serde(skip_serializing_if = "Option::is_none")]
65    pub token_pairs: Option<Vec<(String, String)>>,
66}
67
68impl Default for SubscribeRequest {
69    fn default() -> Self {
70        Self {
71            request_type: "subscribe".to_string(),
72            channels: vec![
73                "pools".to_string(),
74                "priority_fees".to_string(),
75                "blockhash".to_string(),
76            ],
77            format: None,
78            protocols: None,
79            pools: None,
80            token_pairs: None,
81        }
82    }
83}
84
85/// Decoded WebSocket message.
86#[derive(Debug, Clone)]
87pub enum DecodedMessage {
88    /// Pool update
89    PoolUpdate(PoolUpdate),
90    /// Batch of pool updates
91    PoolUpdateBatch(Vec<PoolUpdate>),
92    /// Priority fees
93    PriorityFees(PriorityFees),
94    /// Blockhash
95    Blockhash(Blockhash),
96    /// Quote
97    Quote(Quote),
98    /// Heartbeat
99    Heartbeat(Heartbeat),
100    /// Error message
101    Error(String),
102    /// Subscription confirmed
103    Subscribed { channels: Vec<String> },
104}
105
106type Callback<T> = Arc<RwLock<Option<Box<dyn Fn(T) + Send + Sync + 'static>>>>;
107
108/// K256 WebSocket client for real-time Solana liquidity data.
109pub struct K256WebSocketClient {
110    config: Config,
111    tx: mpsc::Sender<Message>,
112    on_pool_update: Callback<PoolUpdate>,
113    on_priority_fees: Callback<PriorityFees>,
114    on_blockhash: Callback<Blockhash>,
115    on_quote: Callback<Quote>,
116    on_heartbeat: Callback<Heartbeat>,
117    on_error: Callback<String>,
118}
119
120impl K256WebSocketClient {
121    /// Create a new WebSocket client with the given configuration.
122    pub fn new(config: Config) -> Self {
123        let (tx, _rx) = mpsc::channel(100);
124        Self {
125            config,
126            tx,
127            on_pool_update: Arc::new(RwLock::new(None)),
128            on_priority_fees: Arc::new(RwLock::new(None)),
129            on_blockhash: Arc::new(RwLock::new(None)),
130            on_quote: Arc::new(RwLock::new(None)),
131            on_heartbeat: Arc::new(RwLock::new(None)),
132            on_error: Arc::new(RwLock::new(None)),
133        }
134    }
135
136    /// Register a callback for pool updates.
137    pub fn on_pool_update<F>(&self, callback: F)
138    where
139        F: Fn(PoolUpdate) + Send + Sync + 'static,
140    {
141        let rt = tokio::runtime::Handle::current();
142        rt.block_on(async {
143            *self.on_pool_update.write().await = Some(Box::new(callback));
144        });
145    }
146
147    /// Register a callback for priority fee updates.
148    pub fn on_priority_fees<F>(&self, callback: F)
149    where
150        F: Fn(PriorityFees) + Send + Sync + 'static,
151    {
152        let rt = tokio::runtime::Handle::current();
153        rt.block_on(async {
154            *self.on_priority_fees.write().await = Some(Box::new(callback));
155        });
156    }
157
158    /// Register a callback for blockhash updates.
159    pub fn on_blockhash<F>(&self, callback: F)
160    where
161        F: Fn(Blockhash) + Send + Sync + 'static,
162    {
163        let rt = tokio::runtime::Handle::current();
164        rt.block_on(async {
165            *self.on_blockhash.write().await = Some(Box::new(callback));
166        });
167    }
168
169    /// Register a callback for quote updates.
170    pub fn on_quote<F>(&self, callback: F)
171    where
172        F: Fn(Quote) + Send + Sync + 'static,
173    {
174        let rt = tokio::runtime::Handle::current();
175        rt.block_on(async {
176            *self.on_quote.write().await = Some(Box::new(callback));
177        });
178    }
179
180    /// Register a callback for heartbeat messages.
181    pub fn on_heartbeat<F>(&self, callback: F)
182    where
183        F: Fn(Heartbeat) + Send + Sync + 'static,
184    {
185        let rt = tokio::runtime::Handle::current();
186        rt.block_on(async {
187            *self.on_heartbeat.write().await = Some(Box::new(callback));
188        });
189    }
190
191    /// Register a callback for errors.
192    pub fn on_error<F>(&self, callback: F)
193    where
194        F: Fn(String) + Send + Sync + 'static,
195    {
196        let rt = tokio::runtime::Handle::current();
197        rt.block_on(async {
198            *self.on_error.write().await = Some(Box::new(callback));
199        });
200    }
201
202    /// Connect to the K256 WebSocket.
203    pub async fn connect(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
204        let url = format!("{}?apiKey={}", self.config.endpoint, self.config.api_key);
205
206        let (ws_stream, _) = connect_async(&url).await?;
207        info!("Connected to K256 WebSocket");
208
209        let (mut write, mut read) = ws_stream.split();
210
211        let on_pool_update = self.on_pool_update.clone();
212        let on_priority_fees = self.on_priority_fees.clone();
213        let on_blockhash = self.on_blockhash.clone();
214        let on_quote = self.on_quote.clone();
215        let on_heartbeat = self.on_heartbeat.clone();
216        let on_error = self.on_error.clone();
217
218        // Message receiving task
219        let recv_task = tokio::spawn(async move {
220            while let Some(msg) = read.next().await {
221                match msg {
222                    Ok(Message::Binary(data)) => {
223                        if data.is_empty() {
224                            continue;
225                        }
226
227                        let msg_type = data[0];
228                        let payload = &data[1..];
229
230                        match decode_message(msg_type, payload) {
231                            Ok(Some(decoded)) => {
232                                match decoded {
233                                    DecodedMessage::PoolUpdate(update) => {
234                                        if let Some(cb) = on_pool_update.read().await.as_ref() {
235                                            cb(update);
236                                        }
237                                    }
238                                    DecodedMessage::PoolUpdateBatch(updates) => {
239                                        if let Some(cb) = on_pool_update.read().await.as_ref() {
240                                            for update in updates {
241                                                cb(update);
242                                            }
243                                        }
244                                    }
245                                    DecodedMessage::PriorityFees(fees) => {
246                                        if let Some(cb) = on_priority_fees.read().await.as_ref() {
247                                            cb(fees);
248                                        }
249                                    }
250                                    DecodedMessage::Blockhash(bh) => {
251                                        if let Some(cb) = on_blockhash.read().await.as_ref() {
252                                            cb(bh);
253                                        }
254                                    }
255                                    DecodedMessage::Quote(quote) => {
256                                        if let Some(cb) = on_quote.read().await.as_ref() {
257                                            cb(quote);
258                                        }
259                                    }
260                                    DecodedMessage::Heartbeat(hb) => {
261                                        if let Some(cb) = on_heartbeat.read().await.as_ref() {
262                                            cb(hb);
263                                        }
264                                    }
265                                    DecodedMessage::Error(err) => {
266                                        error!("Server error: {}", err);
267                                        if let Some(cb) = on_error.read().await.as_ref() {
268                                            cb(err);
269                                        }
270                                    }
271                                    DecodedMessage::Subscribed { channels } => {
272                                        info!("Subscribed to channels: {:?}", channels);
273                                    }
274                                }
275                            }
276                            Ok(None) => {
277                                debug!("Unhandled message type: {}", msg_type);
278                            }
279                            Err(e) => {
280                                error!("Error decoding message: {}", e);
281                            }
282                        }
283                    }
284                    Ok(Message::Text(text)) => {
285                        // Parse JSON text messages for Heartbeat and other JSON responses
286                        if let Ok(json) = serde_json::from_str::<serde_json::Value>(&text) {
287                            if let Some(msg_type) = json.get("type").and_then(|t| t.as_str()) {
288                                match msg_type {
289                                    "heartbeat" => {
290                                        if let Some(cb) = on_heartbeat.read().await.as_ref() {
291                                            let hb = Heartbeat {
292                                                timestamp_ms: json.get("timestamp_ms")
293                                                    .and_then(|v| v.as_u64()).unwrap_or(0),
294                                                uptime_seconds: json.get("uptime_seconds")
295                                                    .and_then(|v| v.as_u64()).unwrap_or(0),
296                                                messages_received: json.get("messages_received")
297                                                    .and_then(|v| v.as_u64()).unwrap_or(0),
298                                                messages_sent: json.get("messages_sent")
299                                                    .and_then(|v| v.as_u64()).unwrap_or(0),
300                                                subscriptions: json.get("subscriptions")
301                                                    .and_then(|v| v.as_u64()).unwrap_or(0) as u32,
302                                            };
303                                            cb(hb);
304                                        }
305                                    }
306                                    "subscribed" => {
307                                        if let Some(channels) = json.get("channels").and_then(|c| c.as_array()) {
308                                            let channel_names: Vec<String> = channels
309                                                .iter()
310                                                .filter_map(|c| c.as_str().map(String::from))
311                                                .collect();
312                                            info!("Subscribed to channels: {:?}", channel_names);
313                                        }
314                                    }
315                                    "error" => {
316                                        let err_msg = json.get("message")
317                                            .and_then(|m| m.as_str())
318                                            .unwrap_or("Unknown error")
319                                            .to_string();
320                                        error!("Server error: {}", err_msg);
321                                        if let Some(cb) = on_error.read().await.as_ref() {
322                                            cb(err_msg);
323                                        }
324                                    }
325                                    _ => {
326                                        debug!("Unhandled text message type: {}", msg_type);
327                                    }
328                                }
329                            }
330                        } else {
331                            debug!("Received non-JSON text message: {}", text);
332                        }
333                    }
334                    Ok(Message::Close(_)) => {
335                        warn!("WebSocket closed");
336                        break;
337                    }
338                    Err(e) => {
339                        error!("WebSocket error: {}", e);
340                        break;
341                    }
342                    _ => {}
343                }
344            }
345        });
346
347        // Message sending task
348        let mut rx = {
349            let (_tx, rx) = mpsc::channel::<Message>(100);
350            // Note: In a real implementation, we'd store tx in self
351            // This is a simplified version
352            rx
353        };
354
355        let send_task = tokio::spawn(async move {
356            while let Some(msg) = rx.recv().await {
357                if let Err(e) = write.send(msg).await {
358                    error!("Failed to send message: {}", e);
359                    break;
360                }
361            }
362        });
363
364        // Wait for tasks
365        tokio::select! {
366            _ = recv_task => {}
367            _ = send_task => {}
368        }
369
370        Ok(())
371    }
372
373    /// Subscribe to channels.
374    pub async fn subscribe(
375        &self,
376        request: SubscribeRequest,
377    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
378        let msg = serde_json::to_string(&request)?;
379        self.tx.send(Message::Text(msg)).await?;
380        Ok(())
381    }
382
383    /// Unsubscribe from all channels.
384    pub async fn unsubscribe(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
385        let msg = r#"{"type":"unsubscribe"}"#;
386        self.tx.send(Message::Text(msg.to_string())).await?;
387        Ok(())
388    }
389}