1use std::collections::HashMap;
6use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
7use std::sync::Arc;
8
9use futures_util::{SinkExt, StreamExt};
10use parking_lot::RwLock;
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use tokio::sync::mpsc;
14use tokio_tungstenite::{connect_async, tungstenite::protocol::Message};
15
16use crate::error::ChainStreamError;
17use crate::openapi::types::Resolution;
18
19use super::fields::replace_filter_fields;
20use super::models::*;
21
22pub type StreamCallback<T> = Box<dyn Fn(T) + Send + Sync + 'static>;
24
25pub struct Unsubscribe {
27 channel: String,
28 callback_id: u64,
29 api: Arc<StreamApiInner>,
30}
31
32impl Unsubscribe {
33 pub fn unsubscribe(self) {
35 self.api.unsubscribe(&self.channel, self.callback_id);
36 }
37}
38
39#[derive(Debug, Serialize)]
41struct CentrifugeCommand {
42 id: u64,
43 #[serde(skip_serializing_if = "Option::is_none")]
44 connect: Option<ConnectRequest>,
45 #[serde(skip_serializing_if = "Option::is_none")]
46 subscribe: Option<SubscribeRequest>,
47 #[serde(skip_serializing_if = "Option::is_none")]
48 unsubscribe: Option<UnsubscribeRequest>,
49}
50
51#[derive(Debug, Serialize)]
52struct ConnectRequest {
53 token: String,
54}
55
56#[derive(Debug, Serialize)]
57struct SubscribeRequest {
58 channel: String,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 delta: Option<String>,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 filter: Option<String>,
63}
64
65#[derive(Debug, Serialize)]
66struct UnsubscribeRequest {
67 channel: String,
68}
69
70#[derive(Debug, Deserialize)]
72struct CentrifugeResponse {
73 #[serde(default)]
74 id: u64,
75 #[serde(default)]
76 connect: Option<Value>,
77 #[serde(default)]
78 subscribe: Option<Value>,
79 #[serde(default)]
80 push: Option<PushData>,
81 #[serde(default)]
82 error: Option<ErrorData>,
83}
84
85#[derive(Debug, Deserialize)]
86struct PushData {
87 channel: String,
88 #[serde(rename = "pub")]
89 pub_data: Option<PublicationData>,
90}
91
92#[derive(Debug, Deserialize)]
93struct PublicationData {
94 data: Value,
95}
96
97#[derive(Debug, Deserialize)]
98struct ErrorData {
99 code: i32,
100 message: String,
101}
102
103struct CallbackWrapper {
105 id: u64,
106 callback: Box<dyn Fn(Value) + Send + Sync>,
107}
108
109struct StreamApiInner {
111 url: String,
112 access_token: String,
113 connected: AtomicBool,
114 command_id: AtomicU64,
115 callback_id: AtomicU64,
116 listeners: RwLock<HashMap<String, Vec<CallbackWrapper>>>,
117 subscriptions: RwLock<HashMap<String, u64>>,
118 command_tx: RwLock<Option<mpsc::UnboundedSender<Message>>>,
119}
120
121impl StreamApiInner {
122 fn new(url: String, access_token: String) -> Self {
123 Self {
124 url,
125 access_token,
126 connected: AtomicBool::new(false),
127 command_id: AtomicU64::new(1),
128 callback_id: AtomicU64::new(1),
129 listeners: RwLock::new(HashMap::new()),
130 subscriptions: RwLock::new(HashMap::new()),
131 command_tx: RwLock::new(None),
132 }
133 }
134
135 fn next_command_id(&self) -> u64 {
136 self.command_id.fetch_add(1, Ordering::SeqCst)
137 }
138
139 fn next_callback_id(&self) -> u64 {
140 self.callback_id.fetch_add(1, Ordering::SeqCst)
141 }
142
143 fn add_listener<F>(&self, channel: &str, callback: F) -> u64
144 where
145 F: Fn(Value) + Send + Sync + 'static,
146 {
147 let callback_id = self.next_callback_id();
148 let wrapper = CallbackWrapper {
149 id: callback_id,
150 callback: Box::new(callback),
151 };
152
153 let mut listeners = self.listeners.write();
154 listeners
155 .entry(channel.to_string())
156 .or_default()
157 .push(wrapper);
158
159 callback_id
160 }
161
162 fn unsubscribe(&self, channel: &str, callback_id: u64) {
163 let mut listeners = self.listeners.write();
164 if let Some(callbacks) = listeners.get_mut(channel) {
165 callbacks.retain(|c| c.id != callback_id);
166
167 if callbacks.is_empty() {
169 listeners.remove(channel);
170 drop(listeners);
171
172 if let Some(tx) = self.command_tx.read().as_ref() {
174 let cmd = CentrifugeCommand {
175 id: self.next_command_id(),
176 connect: None,
177 subscribe: None,
178 unsubscribe: Some(UnsubscribeRequest {
179 channel: channel.to_string(),
180 }),
181 };
182 if let Ok(json) = serde_json::to_string(&cmd) {
183 let _ = tx.send(Message::Text(json.into()));
184 }
185 }
186
187 self.subscriptions.write().remove(channel);
188 log::info!("[streaming] unsubscribed from channel: {}", channel);
189 }
190 }
191 }
192
193 fn dispatch_message(&self, channel: &str, data: Value) {
194 let listeners = self.listeners.read();
195 if let Some(callbacks) = listeners.get(channel) {
196 for callback in callbacks {
197 (callback.callback)(data.clone());
198 }
199 }
200 }
201
202 fn send_subscribe(&self, channel: &str, filter: Option<&str>) {
203 if let Some(tx) = self.command_tx.read().as_ref() {
204 let cmd = CentrifugeCommand {
205 id: self.next_command_id(),
206 connect: None,
207 subscribe: Some(SubscribeRequest {
208 channel: channel.to_string(),
209 delta: Some("fossil".to_string()),
210 filter: filter.map(|f| f.to_string()),
211 }),
212 unsubscribe: None,
213 };
214 if let Ok(json) = serde_json::to_string(&cmd) {
215 let _ = tx.send(Message::Text(json.into()));
216 }
217 }
218 }
219}
220
221pub struct StreamApi {
223 inner: Arc<StreamApiInner>,
224}
225
226impl StreamApi {
227 pub fn new(url: &str, access_token: &str) -> Self {
229 let url_with_token = if url.contains('?') {
231 format!("{}&token={}", url, access_token)
232 } else {
233 format!("{}?token={}", url, access_token)
234 };
235
236 Self {
237 inner: Arc::new(StreamApiInner::new(
238 url_with_token,
239 access_token.to_string(),
240 )),
241 }
242 }
243
244 pub fn is_connected(&self) -> bool {
246 self.inner.connected.load(Ordering::SeqCst)
247 }
248
249 pub async fn connect(&self) -> Result<(), ChainStreamError> {
251 if self.is_connected() {
252 return Ok(());
253 }
254
255 let url = &self.inner.url;
256 log::info!("[streaming] connecting to {}", url);
257
258 let (ws_stream, _) = connect_async(url)
259 .await
260 .map_err(|e| ChainStreamError::WebSocket(format!("Failed to connect: {}", e)))?;
261
262 let (mut write, mut read) = ws_stream.split();
263
264 let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
266 *self.inner.command_tx.write() = Some(tx.clone());
267
268 let connect_cmd = CentrifugeCommand {
270 id: self.inner.next_command_id(),
271 connect: Some(ConnectRequest {
272 token: self.inner.access_token.clone(),
273 }),
274 subscribe: None,
275 unsubscribe: None,
276 };
277 let connect_json = serde_json::to_string(&connect_cmd)
278 .map_err(|e| ChainStreamError::Serialization(e.to_string()))?;
279 write
280 .send(Message::Text(connect_json.into()))
281 .await
282 .map_err(|e| ChainStreamError::WebSocket(format!("Failed to send connect: {}", e)))?;
283
284 self.inner.connected.store(true, Ordering::SeqCst);
285
286 let inner_write = self.inner.clone();
288 tokio::spawn(async move {
289 while let Some(msg) = rx.recv().await {
290 if write.send(msg).await.is_err() {
291 inner_write.connected.store(false, Ordering::SeqCst);
292 break;
293 }
294 }
295 });
296
297 let inner_read = self.inner.clone();
299 tokio::spawn(async move {
300 while let Some(msg) = read.next().await {
301 match msg {
302 Ok(Message::Text(text)) => {
303 if let Ok(response) = serde_json::from_str::<CentrifugeResponse>(&text) {
304 if let Some(push) = response.push {
306 if let Some(pub_data) = push.pub_data {
307 inner_read.dispatch_message(&push.channel, pub_data.data);
308 }
309 }
310 if let Some(err) = response.error {
312 log::error!(
313 "[streaming] error: code={}, message={}",
314 err.code,
315 err.message
316 );
317 }
318 }
319 }
320 Ok(Message::Close(_)) => {
321 log::info!("[streaming] connection closed");
322 inner_read.connected.store(false, Ordering::SeqCst);
323 break;
324 }
325 Ok(Message::Ping(data)) => {
326 if let Some(tx) = inner_read.command_tx.read().as_ref() {
327 let _ = tx.send(Message::Pong(data));
328 }
329 }
330 Err(e) => {
331 log::error!("[streaming] read error: {}", e);
332 inner_read.connected.store(false, Ordering::SeqCst);
333 break;
334 }
335 _ => {}
336 }
337 }
338 });
339
340 Ok(())
341 }
342
343 pub async fn disconnect(&self) {
345 if let Some(tx) = self.inner.command_tx.write().take() {
346 let _ = tx.send(Message::Close(None));
347 }
348 self.inner.connected.store(false, Ordering::SeqCst);
349 log::info!("[streaming] disconnected");
350 }
351
352 pub async fn subscribe<F>(
354 &self,
355 channel: &str,
356 callback: F,
357 filter: Option<&str>,
358 method_name: Option<&str>,
359 ) -> Result<Unsubscribe, ChainStreamError>
360 where
361 F: Fn(Value) + Send + Sync + 'static,
362 {
363 if !self.is_connected() {
365 self.connect().await?;
366 }
367
368 let processed_filter = match (filter, method_name) {
370 (Some(f), Some(m)) if !f.is_empty() => Some(replace_filter_fields(f, m)),
371 (Some(f), _) if !f.is_empty() => Some(f.to_string()),
372 _ => None,
373 };
374
375 let needs_subscribe = {
377 let subs = self.inner.subscriptions.read();
378 !subs.contains_key(channel)
379 };
380
381 let callback_id = self.inner.add_listener(channel, callback);
383
384 if needs_subscribe {
386 self.inner
387 .send_subscribe(channel, processed_filter.as_deref());
388 self.inner
389 .subscriptions
390 .write()
391 .insert(channel.to_string(), self.inner.next_command_id());
392 log::info!("[streaming] subscribed to channel: {}", channel);
393 }
394
395 Ok(Unsubscribe {
396 channel: channel.to_string(),
397 callback_id,
398 api: self.inner.clone(),
399 })
400 }
401
402 pub async fn subscribe_token_candles<F>(
406 &self,
407 chain: &str,
408 token_address: &str,
409 resolution: Resolution,
410 callback: F,
411 filter: Option<&str>,
412 ) -> Result<Unsubscribe, ChainStreamError>
413 where
414 F: Fn(TokenCandle) + Send + Sync + 'static,
415 {
416 let channel = format!("dex-candle:{}_{}", chain, token_address);
417 let channel_with_resolution = format!("{}_{}", channel, resolution);
418
419 self.subscribe(
420 &channel_with_resolution,
421 move |data| {
422 if let Ok(candle) = parse_token_candle(&data) {
423 callback(candle);
424 }
425 },
426 filter,
427 Some("subscribe_token_candles"),
428 )
429 .await
430 }
431
432 pub async fn subscribe_token_stats<F>(
434 &self,
435 chain: &str,
436 token_address: &str,
437 callback: F,
438 filter: Option<&str>,
439 ) -> Result<Unsubscribe, ChainStreamError>
440 where
441 F: Fn(TokenStat) + Send + Sync + 'static,
442 {
443 let channel = format!("dex-token-stats:{}_{}", chain, token_address);
444
445 self.subscribe(
446 &channel,
447 move |data| {
448 if let Ok(stat) = serde_json::from_value::<TokenStat>(data) {
449 callback(stat);
450 }
451 },
452 filter,
453 Some("subscribe_token_stats"),
454 )
455 .await
456 }
457
458 pub async fn subscribe_new_token<F>(
460 &self,
461 chain: &str,
462 callback: F,
463 filter: Option<&str>,
464 ) -> Result<Unsubscribe, ChainStreamError>
465 where
466 F: Fn(NewToken) + Send + Sync + 'static,
467 {
468 let channel = format!("dex-new-token:{}", chain);
469
470 self.subscribe(
471 &channel,
472 move |data| {
473 if let Ok(token) = parse_new_token(&data) {
474 callback(token);
475 }
476 },
477 filter,
478 Some("subscribe_new_token"),
479 )
480 .await
481 }
482
483 pub async fn subscribe_token_trade<F>(
485 &self,
486 chain: &str,
487 token_address: &str,
488 callback: F,
489 filter: Option<&str>,
490 ) -> Result<Unsubscribe, ChainStreamError>
491 where
492 F: Fn(TradeActivity) + Send + Sync + 'static,
493 {
494 let channel = format!("dex-trade:{}_{}", chain, token_address);
495
496 self.subscribe(
497 &channel,
498 move |data| {
499 if let Ok(trade) = parse_trade_activity(&data) {
500 callback(trade);
501 }
502 },
503 filter,
504 Some("subscribe_token_trades"),
505 )
506 .await
507 }
508
509 pub async fn subscribe_wallet_balance<F>(
511 &self,
512 chain: &str,
513 wallet_address: &str,
514 callback: F,
515 filter: Option<&str>,
516 ) -> Result<Unsubscribe, ChainStreamError>
517 where
518 F: Fn(WalletBalance) + Send + Sync + 'static,
519 {
520 let channel = format!("dex-wallet-balance:{}_{}", chain, wallet_address);
521
522 self.subscribe(
523 &channel,
524 move |data| {
525 if let Ok(balance) = serde_json::from_value::<WalletBalance>(data) {
526 callback(balance);
527 }
528 },
529 filter,
530 Some("subscribe_wallet_balance"),
531 )
532 .await
533 }
534
535 pub async fn subscribe_token_holders<F>(
537 &self,
538 chain: &str,
539 token_address: &str,
540 callback: F,
541 filter: Option<&str>,
542 ) -> Result<Unsubscribe, ChainStreamError>
543 where
544 F: Fn(TokenHolder) + Send + Sync + 'static,
545 {
546 let channel = format!("dex-token-holder:{}_{}", chain, token_address);
547
548 self.subscribe(
549 &channel,
550 move |data| {
551 if let Ok(holder) = serde_json::from_value::<TokenHolder>(data) {
552 callback(holder);
553 }
554 },
555 filter,
556 Some("subscribe_token_holders"),
557 )
558 .await
559 }
560
561 pub async fn subscribe_token_supply<F>(
563 &self,
564 chain: &str,
565 token_address: &str,
566 callback: F,
567 filter: Option<&str>,
568 ) -> Result<Unsubscribe, ChainStreamError>
569 where
570 F: Fn(TokenSupply) + Send + Sync + 'static,
571 {
572 let channel = format!("dex-token-supply:{}_{}", chain, token_address);
573
574 self.subscribe(
575 &channel,
576 move |data| {
577 if let Ok(supply) = serde_json::from_value::<TokenSupply>(data) {
578 callback(supply);
579 }
580 },
581 filter,
582 Some("subscribe_token_supply"),
583 )
584 .await
585 }
586
587 pub async fn subscribe_dex_pool_balance<F>(
589 &self,
590 chain: &str,
591 pool_address: &str,
592 callback: F,
593 filter: Option<&str>,
594 ) -> Result<Unsubscribe, ChainStreamError>
595 where
596 F: Fn(DexPoolBalance) + Send + Sync + 'static,
597 {
598 let channel = format!("dex-pool-balance:{}_{}", chain, pool_address);
599
600 self.subscribe(
601 &channel,
602 move |data| {
603 if let Ok(balance) = serde_json::from_value::<DexPoolBalance>(data) {
604 callback(balance);
605 }
606 },
607 filter,
608 Some("subscribe_dex_pool_balance"),
609 )
610 .await
611 }
612
613 pub async fn subscribe_token_max_liquidity<F>(
615 &self,
616 chain: &str,
617 token_address: &str,
618 callback: F,
619 filter: Option<&str>,
620 ) -> Result<Unsubscribe, ChainStreamError>
621 where
622 F: Fn(TokenMaxLiquidity) + Send + Sync + 'static,
623 {
624 let channel = format!("dex-token-max-liquidity:{}_{}", chain, token_address);
625
626 self.subscribe(
627 &channel,
628 move |data| {
629 if let Ok(liquidity) = serde_json::from_value::<TokenMaxLiquidity>(data) {
630 callback(liquidity);
631 }
632 },
633 filter,
634 Some("subscribe_token_max_liquidity"),
635 )
636 .await
637 }
638
639 pub async fn subscribe_token_total_liquidity<F>(
641 &self,
642 chain: &str,
643 token_address: &str,
644 callback: F,
645 filter: Option<&str>,
646 ) -> Result<Unsubscribe, ChainStreamError>
647 where
648 F: Fn(TokenTotalLiquidity) + Send + Sync + 'static,
649 {
650 let channel = format!("dex-token-total-liquidity:{}_{}", chain, token_address);
651
652 self.subscribe(
653 &channel,
654 move |data| {
655 if let Ok(liquidity) = serde_json::from_value::<TokenTotalLiquidity>(data) {
656 callback(liquidity);
657 }
658 },
659 filter,
660 Some("subscribe_token_total_liquidity"),
661 )
662 .await
663 }
664
665 pub async fn subscribe_wallet_pnl<F>(
667 &self,
668 chain: &str,
669 wallet_address: &str,
670 callback: F,
671 filter: Option<&str>,
672 ) -> Result<Unsubscribe, ChainStreamError>
673 where
674 F: Fn(WalletTokenPnl) + Send + Sync + 'static,
675 {
676 let channel = format!("dex-wallet-pnl:{}_{}", chain, wallet_address);
677
678 self.subscribe(
679 &channel,
680 move |data| {
681 if let Ok(pnl) = serde_json::from_value::<WalletTokenPnl>(data) {
682 callback(pnl);
683 }
684 },
685 filter,
686 Some("subscribe_wallet_pnl"),
687 )
688 .await
689 }
690
691 pub async fn subscribe_new_tokens_metadata<F>(
693 &self,
694 chain: &str,
695 callback: F,
696 filter: Option<&str>,
697 ) -> Result<Unsubscribe, ChainStreamError>
698 where
699 F: Fn(TokenMetadata) + Send + Sync + 'static,
700 {
701 let channel = format!("dex-new-token-metadata:{}", chain);
702
703 self.subscribe(
704 &channel,
705 move |data| {
706 if let Ok(metadata) = serde_json::from_value::<TokenMetadata>(data) {
707 callback(metadata);
708 }
709 },
710 filter,
711 Some("subscribe_new_tokens_metadata"),
712 )
713 .await
714 }
715
716 pub async fn subscribe_ranking_tokens_list<F>(
718 &self,
719 chain: &str,
720 ranking_type: RankingType,
721 callback: F,
722 filter: Option<&str>,
723 ) -> Result<Unsubscribe, ChainStreamError>
724 where
725 F: Fn(RankingTokenList) + Send + Sync + 'static,
726 {
727 let ranking_str = match ranking_type {
728 RankingType::New => "new",
729 RankingType::Hot => "trending",
730 RankingType::Stocks => "stocks",
731 RankingType::FinalStretch => "completed",
732 RankingType::Migrated => "graduated",
733 };
734 let channel = format!("dex-ranking-token-list:{}_{}", chain, ranking_str);
735
736 self.subscribe(
737 &channel,
738 move |data| {
739 if let Ok(ranking) = serde_json::from_value::<RankingTokenList>(data) {
740 callback(ranking);
741 }
742 },
743 filter,
744 None,
745 )
746 .await
747 }
748}
749
750fn parse_token_candle(data: &Value) -> Result<TokenCandle, String> {
753 let obj = data.as_object().ok_or_else(|| {
754 "expected object".to_string()
755 })?;
756
757 Ok(TokenCandle {
758 open: get_string(obj, "o"),
759 close: get_string(obj, "c"),
760 high: get_string(obj, "h"),
761 low: get_string(obj, "l"),
762 volume: get_string(obj, "v"),
763 resolution: get_string(obj, "r"),
764 time: get_i64(obj, "t"),
765 number: get_i32(obj, "n"),
766 })
767}
768
769fn parse_new_token(data: &Value) -> Result<NewToken, String> {
770 let obj = data.as_object().ok_or_else(|| {
771 "expected object".to_string()
772 })?;
773
774 Ok(NewToken {
775 token_address: get_string(obj, "a"),
776 name: get_string(obj, "n"),
777 symbol: get_string(obj, "s"),
778 decimals: obj.get("d").and_then(|v| v.as_i64()).map(|v| v as i32),
779 launch_from: None, created_at_ms: get_i64(obj, "cts"),
781 })
782}
783
784fn parse_trade_activity(data: &Value) -> Result<TradeActivity, String> {
785 let obj = data.as_object().ok_or_else(|| {
786 "expected object".to_string()
787 })?;
788
789 Ok(TradeActivity {
790 token_address: get_string(obj, "a"),
791 timestamp: get_i64(obj, "t"),
792 kind: get_string(obj, "k"),
793 buy_amount: get_string(obj, "ba"),
794 buy_amount_in_usd: get_string(obj, "baiu"),
795 buy_token_address: get_string(obj, "btma"),
796 buy_token_name: get_string(obj, "btn"),
797 buy_token_symbol: get_string(obj, "bts"),
798 buy_wallet_address: get_string(obj, "bwa"),
799 sell_amount: get_string(obj, "sa"),
800 sell_amount_in_usd: get_string(obj, "saiu"),
801 sell_token_address: get_string(obj, "stma"),
802 sell_token_name: get_string(obj, "stn"),
803 sell_token_symbol: get_string(obj, "sts"),
804 sell_wallet_address: get_string(obj, "swa"),
805 tx_hash: get_string(obj, "h"),
806 })
807}
808
809fn get_string(obj: &serde_json::Map<String, Value>, key: &str) -> String {
810 obj.get(key)
811 .and_then(|v| v.as_str())
812 .unwrap_or_default()
813 .to_string()
814}
815
816fn get_i64(obj: &serde_json::Map<String, Value>, key: &str) -> i64 {
817 obj.get(key).and_then(|v| v.as_i64()).unwrap_or_default()
818}
819
820fn get_i32(obj: &serde_json::Map<String, Value>, key: &str) -> i32 {
821 obj.get(key)
822 .and_then(|v| v.as_i64())
823 .map(|v| v as i32)
824 .unwrap_or_default()
825}