hyperstack_sdk/
connection.rs

1use crate::config::ConnectionConfig;
2use crate::frame::{parse_frame, Frame};
3use crate::subscription::{ClientMessage, Subscription, SubscriptionRegistry, Unsubscription};
4use futures_util::{SinkExt, StreamExt};
5use std::sync::Arc;
6use tokio::sync::{mpsc, RwLock};
7use tokio::time::{sleep, Duration};
8use tokio_tungstenite::{connect_async, tungstenite::Message};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum ConnectionState {
12    Disconnected,
13    Connecting,
14    Connected,
15    Reconnecting { attempt: u32 },
16    Error,
17}
18
19pub enum ConnectionCommand {
20    Subscribe(Subscription),
21    Unsubscribe(Unsubscription),
22    Disconnect,
23}
24
25struct ConnectionManagerInner {
26    #[allow(dead_code)]
27    url: String,
28    state: Arc<RwLock<ConnectionState>>,
29    subscriptions: Arc<RwLock<SubscriptionRegistry>>,
30    #[allow(dead_code)]
31    config: ConnectionConfig,
32    command_tx: mpsc::Sender<ConnectionCommand>,
33}
34
35#[derive(Clone)]
36pub struct ConnectionManager {
37    inner: Arc<ConnectionManagerInner>,
38}
39
40impl ConnectionManager {
41    pub async fn new(url: String, config: ConnectionConfig, frame_tx: mpsc::Sender<Frame>) -> Self {
42        let (command_tx, command_rx) = mpsc::channel(100);
43        let state = Arc::new(RwLock::new(ConnectionState::Disconnected));
44        let subscriptions = Arc::new(RwLock::new(SubscriptionRegistry::new()));
45
46        let inner = ConnectionManagerInner {
47            url: url.clone(),
48            state: state.clone(),
49            subscriptions: subscriptions.clone(),
50            config: config.clone(),
51            command_tx,
52        };
53
54        spawn_connection_loop(url, state, subscriptions, config, frame_tx, command_rx);
55
56        Self {
57            inner: Arc::new(inner),
58        }
59    }
60
61    pub async fn state(&self) -> ConnectionState {
62        *self.inner.state.read().await
63    }
64
65    pub async fn ensure_subscription(&self, view: &str, key: Option<&str>) {
66        self.ensure_subscription_with_opts(view, key, None, None)
67            .await
68    }
69
70    pub async fn ensure_subscription_with_opts(
71        &self,
72        view: &str,
73        key: Option<&str>,
74        take: Option<u32>,
75        skip: Option<u32>,
76    ) {
77        let sub = Subscription {
78            view: view.to_string(),
79            key: key.map(|s| s.to_string()),
80            partition: None,
81            filters: None,
82            take,
83            skip,
84        };
85
86        if !self.inner.subscriptions.read().await.contains(&sub) {
87            let _ = self
88                .inner
89                .command_tx
90                .send(ConnectionCommand::Subscribe(sub))
91                .await;
92        }
93    }
94
95    pub async fn subscribe(&self, sub: Subscription) {
96        let _ = self
97            .inner
98            .command_tx
99            .send(ConnectionCommand::Subscribe(sub))
100            .await;
101    }
102
103    pub async fn unsubscribe(&self, unsub: Unsubscription) {
104        let _ = self
105            .inner
106            .command_tx
107            .send(ConnectionCommand::Unsubscribe(unsub))
108            .await;
109    }
110
111    pub async fn disconnect(&self) {
112        let _ = self
113            .inner
114            .command_tx
115            .send(ConnectionCommand::Disconnect)
116            .await;
117    }
118}
119
120fn spawn_connection_loop(
121    url: String,
122    state: Arc<RwLock<ConnectionState>>,
123    subscriptions: Arc<RwLock<SubscriptionRegistry>>,
124    config: ConnectionConfig,
125    frame_tx: mpsc::Sender<Frame>,
126    mut command_rx: mpsc::Receiver<ConnectionCommand>,
127) {
128    tokio::spawn(async move {
129        let mut reconnect_attempt: u32 = 0;
130        let mut should_run = true;
131
132        while should_run {
133            *state.write().await = ConnectionState::Connecting;
134
135            match connect_async(&url).await {
136                Ok((ws, _)) => {
137                    *state.write().await = ConnectionState::Connected;
138                    reconnect_attempt = 0;
139
140                    let (mut ws_tx, mut ws_rx) = ws.split();
141
142                    let subs = subscriptions.read().await.all();
143                    for sub in subs {
144                        let client_msg = ClientMessage::Subscribe(sub);
145                        if let Ok(msg) = serde_json::to_string(&client_msg) {
146                            let _ = ws_tx.send(Message::Text(msg)).await;
147                        }
148                    }
149
150                    let ping_interval = config.ping_interval;
151                    let mut ping_timer = tokio::time::interval(ping_interval);
152
153                    loop {
154                        tokio::select! {
155                            msg = ws_rx.next() => {
156                                match msg {
157                                    Some(Ok(Message::Binary(bytes))) => {
158                                        if let Ok(frame) = parse_frame(&bytes) {
159                                            let _ = frame_tx.send(frame).await;
160                                        }
161                                    }
162                                    Some(Ok(Message::Text(text))) => {
163                                        if let Ok(frame) = serde_json::from_str::<Frame>(&text) {
164                                            let _ = frame_tx.send(frame).await;
165                                        }
166                                    }
167                                    Some(Ok(Message::Ping(payload))) => {
168                                        let _ = ws_tx.send(Message::Pong(payload)).await;
169                                    }
170                                    Some(Ok(Message::Close(_))) => {
171                                        break;
172                                    }
173                                    Some(Err(_)) => {
174                                        break;
175                                    }
176                                    None => {
177                                        break;
178                                    }
179                                    _ => {}
180                                }
181                            }
182                            cmd = command_rx.recv() => {
183                                match cmd {
184                                    Some(ConnectionCommand::Subscribe(sub)) => {
185                                        subscriptions.write().await.add(sub.clone());
186                                        let client_msg = ClientMessage::Subscribe(sub);
187                                        if let Ok(msg) = serde_json::to_string(&client_msg) {
188                                            let _ = ws_tx.send(Message::Text(msg)).await;
189                                        }
190                                    }
191                                    Some(ConnectionCommand::Unsubscribe(unsub)) => {
192                                        let sub = Subscription {
193                                            view: unsub.view.clone(),
194                                            key: unsub.key.clone(),
195                                            partition: None,
196                                            filters: None,
197                                            take: None,
198                                            skip: None,
199                                        };
200                                        subscriptions.write().await.remove(&sub);
201                                        let client_msg = ClientMessage::Unsubscribe(unsub);
202                                        if let Ok(msg) = serde_json::to_string(&client_msg) {
203                                            let _ = ws_tx.send(Message::Text(msg)).await;
204                                        }
205                                    }
206                                    Some(ConnectionCommand::Disconnect) => {
207                                        let _ = ws_tx.close().await;
208                                        *state.write().await = ConnectionState::Disconnected;
209                                        should_run = false;
210                                        break;
211                                    }
212                                    None => {
213                                        should_run = false;
214                                        break;
215                                    }
216                                }
217                            }
218                            _ = ping_timer.tick() => {
219                                if let Ok(msg) = serde_json::to_string(&ClientMessage::Ping) {
220                                    let _ = ws_tx.send(Message::Text(msg)).await;
221                                }
222                            }
223                        }
224                    }
225                }
226                Err(e) => {
227                    tracing::error!("Connection failed: {}", e);
228                }
229            }
230
231            if !should_run {
232                break;
233            }
234
235            if !config.auto_reconnect {
236                *state.write().await = ConnectionState::Error;
237                break;
238            }
239
240            if reconnect_attempt >= config.max_reconnect_attempts {
241                *state.write().await = ConnectionState::Error;
242                break;
243            }
244
245            let delay = config
246                .reconnect_intervals
247                .get(reconnect_attempt as usize)
248                .copied()
249                .unwrap_or_else(|| {
250                    config
251                        .reconnect_intervals
252                        .last()
253                        .copied()
254                        .unwrap_or(Duration::from_secs(16))
255                });
256
257            *state.write().await = ConnectionState::Reconnecting {
258                attempt: reconnect_attempt,
259            };
260            reconnect_attempt += 1;
261
262            tracing::info!(
263                "Reconnecting in {:?} (attempt {})",
264                delay,
265                reconnect_attempt
266            );
267            sleep(delay).await;
268        }
269    });
270}