hyperstack_sdk/
connection.rs

1use crate::config::ConnectionConfig;
2use crate::frame::{parse_frame, Frame};
3use crate::subscription::{ClientMessage, Subscription, SubscriptionRegistry, Unsubscription};
4use futures_util::{SinkExt, StreamExt};
5use std::sync::Arc;
6use tokio::sync::{mpsc, RwLock};
7use tokio::time::{sleep, Duration};
8use tokio_tungstenite::{connect_async, tungstenite::Message};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum ConnectionState {
12    Disconnected,
13    Connecting,
14    Connected,
15    Reconnecting { attempt: u32 },
16    Error,
17}
18
19pub enum ConnectionCommand {
20    Subscribe(Subscription),
21    Unsubscribe(Unsubscription),
22    Disconnect,
23}
24
25struct ConnectionManagerInner {
26    #[allow(dead_code)]
27    url: String,
28    state: Arc<RwLock<ConnectionState>>,
29    subscriptions: Arc<RwLock<SubscriptionRegistry>>,
30    #[allow(dead_code)]
31    config: ConnectionConfig,
32    command_tx: mpsc::Sender<ConnectionCommand>,
33}
34
35#[derive(Clone)]
36pub struct ConnectionManager {
37    inner: Arc<ConnectionManagerInner>,
38}
39
40impl ConnectionManager {
41    pub async fn new(url: String, config: ConnectionConfig, frame_tx: mpsc::Sender<Frame>) -> Self {
42        let (command_tx, command_rx) = mpsc::channel(100);
43        let state = Arc::new(RwLock::new(ConnectionState::Disconnected));
44        let subscriptions = Arc::new(RwLock::new(SubscriptionRegistry::new()));
45
46        let inner = ConnectionManagerInner {
47            url: url.clone(),
48            state: state.clone(),
49            subscriptions: subscriptions.clone(),
50            config: config.clone(),
51            command_tx,
52        };
53
54        spawn_connection_loop(url, state, subscriptions, config, frame_tx, command_rx);
55
56        Self {
57            inner: Arc::new(inner),
58        }
59    }
60
61    pub async fn state(&self) -> ConnectionState {
62        *self.inner.state.read().await
63    }
64
65    pub async fn ensure_subscription(&self, view: &str, key: Option<&str>) {
66        let sub = Subscription {
67            view: view.to_string(),
68            key: key.map(|s| s.to_string()),
69            partition: None,
70            filters: None,
71        };
72
73        if !self.inner.subscriptions.read().await.contains(&sub) {
74            let _ = self
75                .inner
76                .command_tx
77                .send(ConnectionCommand::Subscribe(sub))
78                .await;
79        }
80    }
81
82    pub async fn subscribe(&self, sub: Subscription) {
83        let _ = self
84            .inner
85            .command_tx
86            .send(ConnectionCommand::Subscribe(sub))
87            .await;
88    }
89
90    pub async fn unsubscribe(&self, unsub: Unsubscription) {
91        let _ = self
92            .inner
93            .command_tx
94            .send(ConnectionCommand::Unsubscribe(unsub))
95            .await;
96    }
97
98    pub async fn disconnect(&self) {
99        let _ = self
100            .inner
101            .command_tx
102            .send(ConnectionCommand::Disconnect)
103            .await;
104    }
105}
106
107fn spawn_connection_loop(
108    url: String,
109    state: Arc<RwLock<ConnectionState>>,
110    subscriptions: Arc<RwLock<SubscriptionRegistry>>,
111    config: ConnectionConfig,
112    frame_tx: mpsc::Sender<Frame>,
113    mut command_rx: mpsc::Receiver<ConnectionCommand>,
114) {
115    tokio::spawn(async move {
116        let mut reconnect_attempt: u32 = 0;
117        let mut should_run = true;
118
119        while should_run {
120            *state.write().await = ConnectionState::Connecting;
121
122            match connect_async(&url).await {
123                Ok((ws, _)) => {
124                    *state.write().await = ConnectionState::Connected;
125                    reconnect_attempt = 0;
126
127                    let (mut ws_tx, mut ws_rx) = ws.split();
128
129                    let subs = subscriptions.read().await.all();
130                    for sub in subs {
131                        let client_msg = ClientMessage::Subscribe(sub);
132                        if let Ok(msg) = serde_json::to_string(&client_msg) {
133                            let _ = ws_tx.send(Message::Text(msg)).await;
134                        }
135                    }
136
137                    let ping_interval = config.ping_interval;
138                    let mut ping_timer = tokio::time::interval(ping_interval);
139
140                    loop {
141                        tokio::select! {
142                            msg = ws_rx.next() => {
143                                match msg {
144                                    Some(Ok(Message::Binary(bytes))) => {
145                                        if let Ok(frame) = parse_frame(&bytes) {
146                                            let _ = frame_tx.send(frame).await;
147                                        }
148                                    }
149                                    Some(Ok(Message::Text(text))) => {
150                                        if let Ok(frame) = serde_json::from_str::<Frame>(&text) {
151                                            let _ = frame_tx.send(frame).await;
152                                        }
153                                    }
154                                    Some(Ok(Message::Ping(payload))) => {
155                                        let _ = ws_tx.send(Message::Pong(payload)).await;
156                                    }
157                                    Some(Ok(Message::Close(_))) => {
158                                        break;
159                                    }
160                                    Some(Err(_)) => {
161                                        break;
162                                    }
163                                    None => {
164                                        break;
165                                    }
166                                    _ => {}
167                                }
168                            }
169                            cmd = command_rx.recv() => {
170                                match cmd {
171                                    Some(ConnectionCommand::Subscribe(sub)) => {
172                                        subscriptions.write().await.add(sub.clone());
173                                        let client_msg = ClientMessage::Subscribe(sub);
174                                        if let Ok(msg) = serde_json::to_string(&client_msg) {
175                                            let _ = ws_tx.send(Message::Text(msg)).await;
176                                        }
177                                    }
178                                    Some(ConnectionCommand::Unsubscribe(unsub)) => {
179                                        let sub = Subscription {
180                                            view: unsub.view.clone(),
181                                            key: unsub.key.clone(),
182                                            partition: None,
183                                            filters: None,
184                                        };
185                                        subscriptions.write().await.remove(&sub);
186                                        let client_msg = ClientMessage::Unsubscribe(unsub);
187                                        if let Ok(msg) = serde_json::to_string(&client_msg) {
188                                            let _ = ws_tx.send(Message::Text(msg)).await;
189                                        }
190                                    }
191                                    Some(ConnectionCommand::Disconnect) => {
192                                        let _ = ws_tx.close().await;
193                                        *state.write().await = ConnectionState::Disconnected;
194                                        should_run = false;
195                                        break;
196                                    }
197                                    None => {
198                                        should_run = false;
199                                        break;
200                                    }
201                                }
202                            }
203                            _ = ping_timer.tick() => {
204                                if let Ok(msg) = serde_json::to_string(&ClientMessage::Ping) {
205                                    let _ = ws_tx.send(Message::Text(msg)).await;
206                                }
207                            }
208                        }
209                    }
210                }
211                Err(e) => {
212                    tracing::error!("Connection failed: {}", e);
213                }
214            }
215
216            if !should_run {
217                break;
218            }
219
220            if !config.auto_reconnect {
221                *state.write().await = ConnectionState::Error;
222                break;
223            }
224
225            if reconnect_attempt >= config.max_reconnect_attempts {
226                *state.write().await = ConnectionState::Error;
227                break;
228            }
229
230            let delay = config
231                .reconnect_intervals
232                .get(reconnect_attempt as usize)
233                .copied()
234                .unwrap_or_else(|| {
235                    config
236                        .reconnect_intervals
237                        .last()
238                        .copied()
239                        .unwrap_or(Duration::from_secs(16))
240                });
241
242            *state.write().await = ConnectionState::Reconnecting {
243                attempt: reconnect_attempt,
244            };
245            reconnect_attempt += 1;
246
247            tracing::info!(
248                "Reconnecting in {:?} (attempt {})",
249                delay,
250                reconnect_attempt
251            );
252            sleep(delay).await;
253        }
254    });
255}