1use std::collections::HashMap;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7use std::time::{Duration, Instant};
8
9use futures_util::stream::{SplitSink, SplitStream};
10use futures_util::{SinkExt, Stream, StreamExt};
11use tokio::net::TcpStream;
12use tokio::sync::Mutex;
13use tokio::time::{interval, Interval};
14use tokio_tungstenite::tungstenite::Message as WsMessage;
15use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
16
17use crate::error::KrakenError;
18use crate::spot::ws::client::WsConfig;
19use crate::spot::ws::messages::{
20 channels, AddOrderParams, AddOrderResult, CancelAllParams, CancelAllResult, CancelOrderParams,
21 CancelOrderResult, EditOrderParams, EditOrderResult, Heartbeat, PingRequest, PongResponse,
22 SubscribeParams, SubscriptionResult, SystemStatusMessage, WsRequest,
23};
24
25type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
26type WsSink = SplitSink<WsStream, WsMessage>;
27type WsReceiver = SplitStream<WsStream>;
28
29#[derive(Debug, Clone)]
31pub enum WsMessageEvent {
32 Status(SystemStatusMessage),
34 Heartbeat(Heartbeat),
36 Pong(PongResponse),
38 Subscribed(SubscriptionResult),
40 Unsubscribed(SubscriptionResult),
42 ChannelData(serde_json::Value),
44 OrderAdded {
46 req_id: Option<u64>,
48 result: AddOrderResult,
50 },
51 OrderCancelled {
53 req_id: Option<u64>,
55 result: CancelOrderResult,
57 },
58 AllOrdersCancelled {
60 req_id: Option<u64>,
62 result: CancelAllResult,
64 },
65 OrderEdited {
67 req_id: Option<u64>,
69 result: EditOrderResult,
71 },
72 Error { method: String, error: String, req_id: Option<u64> },
74 Disconnected,
76 Reconnecting { attempt: u32 },
78 Reconnected,
80}
81
82#[allow(dead_code)]
84#[derive(Debug, Clone)]
85struct SubscriptionState {
86 params: SubscribeParams,
87 status: SubscriptionStatus,
88 last_change: Instant,
89}
90
91#[allow(dead_code)]
92#[derive(Debug, Clone, Copy, PartialEq, Eq)]
93enum SubscriptionStatus {
94 Pending,
95 Active,
96 Error,
97}
98
99pub struct KrakenStream {
127 sink: Option<Arc<Mutex<WsSink>>>,
129 receiver: Option<WsReceiver>,
131 config: WsConfig,
133 url: String,
135 token: Option<String>,
137 subscriptions: HashMap<String, SubscriptionState>,
139 ping_interval: Interval,
141 last_ping: Option<Instant>,
143 last_message: Instant,
145 reconnect_attempt: u32,
147 req_id: u64,
149 connected: bool,
151 reconnecting: bool,
153}
154
155impl std::fmt::Debug for KrakenStream {
156 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157 f.debug_struct("KrakenStream")
158 .field("url", &self.url)
159 .field("connected", &self.connected)
160 .field("reconnecting", &self.reconnecting)
161 .field("subscriptions", &self.subscriptions.len())
162 .finish()
163 }
164}
165
166impl KrakenStream {
167 pub(crate) async fn connect_public(url: &str, config: WsConfig) -> Result<Self, KrakenError> {
169 Self::connect(url, config, None).await
170 }
171
172 pub(crate) async fn connect_private(
174 url: &str,
175 config: WsConfig,
176 token: String,
177 ) -> Result<Self, KrakenError> {
178 Self::connect(url, config, Some(token)).await
179 }
180
181 async fn connect(
183 url: &str,
184 config: WsConfig,
185 token: Option<String>,
186 ) -> Result<Self, KrakenError> {
187 let (ws_stream, _) = connect_async(url).await.map_err(|e| {
188 KrakenError::WebSocketMsg(format!("Failed to connect to {}: {}", url, e))
189 })?;
190
191 let (sink, receiver) = ws_stream.split();
192 let ping_interval_duration = config.ping_interval;
193
194 Ok(Self {
195 sink: Some(Arc::new(Mutex::new(sink))),
196 receiver: Some(receiver),
197 config,
198 url: url.to_string(),
199 token,
200 subscriptions: HashMap::new(),
201 ping_interval: interval(ping_interval_duration),
202 last_ping: None,
203 last_message: Instant::now(),
204 reconnect_attempt: 0,
205 req_id: 0,
206 connected: true,
207 reconnecting: false,
208 })
209 }
210
211 pub async fn subscribe(&mut self, params: SubscribeParams) -> Result<(), KrakenError> {
213 let key = subscription_key(¶ms);
214
215 self.subscriptions.insert(
217 key,
218 SubscriptionState {
219 params: params.clone(),
220 status: SubscriptionStatus::Pending,
221 last_change: Instant::now(),
222 },
223 );
224
225 self.send_subscribe(params).await
227 }
228
229 pub async fn unsubscribe(&mut self, params: SubscribeParams) -> Result<(), KrakenError> {
231 let key = subscription_key(¶ms);
232 self.subscriptions.remove(&key);
233
234 self.send_unsubscribe(params).await
235 }
236
237 async fn send_subscribe(&mut self, params: SubscribeParams) -> Result<(), KrakenError> {
239 let req = WsRequest::new("subscribe", params).with_req_id(self.next_req_id());
240 self.send_json(&req).await
241 }
242
243 async fn send_unsubscribe(&mut self, params: SubscribeParams) -> Result<(), KrakenError> {
245 let req = WsRequest::new("unsubscribe", params).with_req_id(self.next_req_id());
246 self.send_json(&req).await
247 }
248
249 pub async fn ping(&mut self) -> Result<(), KrakenError> {
251 let req = WsRequest::new("ping", PingRequest::with_req_id(self.next_req_id()));
252 self.last_ping = Some(Instant::now());
253 self.send_json(&req).await
254 }
255
256 pub async fn add_order(&mut self, params: AddOrderParams) -> Result<u64, KrakenError> {
282 self.ensure_private()?;
283 let req_id = self.next_req_id();
284 let req = WsRequest::new("add_order", params).with_req_id(req_id);
285 self.send_json(&req).await?;
286 Ok(req_id)
287 }
288
289 pub async fn cancel_order(&mut self, params: CancelOrderParams) -> Result<u64, KrakenError> {
313 self.ensure_private()?;
314 let req_id = self.next_req_id();
315 let req = WsRequest::new("cancel_order", params).with_req_id(req_id);
316 self.send_json(&req).await?;
317 Ok(req_id)
318 }
319
320 pub async fn cancel_all_orders(&mut self, params: CancelAllParams) -> Result<u64, KrakenError> {
333 self.ensure_private()?;
334 let req_id = self.next_req_id();
335 let req = WsRequest::new("cancel_all", params).with_req_id(req_id);
336 self.send_json(&req).await?;
337 Ok(req_id)
338 }
339
340 pub async fn edit_order(&mut self, params: EditOrderParams) -> Result<u64, KrakenError> {
357 self.ensure_private()?;
358 let req_id = self.next_req_id();
359 let req = WsRequest::new("edit_order", params).with_req_id(req_id);
360 self.send_json(&req).await?;
361 Ok(req_id)
362 }
363
364 fn ensure_private(&self) -> Result<(), KrakenError> {
366 if self.token.is_none() {
367 return Err(KrakenError::MissingCredentials);
368 }
369 Ok(())
370 }
371
372 async fn send_json<T: serde::Serialize>(&self, msg: &T) -> Result<(), KrakenError> {
374 let sink = self
375 .sink
376 .as_ref()
377 .ok_or_else(|| KrakenError::WebSocketMsg("Not connected".into()))?;
378
379 let json = serde_json::to_string(msg)
380 .map_err(|e| KrakenError::WebSocketMsg(format!("Failed to serialize message: {}", e)))?;
381
382 let mut sink = sink.lock().await;
383 sink.send(WsMessage::Text(json.into()))
384 .await
385 .map_err(|e| KrakenError::WebSocketMsg(format!("Failed to send message: {}", e)))
386 }
387
388 fn next_req_id(&mut self) -> u64 {
390 self.req_id += 1;
391 self.req_id
392 }
393
394 fn should_reconnect(&self) -> bool {
396 match self.config.max_reconnect_attempts {
397 Some(max) => self.reconnect_attempt < max,
398 None => true, }
400 }
401
402 #[allow(dead_code)]
404 fn backoff_duration(&self) -> Duration {
405 let base = self.config.initial_backoff.as_millis() as u64;
406 let max = self.config.max_backoff.as_millis() as u64;
407 let multiplier = 2u64.saturating_pow(self.reconnect_attempt);
408 let backoff_ms = base.saturating_mul(multiplier).min(max);
409 Duration::from_millis(backoff_ms)
410 }
411
412 #[allow(dead_code)]
414 async fn reconnect(&mut self) -> Result<(), KrakenError> {
415 self.reconnect_attempt += 1;
416 self.connected = false;
417 self.reconnecting = true;
418
419 self.sink = None;
421 self.receiver = None;
422
423 let backoff = self.backoff_duration();
425 tokio::time::sleep(backoff).await;
426
427 let (ws_stream, _) = connect_async(&self.url).await.map_err(|e| {
429 KrakenError::WebSocketMsg(format!("Failed to reconnect: {}", e))
430 })?;
431
432 let (sink, receiver) = ws_stream.split();
433 self.sink = Some(Arc::new(Mutex::new(sink)));
434 self.receiver = Some(receiver);
435 self.connected = true;
436 self.reconnecting = false;
437 self.reconnect_attempt = 0;
438 self.last_message = Instant::now();
439
440 self.restore_subscriptions().await?;
442
443 Ok(())
444 }
445
446 #[allow(dead_code)]
448 async fn restore_subscriptions(&mut self) -> Result<(), KrakenError> {
449 let subs: Vec<_> = self.subscriptions.values().map(|s| s.params.clone()).collect();
450
451 for params in subs {
452 self.send_subscribe(params).await?;
453 }
454
455 Ok(())
456 }
457
458 fn parse_message(&mut self, text: &str) -> Option<WsMessageEvent> {
460 self.last_message = Instant::now();
461
462 let value: serde_json::Value = match serde_json::from_str(text) {
464 Ok(v) => v,
465 Err(e) => {
466 tracing::warn!("Failed to parse WebSocket message: {}", e);
467 return None;
468 }
469 };
470
471 if let Some(method) = value.get("method").and_then(|m| m.as_str()) {
473 return self.handle_response_message(method, &value);
474 }
475
476 if let Some(channel) = value.get("channel").and_then(|c| c.as_str()) {
478 let channel = channel.to_string(); return self.handle_channel_message(&channel, value);
480 }
481
482 tracing::debug!("Unknown message format: {}", text);
484 Some(WsMessageEvent::ChannelData(value))
485 }
486
487 fn handle_response_message(
489 &mut self,
490 method: &str,
491 value: &serde_json::Value,
492 ) -> Option<WsMessageEvent> {
493 let req_id = value.get("req_id").and_then(|r| r.as_u64());
494
495 match method {
496 "pong" => {
497 if let Ok(pong) = serde_json::from_value::<PongResponse>(value.clone()) {
498 self.last_ping = None;
499 return Some(WsMessageEvent::Pong(pong));
500 }
501 }
502 "subscribe" => {
503 let success = value.get("success").and_then(|s| s.as_bool()).unwrap_or(false);
505 if success {
506 if let Some(result) = value.get("result") {
507 if let Ok(sub_result) = serde_json::from_value::<SubscriptionResult>(result.clone()) {
508 let key = subscription_key_from_result(&sub_result);
510 if let Some(state) = self.subscriptions.get_mut(&key) {
511 state.status = SubscriptionStatus::Active;
512 state.last_change = Instant::now();
513 }
514 return Some(WsMessageEvent::Subscribed(sub_result));
515 }
516 }
517 } else {
518 let error = value.get("error").and_then(|e| e.as_str()).unwrap_or("Unknown error");
519 return Some(WsMessageEvent::Error {
520 method: method.to_string(),
521 error: error.to_string(),
522 req_id,
523 });
524 }
525 }
526 "unsubscribe" => {
527 let success = value.get("success").and_then(|s| s.as_bool()).unwrap_or(false);
528 if success {
529 if let Some(result) = value.get("result") {
530 if let Ok(sub_result) = serde_json::from_value::<SubscriptionResult>(result.clone()) {
531 return Some(WsMessageEvent::Unsubscribed(sub_result));
532 }
533 }
534 } else {
535 let error = value.get("error").and_then(|e| e.as_str()).unwrap_or("Unknown error");
536 return Some(WsMessageEvent::Error {
537 method: method.to_string(),
538 error: error.to_string(),
539 req_id,
540 });
541 }
542 }
543 "add_order" => {
544 let success = value.get("success").and_then(|s| s.as_bool()).unwrap_or(false);
545 if success {
546 if let Some(result) = value.get("result") {
547 if let Ok(order_result) = serde_json::from_value::<AddOrderResult>(result.clone()) {
548 return Some(WsMessageEvent::OrderAdded {
549 req_id,
550 result: order_result,
551 });
552 }
553 }
554 } else {
555 let error = value.get("error").and_then(|e| e.as_str()).unwrap_or("Unknown error");
556 return Some(WsMessageEvent::Error {
557 method: method.to_string(),
558 error: error.to_string(),
559 req_id,
560 });
561 }
562 }
563 "cancel_order" => {
564 let success = value.get("success").and_then(|s| s.as_bool()).unwrap_or(false);
565 if success {
566 if let Some(result) = value.get("result") {
567 if let Ok(cancel_result) = serde_json::from_value::<CancelOrderResult>(result.clone()) {
568 return Some(WsMessageEvent::OrderCancelled {
569 req_id,
570 result: cancel_result,
571 });
572 }
573 }
574 } else {
575 let error = value.get("error").and_then(|e| e.as_str()).unwrap_or("Unknown error");
576 return Some(WsMessageEvent::Error {
577 method: method.to_string(),
578 error: error.to_string(),
579 req_id,
580 });
581 }
582 }
583 "cancel_all" => {
584 let success = value.get("success").and_then(|s| s.as_bool()).unwrap_or(false);
585 if success {
586 if let Some(result) = value.get("result") {
587 if let Ok(cancel_result) = serde_json::from_value::<CancelAllResult>(result.clone()) {
588 return Some(WsMessageEvent::AllOrdersCancelled {
589 req_id,
590 result: cancel_result,
591 });
592 }
593 }
594 } else {
595 let error = value.get("error").and_then(|e| e.as_str()).unwrap_or("Unknown error");
596 return Some(WsMessageEvent::Error {
597 method: method.to_string(),
598 error: error.to_string(),
599 req_id,
600 });
601 }
602 }
603 "edit_order" => {
604 let success = value.get("success").and_then(|s| s.as_bool()).unwrap_or(false);
605 if success {
606 if let Some(result) = value.get("result") {
607 if let Ok(edit_result) = serde_json::from_value::<EditOrderResult>(result.clone()) {
608 return Some(WsMessageEvent::OrderEdited {
609 req_id,
610 result: edit_result,
611 });
612 }
613 }
614 } else {
615 let error = value.get("error").and_then(|e| e.as_str()).unwrap_or("Unknown error");
616 return Some(WsMessageEvent::Error {
617 method: method.to_string(),
618 error: error.to_string(),
619 req_id,
620 });
621 }
622 }
623 _ => {
624 return Some(WsMessageEvent::ChannelData(value.clone()));
626 }
627 }
628
629 None
630 }
631
632 fn handle_channel_message(
634 &mut self,
635 channel: &str,
636 value: serde_json::Value,
637 ) -> Option<WsMessageEvent> {
638 match channel {
639 channels::STATUS => {
640 if let Ok(status) = serde_json::from_value::<SystemStatusMessage>(value) {
641 return Some(WsMessageEvent::Status(status));
642 }
643 }
644 channels::HEARTBEAT => {
645 if let Ok(heartbeat) = serde_json::from_value::<Heartbeat>(value) {
646 return Some(WsMessageEvent::Heartbeat(heartbeat));
647 }
648 }
649 _ => {
650 return Some(WsMessageEvent::ChannelData(value));
652 }
653 }
654
655 None
656 }
657
658 fn check_connection_health(&self) -> bool {
660 if let Some(ping_time) = self.last_ping {
662 if ping_time.elapsed() > self.config.pong_timeout {
663 return false;
664 }
665 }
666
667 true
668 }
669
670 pub async fn close(&mut self) -> Result<(), KrakenError> {
672 if let Some(sink) = self.sink.take() {
673 let mut sink = sink.lock().await;
674 let _ = sink.send(WsMessage::Close(None)).await;
675 }
676 self.receiver = None;
677 self.connected = false;
678 Ok(())
679 }
680
681 pub fn is_connected(&self) -> bool {
683 self.connected
684 }
685}
686
687impl Stream for KrakenStream {
688 type Item = Result<WsMessageEvent, KrakenError>;
689
690 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
691 if self.ping_interval.poll_tick(cx).is_ready() && self.connected {
693 if self.last_ping.is_none() {
695 let this = self.as_mut().get_mut();
696 let ping_req = WsRequest::new("ping", PingRequest::with_req_id(this.next_req_id()));
697 this.last_ping = Some(Instant::now());
698
699 if let Some(sink) = &this.sink {
700 let sink = sink.clone();
701 if let Ok(json) = serde_json::to_string(&ping_req) {
702 tokio::spawn(async move {
703 let mut sink = sink.lock().await;
704 let _ = sink.send(WsMessage::Text(json.into())).await;
705 });
706 }
707 }
708 }
709 }
710
711 if !self.check_connection_health() && self.connected {
713 let this = self.as_mut().get_mut();
714 this.connected = false;
715
716 if this.should_reconnect() {
717 return Poll::Ready(Some(Ok(WsMessageEvent::Reconnecting {
718 attempt: this.reconnect_attempt + 1,
719 })));
720 } else {
721 return Poll::Ready(Some(Ok(WsMessageEvent::Disconnected)));
722 }
723 }
724
725 if let Some(receiver) = self.receiver.as_mut() {
727 match Pin::new(receiver).poll_next(cx) {
728 Poll::Ready(Some(Ok(msg))) => {
729 let this = self.as_mut().get_mut();
730 match msg {
731 WsMessage::Text(text) => {
732 if let Some(event) = this.parse_message(&text) {
733 return Poll::Ready(Some(Ok(event)));
734 }
735 cx.waker().wake_by_ref();
737 return Poll::Pending;
738 }
739 WsMessage::Binary(data) => {
740 if let Ok(text) = String::from_utf8(data.to_vec()) {
742 if let Some(event) = this.parse_message(&text) {
743 return Poll::Ready(Some(Ok(event)));
744 }
745 }
746 cx.waker().wake_by_ref();
747 return Poll::Pending;
748 }
749 WsMessage::Ping(_) | WsMessage::Pong(_) => {
750 cx.waker().wake_by_ref();
752 return Poll::Pending;
753 }
754 WsMessage::Close(_) => {
755 this.connected = false;
756 if this.should_reconnect() {
757 return Poll::Ready(Some(Ok(WsMessageEvent::Reconnecting {
758 attempt: this.reconnect_attempt + 1,
759 })));
760 } else {
761 return Poll::Ready(Some(Ok(WsMessageEvent::Disconnected)));
762 }
763 }
764 WsMessage::Frame(_) => {
765 cx.waker().wake_by_ref();
766 return Poll::Pending;
767 }
768 }
769 }
770 Poll::Ready(Some(Err(e))) => {
771 let this = self.as_mut().get_mut();
772 this.connected = false;
773 tracing::warn!("WebSocket error: {}", e);
774
775 if this.should_reconnect() {
776 return Poll::Ready(Some(Ok(WsMessageEvent::Reconnecting {
777 attempt: this.reconnect_attempt + 1,
778 })));
779 } else {
780 return Poll::Ready(Some(Err(KrakenError::WebSocket(e))));
781 }
782 }
783 Poll::Ready(None) => {
784 let this = self.as_mut().get_mut();
785 this.connected = false;
786
787 if this.should_reconnect() {
788 return Poll::Ready(Some(Ok(WsMessageEvent::Reconnecting {
789 attempt: this.reconnect_attempt + 1,
790 })));
791 } else {
792 return Poll::Ready(None);
793 }
794 }
795 Poll::Pending => {}
796 }
797 } else if !self.reconnecting && self.should_reconnect() {
798 return Poll::Ready(Some(Ok(WsMessageEvent::Reconnecting {
800 attempt: self.reconnect_attempt + 1,
801 })));
802 }
803
804 Poll::Pending
805 }
806}
807
808fn subscription_key(params: &SubscribeParams) -> String {
810 let symbols = params
811 .symbol
812 .as_ref()
813 .map(|s| s.join(","))
814 .unwrap_or_default();
815 format!("{}:{}", params.channel, symbols)
816}
817
818fn subscription_key_from_result(result: &SubscriptionResult) -> String {
820 format!(
821 "{}:{}",
822 result.channel,
823 result.symbol.as_deref().unwrap_or("")
824 )
825}
826
827#[cfg(test)]
828mod tests {
829 use super::*;
830
831 #[test]
832 fn test_subscription_key() {
833 let params = SubscribeParams::public("ticker", vec!["BTC/USD".into(), "ETH/USD".into()]);
834 let key = subscription_key(¶ms);
835 assert_eq!(key, "ticker:BTC/USD,ETH/USD");
836 }
837
838 #[test]
839 fn test_backoff_calculation_formula() {
840 let initial = Duration::from_secs(1);
842 let max = Duration::from_secs(60);
843
844 let attempt = 0;
846 let multiplier = 2u64.saturating_pow(attempt);
847 let result = (initial.as_millis() as u64 * multiplier).min(max.as_millis() as u64);
848 assert_eq!(Duration::from_millis(result), Duration::from_secs(1));
849
850 let attempt = 3;
852 let multiplier = 2u64.saturating_pow(attempt);
853 let result = (initial.as_millis() as u64 * multiplier).min(max.as_millis() as u64);
854 assert_eq!(Duration::from_millis(result), Duration::from_secs(8));
855
856 let attempt = 10;
858 let multiplier = 2u64.saturating_pow(attempt);
859 let result = (initial.as_millis() as u64 * multiplier).min(max.as_millis() as u64);
860 assert_eq!(Duration::from_millis(result), Duration::from_secs(60));
861 }
862}