1use crate::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig};
4use crate::endpoint::Endpoint;
5use crate::events::{ConnectionEvent, DisconnectReason, Event, L3Event, MarketEvent, SubscriptionEvent};
6use crate::reconnect::ReconnectConfig;
7use crate::subscription::{Subscription, SubscriptionManager};
8
9use dashmap::DashMap;
10use futures_util::{SinkExt, StreamExt};
11use std::pin::Pin;
12use std::task::{Context, Poll};
13use kraken_book::Orderbook;
14use kraken_types::{Channel, Depth, KrakenError, MethodResponse, WsMessage};
15use parking_lot::RwLock;
16use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
17use std::sync::Arc;
18use tokio::sync::mpsc;
19use tokio::time::{timeout, Duration};
20use tokio_tungstenite::{connect_async, tungstenite::Message};
21use tracing::{debug, error, info, instrument, warn};
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum ConnectionState {
26 Disconnected,
28 Connecting,
30 Connected,
32 Reconnecting,
34 ShuttingDown,
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
40pub enum BackpressurePolicy {
41 #[default]
43 DropNewest,
44 Block,
46}
47
48#[derive(Debug, Clone)]
50pub struct ConnectionConfig {
51 pub endpoint: Endpoint,
53 pub reconnect: ReconnectConfig,
55 pub connect_timeout: Duration,
57 pub depth: Depth,
59 pub heartbeat_timeout: Option<Duration>,
62 pub channel_capacity: Option<usize>,
64 pub backpressure_policy: BackpressurePolicy,
66 pub circuit_breaker: Option<CircuitBreakerConfig>,
68}
69
70impl Default for ConnectionConfig {
71 fn default() -> Self {
72 Self {
73 endpoint: Endpoint::Public,
74 reconnect: ReconnectConfig::default(),
75 connect_timeout: Duration::from_secs(10),
76 depth: Depth::D10,
77 heartbeat_timeout: Some(Duration::from_secs(30)),
78 channel_capacity: None, backpressure_policy: BackpressurePolicy::default(),
80 circuit_breaker: Some(CircuitBreakerConfig::default()), }
82 }
83}
84
85impl ConnectionConfig {
86 pub fn new() -> Self {
88 Self::default()
89 }
90
91 pub fn with_endpoint(mut self, endpoint: Endpoint) -> Self {
93 self.endpoint = endpoint;
94 self
95 }
96
97 pub fn with_reconnect(mut self, config: ReconnectConfig) -> Self {
99 self.reconnect = config;
100 self
101 }
102
103 pub fn without_reconnect(mut self) -> Self {
105 self.reconnect = ReconnectConfig::disabled();
106 self
107 }
108
109 pub fn with_timeout(mut self, timeout: Duration) -> Self {
111 self.connect_timeout = timeout;
112 self
113 }
114
115 pub fn with_depth(mut self, depth: Depth) -> Self {
117 self.depth = depth;
118 self
119 }
120
121 pub fn with_heartbeat_timeout(mut self, timeout: Duration) -> Self {
126 self.heartbeat_timeout = Some(timeout);
127 self
128 }
129
130 pub fn without_heartbeat_timeout(mut self) -> Self {
132 self.heartbeat_timeout = None;
133 self
134 }
135
136 pub fn with_channel_capacity(mut self, capacity: usize, policy: BackpressurePolicy) -> Self {
144 self.channel_capacity = Some(capacity);
145 self.backpressure_policy = policy;
146 self
147 }
148
149 pub fn with_unbounded_channel(mut self) -> Self {
151 self.channel_capacity = None;
152 self
153 }
154
155 pub fn with_circuit_breaker(mut self, config: CircuitBreakerConfig) -> Self {
160 self.circuit_breaker = Some(config);
161 self
162 }
163
164 pub fn without_circuit_breaker(mut self) -> Self {
166 self.circuit_breaker = None;
167 self
168 }
169}
170
171enum EventSender {
173 Unbounded(mpsc::UnboundedSender<Event>),
174 Bounded {
175 sender: mpsc::Sender<Event>,
176 policy: BackpressurePolicy,
177 dropped_count: std::sync::atomic::AtomicU64,
178 },
179}
180
181impl EventSender {
182 fn send(&self, event: Event) {
183 match self {
184 EventSender::Unbounded(tx) => {
185 let _ = tx.send(event);
186 }
187 EventSender::Bounded { sender, policy, dropped_count } => {
188 match policy {
189 BackpressurePolicy::DropNewest => {
190 if sender.try_send(event).is_err() {
191 dropped_count.fetch_add(1, Ordering::Relaxed);
192 }
193 }
194 BackpressurePolicy::Block => {
195 let _ = sender.blocking_send(event);
197 }
198 }
199 }
200 }
201 }
202
203 fn dropped_count(&self) -> u64 {
204 match self {
205 EventSender::Unbounded(_) => 0,
206 EventSender::Bounded { dropped_count, .. } => dropped_count.load(Ordering::Relaxed),
207 }
208 }
209}
210
211pub enum EventReceiver {
213 Unbounded(mpsc::UnboundedReceiver<Event>),
215 Bounded(mpsc::Receiver<Event>),
217}
218
219impl EventReceiver {
220 #[instrument(skip(self), level = "trace")]
222 pub async fn recv(&mut self) -> Option<Event> {
223 match self {
224 EventReceiver::Unbounded(rx) => rx.recv().await,
225 EventReceiver::Bounded(rx) => rx.recv().await,
226 }
227 }
228}
229
230impl futures::Stream for EventReceiver {
231 type Item = Event;
232
233 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
234 match self.get_mut() {
235 EventReceiver::Unbounded(rx) => Pin::new(rx).poll_recv(cx),
236 EventReceiver::Bounded(rx) => Pin::new(rx).poll_recv(cx),
237 }
238 }
239}
240
241pub struct KrakenConnection {
243 config: ConnectionConfig,
245 state: Arc<RwLock<ConnectionState>>,
247 orderbooks: Arc<DashMap<String, Orderbook>>,
249 subscriptions: Arc<RwLock<SubscriptionManager>>,
251 reconnect_attempt: AtomicU32,
253 shutdown: AtomicBool,
255 event_tx: EventSender,
257 event_rx: Arc<RwLock<Option<EventReceiver>>>,
259 last_message_time: Arc<RwLock<std::time::Instant>>,
261 circuit_breaker: Option<CircuitBreaker>,
263}
264
265impl KrakenConnection {
266 pub fn new(config: ConnectionConfig) -> Self {
268 let (event_tx, event_rx) = match config.channel_capacity {
269 Some(capacity) => {
270 let (tx, rx) = mpsc::channel(capacity);
271 (
272 EventSender::Bounded {
273 sender: tx,
274 policy: config.backpressure_policy,
275 dropped_count: std::sync::atomic::AtomicU64::new(0),
276 },
277 EventReceiver::Bounded(rx),
278 )
279 }
280 None => {
281 let (tx, rx) = mpsc::unbounded_channel();
282 (EventSender::Unbounded(tx), EventReceiver::Unbounded(rx))
283 }
284 };
285
286 let circuit_breaker = config.circuit_breaker.clone().map(CircuitBreaker::new);
287
288 Self {
289 config,
290 state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
291 orderbooks: Arc::new(DashMap::new()),
292 subscriptions: Arc::new(RwLock::new(SubscriptionManager::new())),
293 reconnect_attempt: AtomicU32::new(0),
294 shutdown: AtomicBool::new(false),
295 event_tx,
296 event_rx: Arc::new(RwLock::new(Some(event_rx))),
297 last_message_time: Arc::new(RwLock::new(std::time::Instant::now())),
298 circuit_breaker,
299 }
300 }
301
302 pub fn with_defaults() -> Self {
304 Self::new(ConnectionConfig::default())
305 }
306
307 pub fn state(&self) -> ConnectionState {
309 *self.state.read()
310 }
311
312 pub fn is_connected(&self) -> bool {
314 self.state() == ConnectionState::Connected
315 }
316
317 pub fn take_event_receiver(&self) -> Option<EventReceiver> {
319 self.event_rx.write().take()
320 }
321
322 pub fn dropped_event_count(&self) -> u64 {
326 self.event_tx.dropped_count()
327 }
328
329 pub fn orderbook(&self, symbol: &str) -> Option<dashmap::mapref::one::Ref<'_, String, Orderbook>>
331 {
332 self.orderbooks.get(symbol)
333 }
334
335 #[instrument(skip(self), fields(symbols = ?symbols))]
337 pub fn subscribe_orderbook(&self, symbols: Vec<String>) -> u64 {
338 let sub = Subscription::orderbook(symbols, self.config.depth);
339 self.subscriptions.write().add(sub)
340 }
341
342 #[instrument(skip(self), fields(symbols = ?symbols))]
344 pub fn subscribe_ticker(&self, symbols: Vec<String>) -> u64 {
345 let sub = Subscription::ticker(symbols);
346 self.subscriptions.write().add(sub)
347 }
348
349 #[instrument(skip(self), fields(symbols = ?symbols))]
351 pub fn subscribe_trade(&self, symbols: Vec<String>) -> u64 {
352 let sub = Subscription::trade(symbols);
353 self.subscriptions.write().add(sub)
354 }
355
356 #[instrument(skip(self), fields(symbols = ?symbols))]
361 pub fn subscribe_l3(&self, symbols: Vec<String>) -> u64 {
362 let sub = Subscription::level3(symbols);
363 self.subscriptions.write().add(sub)
364 }
365
366 #[instrument(skip(self), name = "kraken_connection")]
368 pub async fn connect_and_run(&self) -> Result<(), KrakenError> {
369 loop {
370 if self.shutdown.load(Ordering::Relaxed) {
371 break;
372 }
373
374 if let Some(ref breaker) = self.circuit_breaker {
376 if !breaker.allow_request() {
377 let stats = breaker.stats();
378 warn!(
379 "Circuit breaker is open (tripped {} times), waiting for recovery",
380 stats.trips
381 );
382 self.emit(ConnectionEvent::CircuitBreakerOpen {
383 trips: stats.trips,
384 });
385 tokio::time::sleep(Duration::from_secs(5)).await;
387 continue;
388 }
389 }
390
391 {
393 let mut state = self.state.write();
394 if *state == ConnectionState::Reconnecting {
395 } else {
397 *state = ConnectionState::Connecting;
398 }
399 }
400
401 match self.connect_internal().await {
402 Ok(()) => {
403 if let Some(ref breaker) = self.circuit_breaker {
405 breaker.record_success();
406 }
407 break;
408 }
409 Err(e) => {
410 if let Some(ref breaker) = self.circuit_breaker {
412 breaker.record_failure();
413 }
414
415 let attempt = self.reconnect_attempt.fetch_add(1, Ordering::Relaxed) + 1;
416
417 if !self.config.reconnect.should_reconnect(attempt) {
418 error!("Reconnection attempts exhausted after {} tries", attempt);
419 self.emit(ConnectionEvent::ReconnectFailed {
420 error: e.to_string(),
421 });
422 return Err(e);
423 }
424
425 let delay = self.config.reconnect.delay_with_jitter(attempt);
426 warn!(
427 "Connection failed, reconnecting in {:?} (attempt {}): {}",
428 delay, attempt, e
429 );
430
431 self.emit(ConnectionEvent::Reconnecting { attempt, delay });
432 *self.state.write() = ConnectionState::Reconnecting;
433
434 tokio::time::sleep(delay).await;
435 }
436 }
437 }
438
439 *self.state.write() = ConnectionState::Disconnected;
440 Ok(())
441 }
442
443 async fn connect_internal(&self) -> Result<(), KrakenError> {
445 let url = self.config.endpoint.url();
446 info!("Connecting to {}", url);
447
448 let connect_result = timeout(self.config.connect_timeout, connect_async(url)).await;
450
451 let (ws_stream, _response) = match connect_result {
452 Ok(Ok((stream, response))) => (stream, response),
453 Ok(Err(e)) => {
454 return Err(KrakenError::ConnectionFailed {
455 url: url.to_string(),
456 reason: e.to_string(),
457 });
458 }
459 Err(_) => {
460 return Err(KrakenError::ConnectionTimeout {
461 url: url.to_string(),
462 timeout: self.config.connect_timeout,
463 });
464 }
465 };
466
467 let (mut write, mut read) = ws_stream.split();
468
469 let mut connected = false;
471 while let Some(msg_result) = read.next().await {
472 match msg_result {
473 Ok(Message::Text(text)) => {
474 if let Ok(WsMessage::Status(status_msg)) = WsMessage::parse(&text) {
475 if let Some(data) = status_msg.data.first() {
476 info!(
477 "Connected to Kraken API {} (connection_id: {})",
478 data.api_version, data.connection_id
479 );
480
481 self.emit(ConnectionEvent::Connected {
482 api_version: data.api_version.clone(),
483 connection_id: data.connection_id,
484 });
485
486 connected = true;
487 break;
488 }
489 }
490 }
491 Ok(Message::Close(_)) => {
492 return Err(KrakenError::WebSocket("Connection closed before ready".into()));
493 }
494 Err(e) => {
495 return Err(KrakenError::WebSocket(e.to_string()));
496 }
497 _ => {}
498 }
499 }
500
501 if !connected {
502 return Err(KrakenError::WebSocket(
503 "No status message received".into(),
504 ));
505 }
506
507 *self.state.write() = ConnectionState::Connected;
509 self.reconnect_attempt.store(0, Ordering::Relaxed);
510
511 let requests = self.subscriptions.write().restoration_requests();
514
515 let book_symbols: Vec<String> = requests
517 .iter()
518 .filter_map(|(_, req)| {
519 if req.params.channel == Channel::Book {
520 Some(req.params.symbol.clone())
521 } else {
522 None
523 }
524 })
525 .flatten()
526 .collect();
527
528 if !book_symbols.is_empty() {
530 let instrument_request = serde_json::json!({
531 "method": "subscribe",
532 "params": {
533 "channel": "instrument",
534 "snapshot": true
535 }
536 });
537 let json = instrument_request.to_string();
538 debug!("Sending instrument subscription: {}", json);
539 write
540 .send(Message::Text(json))
541 .await
542 .map_err(|e| KrakenError::WebSocket(e.to_string()))?;
543
544 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
547 }
548
549 for (_req_id, request) in &requests {
551 let json = serde_json::to_string(request).map_err(|e| {
552 KrakenError::InvalidJson {
553 message: e.to_string(),
554 raw: None,
555 }
556 })?;
557 debug!("Sending subscription: {}", json);
558 write
559 .send(Message::Text(json))
560 .await
561 .map_err(|e| KrakenError::WebSocket(e.to_string()))?;
562 }
563
564 if !requests.is_empty() {
565 self.emit(ConnectionEvent::SubscriptionsRestored {
566 count: requests.len(),
567 });
568 }
569
570 *self.last_message_time.write() = std::time::Instant::now();
572
573 loop {
575 if self.shutdown.load(Ordering::Relaxed) {
576 info!("Shutdown requested, closing connection");
577 let _ = write.send(Message::Close(None)).await;
578 break;
579 }
580
581 let heartbeat_timeout = self.config.heartbeat_timeout.unwrap_or(Duration::from_secs(3600));
583
584 let msg_result = tokio::select! {
585 msg = read.next() => msg,
586 _ = tokio::time::sleep(heartbeat_timeout) => {
587 let elapsed = self.last_message_time.read().elapsed();
589 if elapsed >= heartbeat_timeout {
590 warn!("Heartbeat timeout: no message received for {:?}", elapsed);
591 self.emit(ConnectionEvent::Disconnected {
592 reason: DisconnectReason::HeartbeatTimeout,
593 });
594 return Err(KrakenError::WebSocket("Heartbeat timeout".into()));
595 }
596 continue;
597 }
598 };
599
600 match msg_result {
601 Some(Ok(Message::Text(text))) => {
602 *self.last_message_time.write() = std::time::Instant::now();
603 self.handle_message(&text);
604 }
605 Some(Ok(Message::Ping(data))) => {
606 *self.last_message_time.write() = std::time::Instant::now();
607 let _ = write.send(Message::Pong(data)).await;
608 }
609 Some(Ok(Message::Pong(_))) => {
610 *self.last_message_time.write() = std::time::Instant::now();
611 }
612 Some(Ok(Message::Close(_))) => {
613 info!("Server closed connection");
614 self.emit(ConnectionEvent::Disconnected {
615 reason: DisconnectReason::ServerClosed,
616 });
617 return Err(KrakenError::WebSocket("Server closed connection".into()));
618 }
619 Some(Err(e)) => {
620 error!("WebSocket error: {}", e);
621 self.emit(ConnectionEvent::Disconnected {
622 reason: DisconnectReason::NetworkError(e.to_string()),
623 });
624 return Err(KrakenError::WebSocket(e.to_string()));
625 }
626 Some(Ok(_)) => {}
627 None => {
628 info!("WebSocket stream ended");
629 break;
630 }
631 }
632 }
633
634 Ok(())
635 }
636
637 fn handle_message(&self, text: &str) {
639 match WsMessage::parse(text) {
640 Ok(msg) => match msg {
641 WsMessage::Status(status_msg) => {
642 if let Some(data) = status_msg.data.first() {
643 self.emit(MarketEvent::Status {
644 system: data.system.to_string(),
645 version: data.api_version.clone(),
646 });
647 }
648 }
649 WsMessage::Method(resp) => {
650 self.handle_subscribe_response(&resp);
651 }
652 WsMessage::Book(book_msg) => {
653 if let Some(data) = book_msg.data.first() {
654 let symbol = &data.symbol;
655 let is_snapshot = book_msg.msg_type == "snapshot";
656
657 let mut orderbook =
659 self.orderbooks.entry(symbol.clone()).or_insert_with(|| {
660 Orderbook::with_depth(symbol, self.config.depth as u32)
661 });
662
663 match orderbook.apply_book_data(data, is_snapshot) {
665 Ok(_result) => {
666 let snapshot = orderbook.snapshot();
667 let event = if is_snapshot {
668 MarketEvent::OrderbookSnapshot {
669 symbol: symbol.clone(),
670 snapshot,
671 }
672 } else {
673 MarketEvent::OrderbookUpdate {
674 symbol: symbol.clone(),
675 snapshot,
676 }
677 };
678 self.emit(event);
679 }
680 Err(mismatch) => {
681 warn!(
682 "Checksum mismatch for {}: expected {}, computed {}",
683 mismatch.symbol, mismatch.expected, mismatch.computed
684 );
685 self.emit(MarketEvent::ChecksumMismatch {
686 symbol: symbol.clone(),
687 expected: mismatch.expected,
688 computed: mismatch.computed,
689 });
690 }
691 }
692 }
693 }
694 WsMessage::Ticker(_ticker_msg) => {
695 debug!("Ticker update received");
697 }
698 WsMessage::Trade(_trade_msg) => {
699 debug!("Trade update received");
701 }
702 WsMessage::Ohlc(_ohlc_msg) => {
703 debug!("OHLC update received");
705 }
706 WsMessage::Instrument(instrument_msg) => {
707 for pair in &instrument_msg.data.pairs {
709 let symbol = &pair.symbol;
710
711 let mut orderbook =
713 self.orderbooks.entry(symbol.clone()).or_insert_with(|| {
714 Orderbook::with_depth(symbol, self.config.depth as u32)
715 });
716
717 orderbook.set_precision(pair.price_precision, pair.qty_precision);
718
719 debug!(
720 "Updated precision for {}: price={}, qty={}",
721 symbol, pair.price_precision, pair.qty_precision
722 );
723 }
724 }
725 WsMessage::Executions(_executions_msg) => {
726 debug!("Executions update received");
728 }
729 WsMessage::Balances(_balances_msg) => {
730 debug!("Balances update received");
732 }
733 WsMessage::Level3(l3_msg) => {
734 if let Some(data) = l3_msg.data.first() {
736 let is_snapshot = l3_msg.msg_type == "snapshot";
737 let event = L3Event::from_data(data, is_snapshot);
738 debug!(
739 "L3 {} received for {} ({} bids, {} asks)",
740 if is_snapshot { "snapshot" } else { "update" },
741 data.symbol,
742 data.bids.len(),
743 data.asks.len()
744 );
745 self.emit(event);
746 }
747 }
748 WsMessage::Heartbeat => {
749 self.emit(MarketEvent::Heartbeat);
750 }
751 WsMessage::Unknown(_) => {
752 debug!("Unknown message: {}", text);
753 }
754 _ => {
756 debug!("Unhandled message variant");
757 }
758 },
759 Err(e) => {
760 warn!("Failed to parse message: {} - {}", e, text);
761 }
762 }
763 }
764
765 fn handle_subscribe_response(&self, resp: &MethodResponse) {
767 if let Some(req_id) = resp.req_id {
768 if resp.success {
769 self.subscriptions.write().confirm(req_id);
770
771 if let Some(result) = &resp.result {
772 self.emit(SubscriptionEvent::Subscribed {
773 channel: result.channel.clone(),
774 symbols: result.symbol.clone().into_iter().collect(),
775 });
776 }
777 } else {
778 self.subscriptions.write().reject(req_id);
779
780 self.emit(SubscriptionEvent::Rejected {
781 channel: "unknown".to_string(),
782 reason: resp.error.clone().unwrap_or_default(),
783 });
784 }
785 }
786 }
787
788 fn emit(&self, event: impl Into<Event>) {
790 self.event_tx.send(event.into());
791 }
792
793 #[instrument(skip(self))]
795 pub fn shutdown(&self) {
796 info!("Shutdown requested");
797 self.shutdown.store(true, Ordering::Relaxed);
798 *self.state.write() = ConnectionState::ShuttingDown;
799 }
800
801 #[instrument(skip(self))]
806 pub async fn shutdown_gracefully(&self, timeout: Duration) -> bool {
807 info!("Graceful shutdown requested with timeout {:?}", timeout);
808 self.shutdown.store(true, Ordering::Relaxed);
809 *self.state.write() = ConnectionState::ShuttingDown;
810
811 let deadline = std::time::Instant::now() + timeout;
813 loop {
814 if self.state() == ConnectionState::Disconnected {
815 info!("Graceful shutdown complete");
816 return true;
817 }
818 if std::time::Instant::now() >= deadline {
819 warn!("Shutdown timed out after {:?}", timeout);
820 return false;
821 }
822 tokio::time::sleep(Duration::from_millis(50)).await;
823 }
824 }
825
826 pub fn is_shutting_down(&self) -> bool {
828 self.shutdown.load(Ordering::Relaxed)
829 }
830
831 pub fn time_since_last_message(&self) -> Duration {
833 self.last_message_time.read().elapsed()
834 }
835}
836
837#[cfg(test)]
838mod tests {
839 use super::*;
840
841 #[test]
842 fn test_connection_config() {
843 let config = ConnectionConfig::new()
844 .with_endpoint(Endpoint::PublicBeta)
845 .with_depth(Depth::D25)
846 .with_timeout(Duration::from_secs(5));
847
848 assert_eq!(config.endpoint, Endpoint::PublicBeta);
849 assert_eq!(config.depth, Depth::D25);
850 assert_eq!(config.connect_timeout, Duration::from_secs(5));
851 }
852
853 #[test]
854 fn test_connection_state() {
855 let conn = KrakenConnection::with_defaults();
856 assert_eq!(conn.state(), ConnectionState::Disconnected);
857 assert!(!conn.is_connected());
858 }
859}