hyperstack_sdk/
connection.rs

1use crate::config::ConnectionConfig;
2use crate::frame::{parse_frame, Frame};
3use crate::subscription::{Subscription, SubscriptionRegistry};
4use futures_util::{SinkExt, StreamExt};
5use std::sync::Arc;
6use tokio::sync::{mpsc, RwLock};
7use tokio::time::{sleep, Duration};
8use tokio_tungstenite::{connect_async, tungstenite::Message};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum ConnectionState {
12    Disconnected,
13    Connecting,
14    Connected,
15    Reconnecting { attempt: u32 },
16    Error,
17}
18
19pub enum ConnectionCommand {
20    Subscribe(Subscription),
21    #[allow(dead_code)]
22    Unsubscribe(Subscription),
23    Disconnect,
24}
25
26pub struct ConnectionManager {
27    #[allow(dead_code)]
28    url: String,
29    state: Arc<RwLock<ConnectionState>>,
30    subscriptions: Arc<RwLock<SubscriptionRegistry>>,
31    #[allow(dead_code)]
32    config: ConnectionConfig,
33    command_tx: mpsc::Sender<ConnectionCommand>,
34}
35
36impl ConnectionManager {
37    pub async fn new(url: String, config: ConnectionConfig, frame_tx: mpsc::Sender<Frame>) -> Self {
38        let (command_tx, command_rx) = mpsc::channel(100);
39        let state = Arc::new(RwLock::new(ConnectionState::Disconnected));
40        let subscriptions = Arc::new(RwLock::new(SubscriptionRegistry::new()));
41
42        let manager = Self {
43            url: url.clone(),
44            state: state.clone(),
45            subscriptions: subscriptions.clone(),
46            config: config.clone(),
47            command_tx,
48        };
49
50        spawn_connection_loop(url, state, subscriptions, config, frame_tx, command_rx);
51
52        manager
53    }
54
55    pub async fn state(&self) -> ConnectionState {
56        *self.state.read().await
57    }
58
59    pub async fn ensure_subscription(&self, view: &str, key: Option<&str>) {
60        let sub = Subscription {
61            view: view.to_string(),
62            key: key.map(|s| s.to_string()),
63            partition: None,
64            filters: None,
65        };
66
67        if !self.subscriptions.read().await.contains(&sub) {
68            let _ = self
69                .command_tx
70                .send(ConnectionCommand::Subscribe(sub))
71                .await;
72        }
73    }
74
75    #[allow(dead_code)]
76    pub async fn subscribe(&self, sub: Subscription) {
77        let _ = self
78            .command_tx
79            .send(ConnectionCommand::Subscribe(sub))
80            .await;
81    }
82
83    pub async fn disconnect(&self) {
84        let _ = self.command_tx.send(ConnectionCommand::Disconnect).await;
85    }
86}
87
88fn spawn_connection_loop(
89    url: String,
90    state: Arc<RwLock<ConnectionState>>,
91    subscriptions: Arc<RwLock<SubscriptionRegistry>>,
92    config: ConnectionConfig,
93    frame_tx: mpsc::Sender<Frame>,
94    mut command_rx: mpsc::Receiver<ConnectionCommand>,
95) {
96    tokio::spawn(async move {
97        let mut reconnect_attempt: u32 = 0;
98        let mut should_run = true;
99
100        while should_run {
101            *state.write().await = ConnectionState::Connecting;
102
103            match connect_async(&url).await {
104                Ok((ws, _)) => {
105                    *state.write().await = ConnectionState::Connected;
106                    reconnect_attempt = 0;
107
108                    let (mut ws_tx, mut ws_rx) = ws.split();
109
110                    let subs = subscriptions.read().await.all();
111                    for sub in subs {
112                        if let Ok(msg) = serde_json::to_string(&sub) {
113                            let _ = ws_tx.send(Message::Text(msg)).await;
114                        }
115                    }
116
117                    let ping_interval = config.ping_interval;
118                    let mut ping_timer = tokio::time::interval(ping_interval);
119
120                    loop {
121                        tokio::select! {
122                            msg = ws_rx.next() => {
123                                match msg {
124                                    Some(Ok(Message::Binary(bytes))) => {
125                                        if let Ok(frame) = parse_frame(&bytes) {
126                                            let _ = frame_tx.send(frame).await;
127                                        }
128                                    }
129                                    Some(Ok(Message::Text(text))) => {
130                                        if let Ok(frame) = serde_json::from_str::<Frame>(&text) {
131                                            let _ = frame_tx.send(frame).await;
132                                        }
133                                    }
134                                    Some(Ok(Message::Ping(payload))) => {
135                                        let _ = ws_tx.send(Message::Pong(payload)).await;
136                                    }
137                                    Some(Ok(Message::Close(_))) => {
138                                        break;
139                                    }
140                                    Some(Err(_)) => {
141                                        break;
142                                    }
143                                    None => {
144                                        break;
145                                    }
146                                    _ => {}
147                                }
148                            }
149                            cmd = command_rx.recv() => {
150                                match cmd {
151                                    Some(ConnectionCommand::Subscribe(sub)) => {
152                                        subscriptions.write().await.add(sub.clone());
153                                        if let Ok(msg) = serde_json::to_string(&sub) {
154                                            let _ = ws_tx.send(Message::Text(msg)).await;
155                                        }
156                                    }
157                                    Some(ConnectionCommand::Unsubscribe(sub)) => {
158                                        subscriptions.write().await.remove(&sub);
159                                    }
160                                    Some(ConnectionCommand::Disconnect) => {
161                                        let _ = ws_tx.close().await;
162                                        *state.write().await = ConnectionState::Disconnected;
163                                        should_run = false;
164                                        break;
165                                    }
166                                    None => {
167                                        should_run = false;
168                                        break;
169                                    }
170                                }
171                            }
172                            _ = ping_timer.tick() => {
173                                let _ = ws_tx.send(Message::Ping(vec![])).await;
174                            }
175                        }
176                    }
177                }
178                Err(e) => {
179                    tracing::error!("Connection failed: {}", e);
180                }
181            }
182
183            if !should_run {
184                break;
185            }
186
187            if !config.auto_reconnect {
188                *state.write().await = ConnectionState::Error;
189                break;
190            }
191
192            if reconnect_attempt >= config.max_reconnect_attempts {
193                *state.write().await = ConnectionState::Error;
194                break;
195            }
196
197            let delay = config
198                .reconnect_intervals
199                .get(reconnect_attempt as usize)
200                .copied()
201                .unwrap_or_else(|| {
202                    config
203                        .reconnect_intervals
204                        .last()
205                        .copied()
206                        .unwrap_or(Duration::from_secs(16))
207                });
208
209            *state.write().await = ConnectionState::Reconnecting {
210                attempt: reconnect_attempt,
211            };
212            reconnect_attempt += 1;
213
214            tracing::info!(
215                "Reconnecting in {:?} (attempt {})",
216                delay,
217                reconnect_attempt
218            );
219            sleep(delay).await;
220        }
221    });
222}