hyperstack_sdk/
connection.rs

1use crate::config::ConnectionConfig;
2use crate::frame::{parse_frame, Frame};
3use crate::subscription::{Subscription, SubscriptionRegistry};
4use futures_util::{SinkExt, StreamExt};
5use std::sync::Arc;
6use tokio::sync::{mpsc, RwLock};
7use tokio::time::{sleep, Duration};
8use tokio_tungstenite::{connect_async, tungstenite::Message};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum ConnectionState {
12    Disconnected,
13    Connecting,
14    Connected,
15    Reconnecting { attempt: u32 },
16    Error,
17}
18
19pub enum ConnectionCommand {
20    Subscribe(Subscription),
21    #[allow(dead_code)]
22    Unsubscribe(Subscription),
23    Disconnect,
24}
25
26struct ConnectionManagerInner {
27    #[allow(dead_code)]
28    url: String,
29    state: Arc<RwLock<ConnectionState>>,
30    subscriptions: Arc<RwLock<SubscriptionRegistry>>,
31    #[allow(dead_code)]
32    config: ConnectionConfig,
33    command_tx: mpsc::Sender<ConnectionCommand>,
34}
35
36#[derive(Clone)]
37pub struct ConnectionManager {
38    inner: Arc<ConnectionManagerInner>,
39}
40
41impl ConnectionManager {
42    pub async fn new(url: String, config: ConnectionConfig, frame_tx: mpsc::Sender<Frame>) -> Self {
43        let (command_tx, command_rx) = mpsc::channel(100);
44        let state = Arc::new(RwLock::new(ConnectionState::Disconnected));
45        let subscriptions = Arc::new(RwLock::new(SubscriptionRegistry::new()));
46
47        let inner = ConnectionManagerInner {
48            url: url.clone(),
49            state: state.clone(),
50            subscriptions: subscriptions.clone(),
51            config: config.clone(),
52            command_tx,
53        };
54
55        spawn_connection_loop(url, state, subscriptions, config, frame_tx, command_rx);
56
57        Self {
58            inner: Arc::new(inner),
59        }
60    }
61
62    pub async fn state(&self) -> ConnectionState {
63        *self.inner.state.read().await
64    }
65
66    pub async fn ensure_subscription(&self, view: &str, key: Option<&str>) {
67        let sub = Subscription {
68            view: view.to_string(),
69            key: key.map(|s| s.to_string()),
70            partition: None,
71            filters: None,
72        };
73
74        if !self.inner.subscriptions.read().await.contains(&sub) {
75            let _ = self
76                .inner
77                .command_tx
78                .send(ConnectionCommand::Subscribe(sub))
79                .await;
80        }
81    }
82
83    #[allow(dead_code)]
84    pub async fn subscribe(&self, sub: Subscription) {
85        let _ = self
86            .inner
87            .command_tx
88            .send(ConnectionCommand::Subscribe(sub))
89            .await;
90    }
91
92    pub async fn disconnect(&self) {
93        let _ = self
94            .inner
95            .command_tx
96            .send(ConnectionCommand::Disconnect)
97            .await;
98    }
99}
100
101fn spawn_connection_loop(
102    url: String,
103    state: Arc<RwLock<ConnectionState>>,
104    subscriptions: Arc<RwLock<SubscriptionRegistry>>,
105    config: ConnectionConfig,
106    frame_tx: mpsc::Sender<Frame>,
107    mut command_rx: mpsc::Receiver<ConnectionCommand>,
108) {
109    tokio::spawn(async move {
110        let mut reconnect_attempt: u32 = 0;
111        let mut should_run = true;
112
113        while should_run {
114            *state.write().await = ConnectionState::Connecting;
115
116            match connect_async(&url).await {
117                Ok((ws, _)) => {
118                    *state.write().await = ConnectionState::Connected;
119                    reconnect_attempt = 0;
120
121                    let (mut ws_tx, mut ws_rx) = ws.split();
122
123                    let subs = subscriptions.read().await.all();
124                    for sub in subs {
125                        if let Ok(msg) = serde_json::to_string(&sub) {
126                            let _ = ws_tx.send(Message::Text(msg)).await;
127                        }
128                    }
129
130                    let ping_interval = config.ping_interval;
131                    let mut ping_timer = tokio::time::interval(ping_interval);
132
133                    loop {
134                        tokio::select! {
135                            msg = ws_rx.next() => {
136                                match msg {
137                                    Some(Ok(Message::Binary(bytes))) => {
138                                        if let Ok(frame) = parse_frame(&bytes) {
139                                            let _ = frame_tx.send(frame).await;
140                                        }
141                                    }
142                                    Some(Ok(Message::Text(text))) => {
143                                        if let Ok(frame) = serde_json::from_str::<Frame>(&text) {
144                                            let _ = frame_tx.send(frame).await;
145                                        }
146                                    }
147                                    Some(Ok(Message::Ping(payload))) => {
148                                        let _ = ws_tx.send(Message::Pong(payload)).await;
149                                    }
150                                    Some(Ok(Message::Close(_))) => {
151                                        break;
152                                    }
153                                    Some(Err(_)) => {
154                                        break;
155                                    }
156                                    None => {
157                                        break;
158                                    }
159                                    _ => {}
160                                }
161                            }
162                            cmd = command_rx.recv() => {
163                                match cmd {
164                                    Some(ConnectionCommand::Subscribe(sub)) => {
165                                        subscriptions.write().await.add(sub.clone());
166                                        if let Ok(msg) = serde_json::to_string(&sub) {
167                                            let _ = ws_tx.send(Message::Text(msg)).await;
168                                        }
169                                    }
170                                    Some(ConnectionCommand::Unsubscribe(sub)) => {
171                                        subscriptions.write().await.remove(&sub);
172                                    }
173                                    Some(ConnectionCommand::Disconnect) => {
174                                        let _ = ws_tx.close().await;
175                                        *state.write().await = ConnectionState::Disconnected;
176                                        should_run = false;
177                                        break;
178                                    }
179                                    None => {
180                                        should_run = false;
181                                        break;
182                                    }
183                                }
184                            }
185                            _ = ping_timer.tick() => {
186                                let _ = ws_tx.send(Message::Ping(vec![])).await;
187                            }
188                        }
189                    }
190                }
191                Err(e) => {
192                    tracing::error!("Connection failed: {}", e);
193                }
194            }
195
196            if !should_run {
197                break;
198            }
199
200            if !config.auto_reconnect {
201                *state.write().await = ConnectionState::Error;
202                break;
203            }
204
205            if reconnect_attempt >= config.max_reconnect_attempts {
206                *state.write().await = ConnectionState::Error;
207                break;
208            }
209
210            let delay = config
211                .reconnect_intervals
212                .get(reconnect_attempt as usize)
213                .copied()
214                .unwrap_or_else(|| {
215                    config
216                        .reconnect_intervals
217                        .last()
218                        .copied()
219                        .unwrap_or(Duration::from_secs(16))
220                });
221
222            *state.write().await = ConnectionState::Reconnecting {
223                attempt: reconnect_attempt,
224            };
225            reconnect_attempt += 1;
226
227            tracing::info!(
228                "Reconnecting in {:?} (attempt {})",
229                delay,
230                reconnect_attempt
231            );
232            sleep(delay).await;
233        }
234    });
235}