1use std::sync::Arc;
4use std::time::Duration;
5
6use futures_util::{SinkExt, StreamExt};
7use serde::{Deserialize, Serialize};
8use serde_json;
9use tokio::sync::{mpsc, RwLock};
10use tokio_tungstenite::{connect_async, tungstenite::Message};
11use tracing::{debug, error, info, warn};
12
13use crate::types::{Blockhash, Heartbeat, PoolUpdate, PriorityFees, Quote};
14use crate::ws::decoder::decode_message;
15
16#[derive(Debug, Clone)]
18pub struct Config {
19 pub api_key: String,
21 pub endpoint: String,
23 pub reconnect: bool,
25 pub reconnect_delay_initial: Duration,
27 pub reconnect_delay_max: Duration,
29 pub ping_interval: Duration,
31}
32
33impl Default for Config {
34 fn default() -> Self {
35 Self {
36 api_key: String::new(),
37 endpoint: "wss://gateway.k256.xyz/v1/ws".to_string(),
38 reconnect: true,
39 reconnect_delay_initial: Duration::from_secs(1),
40 reconnect_delay_max: Duration::from_secs(60),
41 ping_interval: Duration::from_secs(30),
42 }
43 }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct SubscribeRequest {
49 #[serde(rename = "type")]
51 pub request_type: String,
52 pub channels: Vec<String>,
54 #[serde(skip_serializing_if = "Option::is_none")]
56 pub format: Option<String>,
57 #[serde(skip_serializing_if = "Option::is_none")]
59 pub protocols: Option<Vec<String>>,
60 #[serde(skip_serializing_if = "Option::is_none")]
62 pub pools: Option<Vec<String>>,
63 #[serde(skip_serializing_if = "Option::is_none")]
65 pub token_pairs: Option<Vec<(String, String)>>,
66}
67
68impl Default for SubscribeRequest {
69 fn default() -> Self {
70 Self {
71 request_type: "subscribe".to_string(),
72 channels: vec![
73 "pools".to_string(),
74 "priority_fees".to_string(),
75 "blockhash".to_string(),
76 ],
77 format: None,
78 protocols: None,
79 pools: None,
80 token_pairs: None,
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
87pub enum DecodedMessage {
88 PoolUpdate(PoolUpdate),
90 PoolUpdateBatch(Vec<PoolUpdate>),
92 PriorityFees(PriorityFees),
94 Blockhash(Blockhash),
96 Quote(Quote),
98 Heartbeat(Heartbeat),
100 Error(String),
102 Subscribed { channels: Vec<String> },
104}
105
106type Callback<T> = Arc<RwLock<Option<Box<dyn Fn(T) + Send + Sync + 'static>>>>;
107
108pub struct K256WebSocketClient {
110 config: Config,
111 tx: mpsc::Sender<Message>,
112 on_pool_update: Callback<PoolUpdate>,
113 on_priority_fees: Callback<PriorityFees>,
114 on_blockhash: Callback<Blockhash>,
115 on_quote: Callback<Quote>,
116 on_heartbeat: Callback<Heartbeat>,
117 on_error: Callback<String>,
118}
119
120impl K256WebSocketClient {
121 pub fn new(config: Config) -> Self {
123 let (tx, _rx) = mpsc::channel(100);
124 Self {
125 config,
126 tx,
127 on_pool_update: Arc::new(RwLock::new(None)),
128 on_priority_fees: Arc::new(RwLock::new(None)),
129 on_blockhash: Arc::new(RwLock::new(None)),
130 on_quote: Arc::new(RwLock::new(None)),
131 on_heartbeat: Arc::new(RwLock::new(None)),
132 on_error: Arc::new(RwLock::new(None)),
133 }
134 }
135
136 pub fn on_pool_update<F>(&self, callback: F)
138 where
139 F: Fn(PoolUpdate) + Send + Sync + 'static,
140 {
141 let rt = tokio::runtime::Handle::current();
142 rt.block_on(async {
143 *self.on_pool_update.write().await = Some(Box::new(callback));
144 });
145 }
146
147 pub fn on_priority_fees<F>(&self, callback: F)
149 where
150 F: Fn(PriorityFees) + Send + Sync + 'static,
151 {
152 let rt = tokio::runtime::Handle::current();
153 rt.block_on(async {
154 *self.on_priority_fees.write().await = Some(Box::new(callback));
155 });
156 }
157
158 pub fn on_blockhash<F>(&self, callback: F)
160 where
161 F: Fn(Blockhash) + Send + Sync + 'static,
162 {
163 let rt = tokio::runtime::Handle::current();
164 rt.block_on(async {
165 *self.on_blockhash.write().await = Some(Box::new(callback));
166 });
167 }
168
169 pub fn on_quote<F>(&self, callback: F)
171 where
172 F: Fn(Quote) + Send + Sync + 'static,
173 {
174 let rt = tokio::runtime::Handle::current();
175 rt.block_on(async {
176 *self.on_quote.write().await = Some(Box::new(callback));
177 });
178 }
179
180 pub fn on_heartbeat<F>(&self, callback: F)
182 where
183 F: Fn(Heartbeat) + Send + Sync + 'static,
184 {
185 let rt = tokio::runtime::Handle::current();
186 rt.block_on(async {
187 *self.on_heartbeat.write().await = Some(Box::new(callback));
188 });
189 }
190
191 pub fn on_error<F>(&self, callback: F)
193 where
194 F: Fn(String) + Send + Sync + 'static,
195 {
196 let rt = tokio::runtime::Handle::current();
197 rt.block_on(async {
198 *self.on_error.write().await = Some(Box::new(callback));
199 });
200 }
201
202 pub async fn connect(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
204 let url = format!("{}?apiKey={}", self.config.endpoint, self.config.api_key);
205
206 let (ws_stream, _) = connect_async(&url).await?;
207 info!("Connected to K256 WebSocket");
208
209 let (mut write, mut read) = ws_stream.split();
210
211 let on_pool_update = self.on_pool_update.clone();
212 let on_priority_fees = self.on_priority_fees.clone();
213 let on_blockhash = self.on_blockhash.clone();
214 let on_quote = self.on_quote.clone();
215 let on_heartbeat = self.on_heartbeat.clone();
216 let on_error = self.on_error.clone();
217
218 let recv_task = tokio::spawn(async move {
220 while let Some(msg) = read.next().await {
221 match msg {
222 Ok(Message::Binary(data)) => {
223 if data.is_empty() {
224 continue;
225 }
226
227 let msg_type = data[0];
228 let payload = &data[1..];
229
230 match decode_message(msg_type, payload) {
231 Ok(Some(decoded)) => {
232 match decoded {
233 DecodedMessage::PoolUpdate(update) => {
234 if let Some(cb) = on_pool_update.read().await.as_ref() {
235 cb(update);
236 }
237 }
238 DecodedMessage::PoolUpdateBatch(updates) => {
239 if let Some(cb) = on_pool_update.read().await.as_ref() {
240 for update in updates {
241 cb(update);
242 }
243 }
244 }
245 DecodedMessage::PriorityFees(fees) => {
246 if let Some(cb) = on_priority_fees.read().await.as_ref() {
247 cb(fees);
248 }
249 }
250 DecodedMessage::Blockhash(bh) => {
251 if let Some(cb) = on_blockhash.read().await.as_ref() {
252 cb(bh);
253 }
254 }
255 DecodedMessage::Quote(quote) => {
256 if let Some(cb) = on_quote.read().await.as_ref() {
257 cb(quote);
258 }
259 }
260 DecodedMessage::Heartbeat(hb) => {
261 if let Some(cb) = on_heartbeat.read().await.as_ref() {
262 cb(hb);
263 }
264 }
265 DecodedMessage::Error(err) => {
266 error!("Server error: {}", err);
267 if let Some(cb) = on_error.read().await.as_ref() {
268 cb(err);
269 }
270 }
271 DecodedMessage::Subscribed { channels } => {
272 info!("Subscribed to channels: {:?}", channels);
273 }
274 }
275 }
276 Ok(None) => {
277 debug!("Unhandled message type: {}", msg_type);
278 }
279 Err(e) => {
280 error!("Error decoding message: {}", e);
281 }
282 }
283 }
284 Ok(Message::Text(text)) => {
285 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&text) {
287 if let Some(msg_type) = json.get("type").and_then(|t| t.as_str()) {
288 match msg_type {
289 "heartbeat" => {
290 if let Some(cb) = on_heartbeat.read().await.as_ref() {
291 let hb = Heartbeat {
292 timestamp_ms: json.get("timestamp_ms")
293 .and_then(|v| v.as_u64()).unwrap_or(0),
294 uptime_seconds: json.get("uptime_seconds")
295 .and_then(|v| v.as_u64()).unwrap_or(0),
296 messages_received: json.get("messages_received")
297 .and_then(|v| v.as_u64()).unwrap_or(0),
298 messages_sent: json.get("messages_sent")
299 .and_then(|v| v.as_u64()).unwrap_or(0),
300 subscriptions: json.get("subscriptions")
301 .and_then(|v| v.as_u64()).unwrap_or(0) as u32,
302 };
303 cb(hb);
304 }
305 }
306 "subscribed" => {
307 if let Some(channels) = json.get("channels").and_then(|c| c.as_array()) {
308 let channel_names: Vec<String> = channels
309 .iter()
310 .filter_map(|c| c.as_str().map(String::from))
311 .collect();
312 info!("Subscribed to channels: {:?}", channel_names);
313 }
314 }
315 "error" => {
316 let err_msg = json.get("message")
317 .and_then(|m| m.as_str())
318 .unwrap_or("Unknown error")
319 .to_string();
320 error!("Server error: {}", err_msg);
321 if let Some(cb) = on_error.read().await.as_ref() {
322 cb(err_msg);
323 }
324 }
325 _ => {
326 debug!("Unhandled text message type: {}", msg_type);
327 }
328 }
329 }
330 } else {
331 debug!("Received non-JSON text message: {}", text);
332 }
333 }
334 Ok(Message::Close(_)) => {
335 warn!("WebSocket closed");
336 break;
337 }
338 Err(e) => {
339 error!("WebSocket error: {}", e);
340 break;
341 }
342 _ => {}
343 }
344 }
345 });
346
347 let mut rx = {
349 let (_tx, rx) = mpsc::channel::<Message>(100);
350 rx
353 };
354
355 let send_task = tokio::spawn(async move {
356 while let Some(msg) = rx.recv().await {
357 if let Err(e) = write.send(msg).await {
358 error!("Failed to send message: {}", e);
359 break;
360 }
361 }
362 });
363
364 tokio::select! {
366 _ = recv_task => {}
367 _ = send_task => {}
368 }
369
370 Ok(())
371 }
372
373 pub async fn subscribe(
375 &self,
376 request: SubscribeRequest,
377 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
378 let msg = serde_json::to_string(&request)?;
379 self.tx.send(Message::Text(msg)).await?;
380 Ok(())
381 }
382
383 pub async fn unsubscribe(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
385 let msg = r#"{"type":"unsubscribe"}"#;
386 self.tx.send(Message::Text(msg.to_string())).await?;
387 Ok(())
388 }
389}