Skip to main content

kraken_api_client/spot/ws/
stream.rs

1//! WebSocket stream implementation.
2
3use std::collections::HashMap;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7use std::time::{Duration, Instant};
8
9use futures_util::stream::{SplitSink, SplitStream};
10use futures_util::{SinkExt, Stream, StreamExt};
11use tokio::net::TcpStream;
12use tokio::sync::Mutex;
13use tokio::time::{interval, Interval};
14use tokio_tungstenite::tungstenite::Message as WsMessage;
15use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
16
17use crate::error::KrakenError;
18use crate::spot::ws::client::WsConfig;
19use crate::spot::ws::messages::{
20    channels, AddOrderParams, AddOrderResult, CancelAllParams, CancelAllResult, CancelOrderParams,
21    CancelOrderResult, EditOrderParams, EditOrderResult, Heartbeat, PingRequest, PongResponse,
22    SubscribeParams, SubscriptionResult, SystemStatusMessage, WsRequest,
23};
24
25type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
26type WsSink = SplitSink<WsStream, WsMessage>;
27type WsReceiver = SplitStream<WsStream>;
28
29/// A message received from the WebSocket connection.
30#[derive(Debug, Clone)]
31pub enum WsMessageEvent {
32    /// System status update.
33    Status(SystemStatusMessage),
34    /// Heartbeat from server.
35    Heartbeat(Heartbeat),
36    /// Pong response to our ping.
37    Pong(PongResponse),
38    /// Subscription confirmed.
39    Subscribed(SubscriptionResult),
40    /// Unsubscription confirmed.
41    Unsubscribed(SubscriptionResult),
42    /// Raw channel data (ticker, book, trade, etc.).
43    ChannelData(serde_json::Value),
44    /// Order added successfully.
45    OrderAdded {
46        /// Request ID from the original request.
47        req_id: Option<u64>,
48        /// Order result details.
49        result: AddOrderResult,
50    },
51    /// Order cancelled successfully.
52    OrderCancelled {
53        /// Request ID from the original request.
54        req_id: Option<u64>,
55        /// Cancel result details.
56        result: CancelOrderResult,
57    },
58    /// All orders cancelled.
59    AllOrdersCancelled {
60        /// Request ID from the original request.
61        req_id: Option<u64>,
62        /// Number of orders cancelled.
63        result: CancelAllResult,
64    },
65    /// Order edited successfully.
66    OrderEdited {
67        /// Request ID from the original request.
68        req_id: Option<u64>,
69        /// Edit result details.
70        result: EditOrderResult,
71    },
72    /// Subscription/unsubscription error.
73    Error { method: String, error: String, req_id: Option<u64> },
74    /// Connection closed.
75    Disconnected,
76    /// Reconnecting.
77    Reconnecting { attempt: u32 },
78    /// Reconnected successfully.
79    Reconnected,
80}
81
82/// Subscription state tracking.
83#[allow(dead_code)]
84#[derive(Debug, Clone)]
85struct SubscriptionState {
86    params: SubscribeParams,
87    status: SubscriptionStatus,
88    last_change: Instant,
89}
90
91#[allow(dead_code)]
92#[derive(Debug, Clone, Copy, PartialEq, Eq)]
93enum SubscriptionStatus {
94    Pending,
95    Active,
96    Error,
97}
98
99/// A stream of messages from a Kraken WebSocket connection.
100///
101/// This stream handles:
102/// - Automatic reconnection with exponential backoff
103/// - Subscription restoration after reconnect
104/// - Heartbeat/ping monitoring
105///
106/// # Example
107///
108/// ```rust,ignore
109/// use kraken_api_client::spot::ws::SpotWsClient;
110/// use kraken_api_client::spot::ws::messages::{SubscribeParams, channels};
111/// use futures_util::StreamExt;
112///
113/// let client = SpotWsClient::new();
114/// let mut stream = client.connect_public().await?;
115///
116/// stream.subscribe(SubscribeParams::public(channels::TICKER, vec!["BTC/USD".into()])).await?;
117///
118/// while let Some(msg) = stream.next().await {
119///     match msg? {
120///         WsMessageEvent::ChannelData(data) => println!("Data: {:?}", data),
121///         WsMessageEvent::Disconnected => println!("Disconnected!"),
122///         _ => {}
123///     }
124/// }
125/// ```
126pub struct KrakenStream {
127    /// WebSocket sink for sending messages.
128    sink: Option<Arc<Mutex<WsSink>>>,
129    /// WebSocket receiver for incoming messages.
130    receiver: Option<WsReceiver>,
131    /// Connection configuration.
132    config: WsConfig,
133    /// URL to connect to.
134    url: String,
135    /// Authentication token (for private connections).
136    token: Option<String>,
137    /// Active subscriptions.
138    subscriptions: HashMap<String, SubscriptionState>,
139    /// Ping interval timer.
140    ping_interval: Interval,
141    /// Last ping sent timestamp.
142    last_ping: Option<Instant>,
143    /// Last message received timestamp.
144    last_message: Instant,
145    /// Current reconnection attempt.
146    reconnect_attempt: u32,
147    /// Request ID counter.
148    req_id: u64,
149    /// Connection state.
150    connected: bool,
151    /// Whether we're currently reconnecting.
152    reconnecting: bool,
153}
154
155impl std::fmt::Debug for KrakenStream {
156    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157        f.debug_struct("KrakenStream")
158            .field("url", &self.url)
159            .field("connected", &self.connected)
160            .field("reconnecting", &self.reconnecting)
161            .field("subscriptions", &self.subscriptions.len())
162            .finish()
163    }
164}
165
166impl KrakenStream {
167    /// Create and connect a new public WebSocket stream.
168    pub(crate) async fn connect_public(url: &str, config: WsConfig) -> Result<Self, KrakenError> {
169        Self::connect(url, config, None).await
170    }
171
172    /// Create and connect a new private WebSocket stream.
173    pub(crate) async fn connect_private(
174        url: &str,
175        config: WsConfig,
176        token: String,
177    ) -> Result<Self, KrakenError> {
178        Self::connect(url, config, Some(token)).await
179    }
180
181    /// Connect to the WebSocket server.
182    async fn connect(
183        url: &str,
184        config: WsConfig,
185        token: Option<String>,
186    ) -> Result<Self, KrakenError> {
187        let (ws_stream, _) = connect_async(url).await.map_err(|e| {
188            KrakenError::WebSocketMsg(format!("Failed to connect to {}: {}", url, e))
189        })?;
190
191        let (sink, receiver) = ws_stream.split();
192        let ping_interval_duration = config.ping_interval;
193
194        Ok(Self {
195            sink: Some(Arc::new(Mutex::new(sink))),
196            receiver: Some(receiver),
197            config,
198            url: url.to_string(),
199            token,
200            subscriptions: HashMap::new(),
201            ping_interval: interval(ping_interval_duration),
202            last_ping: None,
203            last_message: Instant::now(),
204            reconnect_attempt: 0,
205            req_id: 0,
206            connected: true,
207            reconnecting: false,
208        })
209    }
210
211    /// Subscribe to a channel.
212    pub async fn subscribe(&mut self, params: SubscribeParams) -> Result<(), KrakenError> {
213        let key = subscription_key(&params);
214
215        // Store subscription state
216        self.subscriptions.insert(
217            key,
218            SubscriptionState {
219                params: params.clone(),
220                status: SubscriptionStatus::Pending,
221                last_change: Instant::now(),
222            },
223        );
224
225        // Send subscription request
226        self.send_subscribe(params).await
227    }
228
229    /// Unsubscribe from a channel.
230    pub async fn unsubscribe(&mut self, params: SubscribeParams) -> Result<(), KrakenError> {
231        let key = subscription_key(&params);
232        self.subscriptions.remove(&key);
233
234        self.send_unsubscribe(params).await
235    }
236
237    /// Send a subscription request.
238    async fn send_subscribe(&mut self, params: SubscribeParams) -> Result<(), KrakenError> {
239        let req = WsRequest::new("subscribe", params).with_req_id(self.next_req_id());
240        self.send_json(&req).await
241    }
242
243    /// Send an unsubscription request.
244    async fn send_unsubscribe(&mut self, params: SubscribeParams) -> Result<(), KrakenError> {
245        let req = WsRequest::new("unsubscribe", params).with_req_id(self.next_req_id());
246        self.send_json(&req).await
247    }
248
249    /// Send a ping message.
250    pub async fn ping(&mut self) -> Result<(), KrakenError> {
251        let req = WsRequest::new("ping", PingRequest::with_req_id(self.next_req_id()));
252        self.last_ping = Some(Instant::now());
253        self.send_json(&req).await
254    }
255
256    // ========== Trading Operations ==========
257
258    /// Add a new order via WebSocket.
259    ///
260    /// This requires an authenticated connection. Use `connect_private()` first.
261    ///
262    /// # Example
263    ///
264    /// ```rust,ignore
265    /// use kraken_api_client::spot::ws::SpotWsClient;
266    /// use kraken_api_client::spot::ws::messages::AddOrderParams;
267    /// use kraken_api_client::types::{OrderType, BuySell};
268    /// use rust_decimal_macros::dec;
269    ///
270    /// let client = SpotWsClient::new();
271    /// let token = rest_client.get_websocket_token().await?.token;
272    /// let mut stream = client.connect_private(&token).await?;
273    ///
274    /// let params = AddOrderParams::new(OrderType::Limit, BuySell::Buy, "BTC/USD", &token)
275    ///     .order_qty(dec!(0.001))
276    ///     .limit_price(dec!(50000))
277    ///     .validate(true); // Validate only, don't submit
278    ///
279    /// stream.add_order(params).await?;
280    /// ```
281    pub async fn add_order(&mut self, params: AddOrderParams) -> Result<u64, KrakenError> {
282        self.ensure_private()?;
283        let req_id = self.next_req_id();
284        let req = WsRequest::new("add_order", params).with_req_id(req_id);
285        self.send_json(&req).await?;
286        Ok(req_id)
287    }
288
289    /// Cancel one or more orders via WebSocket.
290    ///
291    /// This requires an authenticated connection. Use `connect_private()` first.
292    ///
293    /// # Example
294    ///
295    /// ```rust,ignore
296    /// use kraken_api_client::spot::ws::messages::CancelOrderParams;
297    ///
298    /// // Cancel by order ID
299    /// let params = CancelOrderParams::by_order_id(
300    ///     vec!["OQCLML-BW3P3-BUCMWZ".into()],
301    ///     &token
302    /// );
303    /// stream.cancel_order(params).await?;
304    ///
305    /// // Cancel by client order ID
306    /// let params = CancelOrderParams::by_cl_ord_id(
307    ///     vec!["my-order-1".into()],
308    ///     &token
309    /// );
310    /// stream.cancel_order(params).await?;
311    /// ```
312    pub async fn cancel_order(&mut self, params: CancelOrderParams) -> Result<u64, KrakenError> {
313        self.ensure_private()?;
314        let req_id = self.next_req_id();
315        let req = WsRequest::new("cancel_order", params).with_req_id(req_id);
316        self.send_json(&req).await?;
317        Ok(req_id)
318    }
319
320    /// Cancel all open orders via WebSocket.
321    ///
322    /// This requires an authenticated connection. Use `connect_private()` first.
323    ///
324    /// # Example
325    ///
326    /// ```rust,ignore
327    /// use kraken_api_client::spot::ws::messages::CancelAllParams;
328    ///
329    /// let params = CancelAllParams::new(&token);
330    /// stream.cancel_all_orders(params).await?;
331    /// ```
332    pub async fn cancel_all_orders(&mut self, params: CancelAllParams) -> Result<u64, KrakenError> {
333        self.ensure_private()?;
334        let req_id = self.next_req_id();
335        let req = WsRequest::new("cancel_all", params).with_req_id(req_id);
336        self.send_json(&req).await?;
337        Ok(req_id)
338    }
339
340    /// Edit an existing order via WebSocket.
341    ///
342    /// This requires an authenticated connection. Use `connect_private()` first.
343    ///
344    /// # Example
345    ///
346    /// ```rust,ignore
347    /// use kraken_api_client::spot::ws::messages::EditOrderParams;
348    /// use rust_decimal_macros::dec;
349    ///
350    /// let params = EditOrderParams::new("OQCLML-BW3P3-BUCMWZ", &token)
351    ///     .limit_price(dec!(51000))
352    ///     .order_qty(dec!(0.002));
353    ///
354    /// stream.edit_order(params).await?;
355    /// ```
356    pub async fn edit_order(&mut self, params: EditOrderParams) -> Result<u64, KrakenError> {
357        self.ensure_private()?;
358        let req_id = self.next_req_id();
359        let req = WsRequest::new("edit_order", params).with_req_id(req_id);
360        self.send_json(&req).await?;
361        Ok(req_id)
362    }
363
364    /// Ensure this is a private (authenticated) connection.
365    fn ensure_private(&self) -> Result<(), KrakenError> {
366        if self.token.is_none() {
367            return Err(KrakenError::MissingCredentials);
368        }
369        Ok(())
370    }
371
372    /// Send a JSON message.
373    async fn send_json<T: serde::Serialize>(&self, msg: &T) -> Result<(), KrakenError> {
374        let sink = self
375            .sink
376            .as_ref()
377            .ok_or_else(|| KrakenError::WebSocketMsg("Not connected".into()))?;
378
379        let json = serde_json::to_string(msg)
380            .map_err(|e| KrakenError::WebSocketMsg(format!("Failed to serialize message: {}", e)))?;
381
382        let mut sink = sink.lock().await;
383        sink.send(WsMessage::Text(json.into()))
384            .await
385            .map_err(|e| KrakenError::WebSocketMsg(format!("Failed to send message: {}", e)))
386    }
387
388    /// Get the next request ID.
389    fn next_req_id(&mut self) -> u64 {
390        self.req_id += 1;
391        self.req_id
392    }
393
394    /// Check if we should reconnect.
395    fn should_reconnect(&self) -> bool {
396        match self.config.max_reconnect_attempts {
397            Some(max) => self.reconnect_attempt < max,
398            None => true, // Infinite retries
399        }
400    }
401
402    /// Calculate backoff duration for reconnection.
403    #[allow(dead_code)]
404    fn backoff_duration(&self) -> Duration {
405        let base = self.config.initial_backoff.as_millis() as u64;
406        let max = self.config.max_backoff.as_millis() as u64;
407        let multiplier = 2u64.saturating_pow(self.reconnect_attempt);
408        let backoff_ms = base.saturating_mul(multiplier).min(max);
409        Duration::from_millis(backoff_ms)
410    }
411
412    /// Attempt to reconnect.
413    #[allow(dead_code)]
414    async fn reconnect(&mut self) -> Result<(), KrakenError> {
415        self.reconnect_attempt += 1;
416        self.connected = false;
417        self.reconnecting = true;
418
419        // Close existing connection
420        self.sink = None;
421        self.receiver = None;
422
423        // Wait with backoff
424        let backoff = self.backoff_duration();
425        tokio::time::sleep(backoff).await;
426
427        // Try to reconnect
428        let (ws_stream, _) = connect_async(&self.url).await.map_err(|e| {
429            KrakenError::WebSocketMsg(format!("Failed to reconnect: {}", e))
430        })?;
431
432        let (sink, receiver) = ws_stream.split();
433        self.sink = Some(Arc::new(Mutex::new(sink)));
434        self.receiver = Some(receiver);
435        self.connected = true;
436        self.reconnecting = false;
437        self.reconnect_attempt = 0;
438        self.last_message = Instant::now();
439
440        // Restore subscriptions
441        self.restore_subscriptions().await?;
442
443        Ok(())
444    }
445
446    /// Restore subscriptions after reconnection.
447    #[allow(dead_code)]
448    async fn restore_subscriptions(&mut self) -> Result<(), KrakenError> {
449        let subs: Vec<_> = self.subscriptions.values().map(|s| s.params.clone()).collect();
450
451        for params in subs {
452            self.send_subscribe(params).await?;
453        }
454
455        Ok(())
456    }
457
458    /// Parse and handle an incoming message.
459    fn parse_message(&mut self, text: &str) -> Option<WsMessageEvent> {
460        self.last_message = Instant::now();
461
462        // Try to parse as JSON
463        let value: serde_json::Value = match serde_json::from_str(text) {
464            Ok(v) => v,
465            Err(e) => {
466                tracing::warn!("Failed to parse WebSocket message: {}", e);
467                return None;
468            }
469        };
470
471        // Check if it's a response message (has "method" at top level)
472        if let Some(method) = value.get("method").and_then(|m| m.as_str()) {
473            return self.handle_response_message(method, &value);
474        }
475
476        // Check if it's a channel message (has "channel" at top level)
477        if let Some(channel) = value.get("channel").and_then(|c| c.as_str()) {
478            let channel = channel.to_string(); // Clone the channel string to avoid borrow
479            return self.handle_channel_message(&channel, value);
480        }
481
482        // Unknown message format
483        tracing::debug!("Unknown message format: {}", text);
484        Some(WsMessageEvent::ChannelData(value))
485    }
486
487    /// Handle a response message (method-based).
488    fn handle_response_message(
489        &mut self,
490        method: &str,
491        value: &serde_json::Value,
492    ) -> Option<WsMessageEvent> {
493        let req_id = value.get("req_id").and_then(|r| r.as_u64());
494
495        match method {
496            "pong" => {
497                if let Ok(pong) = serde_json::from_value::<PongResponse>(value.clone()) {
498                    self.last_ping = None;
499                    return Some(WsMessageEvent::Pong(pong));
500                }
501            }
502            "subscribe" => {
503                // Check for success/error
504                let success = value.get("success").and_then(|s| s.as_bool()).unwrap_or(false);
505                if success {
506                    if let Some(result) = value.get("result") {
507                        if let Ok(sub_result) = serde_json::from_value::<SubscriptionResult>(result.clone()) {
508                            // Update subscription state
509                            let key = subscription_key_from_result(&sub_result);
510                            if let Some(state) = self.subscriptions.get_mut(&key) {
511                                state.status = SubscriptionStatus::Active;
512                                state.last_change = Instant::now();
513                            }
514                            return Some(WsMessageEvent::Subscribed(sub_result));
515                        }
516                    }
517                } else {
518                    let error = value.get("error").and_then(|e| e.as_str()).unwrap_or("Unknown error");
519                    return Some(WsMessageEvent::Error {
520                        method: method.to_string(),
521                        error: error.to_string(),
522                        req_id,
523                    });
524                }
525            }
526            "unsubscribe" => {
527                let success = value.get("success").and_then(|s| s.as_bool()).unwrap_or(false);
528                if success {
529                    if let Some(result) = value.get("result") {
530                        if let Ok(sub_result) = serde_json::from_value::<SubscriptionResult>(result.clone()) {
531                            return Some(WsMessageEvent::Unsubscribed(sub_result));
532                        }
533                    }
534                } else {
535                    let error = value.get("error").and_then(|e| e.as_str()).unwrap_or("Unknown error");
536                    return Some(WsMessageEvent::Error {
537                        method: method.to_string(),
538                        error: error.to_string(),
539                        req_id,
540                    });
541                }
542            }
543            "add_order" => {
544                let success = value.get("success").and_then(|s| s.as_bool()).unwrap_or(false);
545                if success {
546                    if let Some(result) = value.get("result") {
547                        if let Ok(order_result) = serde_json::from_value::<AddOrderResult>(result.clone()) {
548                            return Some(WsMessageEvent::OrderAdded {
549                                req_id,
550                                result: order_result,
551                            });
552                        }
553                    }
554                } else {
555                    let error = value.get("error").and_then(|e| e.as_str()).unwrap_or("Unknown error");
556                    return Some(WsMessageEvent::Error {
557                        method: method.to_string(),
558                        error: error.to_string(),
559                        req_id,
560                    });
561                }
562            }
563            "cancel_order" => {
564                let success = value.get("success").and_then(|s| s.as_bool()).unwrap_or(false);
565                if success {
566                    if let Some(result) = value.get("result") {
567                        if let Ok(cancel_result) = serde_json::from_value::<CancelOrderResult>(result.clone()) {
568                            return Some(WsMessageEvent::OrderCancelled {
569                                req_id,
570                                result: cancel_result,
571                            });
572                        }
573                    }
574                } else {
575                    let error = value.get("error").and_then(|e| e.as_str()).unwrap_or("Unknown error");
576                    return Some(WsMessageEvent::Error {
577                        method: method.to_string(),
578                        error: error.to_string(),
579                        req_id,
580                    });
581                }
582            }
583            "cancel_all" => {
584                let success = value.get("success").and_then(|s| s.as_bool()).unwrap_or(false);
585                if success {
586                    if let Some(result) = value.get("result") {
587                        if let Ok(cancel_result) = serde_json::from_value::<CancelAllResult>(result.clone()) {
588                            return Some(WsMessageEvent::AllOrdersCancelled {
589                                req_id,
590                                result: cancel_result,
591                            });
592                        }
593                    }
594                } else {
595                    let error = value.get("error").and_then(|e| e.as_str()).unwrap_or("Unknown error");
596                    return Some(WsMessageEvent::Error {
597                        method: method.to_string(),
598                        error: error.to_string(),
599                        req_id,
600                    });
601                }
602            }
603            "edit_order" => {
604                let success = value.get("success").and_then(|s| s.as_bool()).unwrap_or(false);
605                if success {
606                    if let Some(result) = value.get("result") {
607                        if let Ok(edit_result) = serde_json::from_value::<EditOrderResult>(result.clone()) {
608                            return Some(WsMessageEvent::OrderEdited {
609                                req_id,
610                                result: edit_result,
611                            });
612                        }
613                    }
614                } else {
615                    let error = value.get("error").and_then(|e| e.as_str()).unwrap_or("Unknown error");
616                    return Some(WsMessageEvent::Error {
617                        method: method.to_string(),
618                        error: error.to_string(),
619                        req_id,
620                    });
621                }
622            }
623            _ => {
624                // Unknown method, return as raw data
625                return Some(WsMessageEvent::ChannelData(value.clone()));
626            }
627        }
628
629        None
630    }
631
632    /// Handle a channel message.
633    fn handle_channel_message(
634        &mut self,
635        channel: &str,
636        value: serde_json::Value,
637    ) -> Option<WsMessageEvent> {
638        match channel {
639            channels::STATUS => {
640                if let Ok(status) = serde_json::from_value::<SystemStatusMessage>(value) {
641                    return Some(WsMessageEvent::Status(status));
642                }
643            }
644            channels::HEARTBEAT => {
645                if let Ok(heartbeat) = serde_json::from_value::<Heartbeat>(value) {
646                    return Some(WsMessageEvent::Heartbeat(heartbeat));
647                }
648            }
649            _ => {
650                // Market data or user data channel
651                return Some(WsMessageEvent::ChannelData(value));
652            }
653        }
654
655        None
656    }
657
658    /// Check connection health (ping timeout).
659    fn check_connection_health(&self) -> bool {
660        // Check if ping response is overdue
661        if let Some(ping_time) = self.last_ping {
662            if ping_time.elapsed() > self.config.pong_timeout {
663                return false;
664            }
665        }
666
667        true
668    }
669
670    /// Close the connection gracefully.
671    pub async fn close(&mut self) -> Result<(), KrakenError> {
672        if let Some(sink) = self.sink.take() {
673            let mut sink = sink.lock().await;
674            let _ = sink.send(WsMessage::Close(None)).await;
675        }
676        self.receiver = None;
677        self.connected = false;
678        Ok(())
679    }
680
681    /// Check if the connection is open.
682    pub fn is_connected(&self) -> bool {
683        self.connected
684    }
685}
686
687impl Stream for KrakenStream {
688    type Item = Result<WsMessageEvent, KrakenError>;
689
690    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
691        // Check ping interval
692        if self.ping_interval.poll_tick(cx).is_ready() && self.connected {
693            // Only send ping if not waiting for pong
694            if self.last_ping.is_none() {
695                let this = self.as_mut().get_mut();
696                let ping_req = WsRequest::new("ping", PingRequest::with_req_id(this.next_req_id()));
697                this.last_ping = Some(Instant::now());
698
699                if let Some(sink) = &this.sink {
700                    let sink = sink.clone();
701                    if let Ok(json) = serde_json::to_string(&ping_req) {
702                        tokio::spawn(async move {
703                            let mut sink = sink.lock().await;
704                            let _ = sink.send(WsMessage::Text(json.into())).await;
705                        });
706                    }
707                }
708            }
709        }
710
711        // Check connection health
712        if !self.check_connection_health() && self.connected {
713            let this = self.as_mut().get_mut();
714            this.connected = false;
715
716            if this.should_reconnect() {
717                return Poll::Ready(Some(Ok(WsMessageEvent::Reconnecting {
718                    attempt: this.reconnect_attempt + 1,
719                })));
720            } else {
721                return Poll::Ready(Some(Ok(WsMessageEvent::Disconnected)));
722            }
723        }
724
725        // Poll the receiver for messages
726        if let Some(receiver) = self.receiver.as_mut() {
727            match Pin::new(receiver).poll_next(cx) {
728                Poll::Ready(Some(Ok(msg))) => {
729                    let this = self.as_mut().get_mut();
730                    match msg {
731                        WsMessage::Text(text) => {
732                            if let Some(event) = this.parse_message(&text) {
733                                return Poll::Ready(Some(Ok(event)));
734                            }
735                            // If parse returned None, continue polling
736                            cx.waker().wake_by_ref();
737                            return Poll::Pending;
738                        }
739                        WsMessage::Binary(data) => {
740                            // Try to parse binary as JSON text
741                            if let Ok(text) = String::from_utf8(data.to_vec()) {
742                                if let Some(event) = this.parse_message(&text) {
743                                    return Poll::Ready(Some(Ok(event)));
744                                }
745                            }
746                            cx.waker().wake_by_ref();
747                            return Poll::Pending;
748                        }
749                        WsMessage::Ping(_) | WsMessage::Pong(_) => {
750                            // Handled automatically by tungstenite
751                            cx.waker().wake_by_ref();
752                            return Poll::Pending;
753                        }
754                        WsMessage::Close(_) => {
755                            this.connected = false;
756                            if this.should_reconnect() {
757                                return Poll::Ready(Some(Ok(WsMessageEvent::Reconnecting {
758                                    attempt: this.reconnect_attempt + 1,
759                                })));
760                            } else {
761                                return Poll::Ready(Some(Ok(WsMessageEvent::Disconnected)));
762                            }
763                        }
764                        WsMessage::Frame(_) => {
765                            cx.waker().wake_by_ref();
766                            return Poll::Pending;
767                        }
768                    }
769                }
770                Poll::Ready(Some(Err(e))) => {
771                    let this = self.as_mut().get_mut();
772                    this.connected = false;
773                    tracing::warn!("WebSocket error: {}", e);
774
775                    if this.should_reconnect() {
776                        return Poll::Ready(Some(Ok(WsMessageEvent::Reconnecting {
777                            attempt: this.reconnect_attempt + 1,
778                        })));
779                    } else {
780                        return Poll::Ready(Some(Err(KrakenError::WebSocket(e))));
781                    }
782                }
783                Poll::Ready(None) => {
784                    let this = self.as_mut().get_mut();
785                    this.connected = false;
786
787                    if this.should_reconnect() {
788                        return Poll::Ready(Some(Ok(WsMessageEvent::Reconnecting {
789                            attempt: this.reconnect_attempt + 1,
790                        })));
791                    } else {
792                        return Poll::Ready(None);
793                    }
794                }
795                Poll::Pending => {}
796            }
797        } else if !self.reconnecting && self.should_reconnect() {
798            // Need to reconnect
799            return Poll::Ready(Some(Ok(WsMessageEvent::Reconnecting {
800                attempt: self.reconnect_attempt + 1,
801            })));
802        }
803
804        Poll::Pending
805    }
806}
807
808/// Generate a subscription key for tracking.
809fn subscription_key(params: &SubscribeParams) -> String {
810    let symbols = params
811        .symbol
812        .as_ref()
813        .map(|s| s.join(","))
814        .unwrap_or_default();
815    format!("{}:{}", params.channel, symbols)
816}
817
818/// Generate a subscription key from a result.
819fn subscription_key_from_result(result: &SubscriptionResult) -> String {
820    format!(
821        "{}:{}",
822        result.channel,
823        result.symbol.as_deref().unwrap_or("")
824    )
825}
826
827#[cfg(test)]
828mod tests {
829    use super::*;
830
831    #[test]
832    fn test_subscription_key() {
833        let params = SubscribeParams::public("ticker", vec!["BTC/USD".into(), "ETH/USD".into()]);
834        let key = subscription_key(&params);
835        assert_eq!(key, "ticker:BTC/USD,ETH/USD");
836    }
837
838    #[test]
839    fn test_backoff_calculation_formula() {
840        // Test backoff formula: base * 2^attempt, capped at max
841        let initial = Duration::from_secs(1);
842        let max = Duration::from_secs(60);
843
844        // Attempt 0: 1 * 2^0 = 1
845        let attempt = 0;
846        let multiplier = 2u64.saturating_pow(attempt);
847        let result = (initial.as_millis() as u64 * multiplier).min(max.as_millis() as u64);
848        assert_eq!(Duration::from_millis(result), Duration::from_secs(1));
849
850        // Attempt 3: 1 * 2^3 = 8
851        let attempt = 3;
852        let multiplier = 2u64.saturating_pow(attempt);
853        let result = (initial.as_millis() as u64 * multiplier).min(max.as_millis() as u64);
854        assert_eq!(Duration::from_millis(result), Duration::from_secs(8));
855
856        // Attempt 10: 1 * 2^10 = 1024 -> capped at 60
857        let attempt = 10;
858        let multiplier = 2u64.saturating_pow(attempt);
859        let result = (initial.as_millis() as u64 * multiplier).min(max.as_millis() as u64);
860        assert_eq!(Duration::from_millis(result), Duration::from_secs(60));
861    }
862}