hyperstack_sdk/
connection.rs1use crate::config::ConnectionConfig;
2use crate::frame::{parse_frame, Frame};
3use crate::subscription::{ClientMessage, Subscription, SubscriptionRegistry, Unsubscription};
4use futures_util::{SinkExt, StreamExt};
5use std::sync::Arc;
6use tokio::sync::{mpsc, RwLock};
7use tokio::time::{sleep, Duration};
8use tokio_tungstenite::{connect_async, tungstenite::Message};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum ConnectionState {
12 Disconnected,
13 Connecting,
14 Connected,
15 Reconnecting { attempt: u32 },
16 Error,
17}
18
19pub enum ConnectionCommand {
20 Subscribe(Subscription),
21 Unsubscribe(Unsubscription),
22 Disconnect,
23}
24
25struct ConnectionManagerInner {
26 #[allow(dead_code)]
27 url: String,
28 state: Arc<RwLock<ConnectionState>>,
29 subscriptions: Arc<RwLock<SubscriptionRegistry>>,
30 #[allow(dead_code)]
31 config: ConnectionConfig,
32 command_tx: mpsc::Sender<ConnectionCommand>,
33}
34
35#[derive(Clone)]
36pub struct ConnectionManager {
37 inner: Arc<ConnectionManagerInner>,
38}
39
40impl ConnectionManager {
41 pub async fn new(url: String, config: ConnectionConfig, frame_tx: mpsc::Sender<Frame>) -> Self {
42 let (command_tx, command_rx) = mpsc::channel(100);
43 let state = Arc::new(RwLock::new(ConnectionState::Disconnected));
44 let subscriptions = Arc::new(RwLock::new(SubscriptionRegistry::new()));
45
46 let inner = ConnectionManagerInner {
47 url: url.clone(),
48 state: state.clone(),
49 subscriptions: subscriptions.clone(),
50 config: config.clone(),
51 command_tx,
52 };
53
54 spawn_connection_loop(url, state, subscriptions, config, frame_tx, command_rx);
55
56 Self {
57 inner: Arc::new(inner),
58 }
59 }
60
61 pub async fn state(&self) -> ConnectionState {
62 *self.inner.state.read().await
63 }
64
65 pub async fn ensure_subscription(&self, view: &str, key: Option<&str>) {
66 self.ensure_subscription_with_opts(view, key, None, None)
67 .await
68 }
69
70 pub async fn ensure_subscription_with_opts(
71 &self,
72 view: &str,
73 key: Option<&str>,
74 take: Option<u32>,
75 skip: Option<u32>,
76 ) {
77 let sub = Subscription {
78 view: view.to_string(),
79 key: key.map(|s| s.to_string()),
80 partition: None,
81 filters: None,
82 take,
83 skip,
84 };
85
86 if !self.inner.subscriptions.read().await.contains(&sub) {
87 let _ = self
88 .inner
89 .command_tx
90 .send(ConnectionCommand::Subscribe(sub))
91 .await;
92 }
93 }
94
95 pub async fn subscribe(&self, sub: Subscription) {
96 let _ = self
97 .inner
98 .command_tx
99 .send(ConnectionCommand::Subscribe(sub))
100 .await;
101 }
102
103 pub async fn unsubscribe(&self, unsub: Unsubscription) {
104 let _ = self
105 .inner
106 .command_tx
107 .send(ConnectionCommand::Unsubscribe(unsub))
108 .await;
109 }
110
111 pub async fn disconnect(&self) {
112 let _ = self
113 .inner
114 .command_tx
115 .send(ConnectionCommand::Disconnect)
116 .await;
117 }
118}
119
120fn spawn_connection_loop(
121 url: String,
122 state: Arc<RwLock<ConnectionState>>,
123 subscriptions: Arc<RwLock<SubscriptionRegistry>>,
124 config: ConnectionConfig,
125 frame_tx: mpsc::Sender<Frame>,
126 mut command_rx: mpsc::Receiver<ConnectionCommand>,
127) {
128 tokio::spawn(async move {
129 let mut reconnect_attempt: u32 = 0;
130 let mut should_run = true;
131
132 while should_run {
133 *state.write().await = ConnectionState::Connecting;
134
135 match connect_async(&url).await {
136 Ok((ws, _)) => {
137 *state.write().await = ConnectionState::Connected;
138 reconnect_attempt = 0;
139
140 let (mut ws_tx, mut ws_rx) = ws.split();
141
142 let subs = subscriptions.read().await.all();
143 for sub in subs {
144 let client_msg = ClientMessage::Subscribe(sub);
145 if let Ok(msg) = serde_json::to_string(&client_msg) {
146 let _ = ws_tx.send(Message::Text(msg)).await;
147 }
148 }
149
150 let ping_interval = config.ping_interval;
151 let mut ping_timer = tokio::time::interval(ping_interval);
152
153 loop {
154 tokio::select! {
155 msg = ws_rx.next() => {
156 match msg {
157 Some(Ok(Message::Binary(bytes))) => {
158 if let Ok(frame) = parse_frame(&bytes) {
159 let _ = frame_tx.send(frame).await;
160 }
161 }
162 Some(Ok(Message::Text(text))) => {
163 if let Ok(frame) = serde_json::from_str::<Frame>(&text) {
164 let _ = frame_tx.send(frame).await;
165 }
166 }
167 Some(Ok(Message::Ping(payload))) => {
168 let _ = ws_tx.send(Message::Pong(payload)).await;
169 }
170 Some(Ok(Message::Close(_))) => {
171 break;
172 }
173 Some(Err(_)) => {
174 break;
175 }
176 None => {
177 break;
178 }
179 _ => {}
180 }
181 }
182 cmd = command_rx.recv() => {
183 match cmd {
184 Some(ConnectionCommand::Subscribe(sub)) => {
185 subscriptions.write().await.add(sub.clone());
186 let client_msg = ClientMessage::Subscribe(sub);
187 if let Ok(msg) = serde_json::to_string(&client_msg) {
188 let _ = ws_tx.send(Message::Text(msg)).await;
189 }
190 }
191 Some(ConnectionCommand::Unsubscribe(unsub)) => {
192 let sub = Subscription {
193 view: unsub.view.clone(),
194 key: unsub.key.clone(),
195 partition: None,
196 filters: None,
197 take: None,
198 skip: None,
199 };
200 subscriptions.write().await.remove(&sub);
201 let client_msg = ClientMessage::Unsubscribe(unsub);
202 if let Ok(msg) = serde_json::to_string(&client_msg) {
203 let _ = ws_tx.send(Message::Text(msg)).await;
204 }
205 }
206 Some(ConnectionCommand::Disconnect) => {
207 let _ = ws_tx.close().await;
208 *state.write().await = ConnectionState::Disconnected;
209 should_run = false;
210 break;
211 }
212 None => {
213 should_run = false;
214 break;
215 }
216 }
217 }
218 _ = ping_timer.tick() => {
219 if let Ok(msg) = serde_json::to_string(&ClientMessage::Ping) {
220 let _ = ws_tx.send(Message::Text(msg)).await;
221 }
222 }
223 }
224 }
225 }
226 Err(e) => {
227 tracing::error!("Connection failed: {}", e);
228 }
229 }
230
231 if !should_run {
232 break;
233 }
234
235 if !config.auto_reconnect {
236 *state.write().await = ConnectionState::Error;
237 break;
238 }
239
240 if reconnect_attempt >= config.max_reconnect_attempts {
241 *state.write().await = ConnectionState::Error;
242 break;
243 }
244
245 let delay = config
246 .reconnect_intervals
247 .get(reconnect_attempt as usize)
248 .copied()
249 .unwrap_or_else(|| {
250 config
251 .reconnect_intervals
252 .last()
253 .copied()
254 .unwrap_or(Duration::from_secs(16))
255 });
256
257 *state.write().await = ConnectionState::Reconnecting {
258 attempt: reconnect_attempt,
259 };
260 reconnect_attempt += 1;
261
262 tracing::info!(
263 "Reconnecting in {:?} (attempt {})",
264 delay,
265 reconnect_attempt
266 );
267 sleep(delay).await;
268 }
269 });
270}