hyperstack_sdk/
connection.rs1use crate::config::ConnectionConfig;
2use crate::frame::{parse_frame, Frame};
3use crate::subscription::{ClientMessage, Subscription, SubscriptionRegistry, Unsubscription};
4use futures_util::{SinkExt, StreamExt};
5use std::sync::Arc;
6use tokio::sync::{mpsc, RwLock};
7use tokio::time::{sleep, Duration};
8use tokio_tungstenite::{connect_async, tungstenite::Message};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum ConnectionState {
12 Disconnected,
13 Connecting,
14 Connected,
15 Reconnecting { attempt: u32 },
16 Error,
17}
18
19pub enum ConnectionCommand {
20 Subscribe(Subscription),
21 Unsubscribe(Unsubscription),
22 Disconnect,
23}
24
25struct ConnectionManagerInner {
26 #[allow(dead_code)]
27 url: String,
28 state: Arc<RwLock<ConnectionState>>,
29 subscriptions: Arc<RwLock<SubscriptionRegistry>>,
30 #[allow(dead_code)]
31 config: ConnectionConfig,
32 command_tx: mpsc::Sender<ConnectionCommand>,
33}
34
35#[derive(Clone)]
36pub struct ConnectionManager {
37 inner: Arc<ConnectionManagerInner>,
38}
39
40impl ConnectionManager {
41 pub async fn new(url: String, config: ConnectionConfig, frame_tx: mpsc::Sender<Frame>) -> Self {
42 let (command_tx, command_rx) = mpsc::channel(100);
43 let state = Arc::new(RwLock::new(ConnectionState::Disconnected));
44 let subscriptions = Arc::new(RwLock::new(SubscriptionRegistry::new()));
45
46 let inner = ConnectionManagerInner {
47 url: url.clone(),
48 state: state.clone(),
49 subscriptions: subscriptions.clone(),
50 config: config.clone(),
51 command_tx,
52 };
53
54 spawn_connection_loop(url, state, subscriptions, config, frame_tx, command_rx);
55
56 Self {
57 inner: Arc::new(inner),
58 }
59 }
60
61 pub async fn state(&self) -> ConnectionState {
62 *self.inner.state.read().await
63 }
64
65 pub async fn ensure_subscription(&self, view: &str, key: Option<&str>) {
66 let sub = Subscription {
67 view: view.to_string(),
68 key: key.map(|s| s.to_string()),
69 partition: None,
70 filters: None,
71 };
72
73 if !self.inner.subscriptions.read().await.contains(&sub) {
74 let _ = self
75 .inner
76 .command_tx
77 .send(ConnectionCommand::Subscribe(sub))
78 .await;
79 }
80 }
81
82 pub async fn subscribe(&self, sub: Subscription) {
83 let _ = self
84 .inner
85 .command_tx
86 .send(ConnectionCommand::Subscribe(sub))
87 .await;
88 }
89
90 pub async fn unsubscribe(&self, unsub: Unsubscription) {
91 let _ = self
92 .inner
93 .command_tx
94 .send(ConnectionCommand::Unsubscribe(unsub))
95 .await;
96 }
97
98 pub async fn disconnect(&self) {
99 let _ = self
100 .inner
101 .command_tx
102 .send(ConnectionCommand::Disconnect)
103 .await;
104 }
105}
106
107fn spawn_connection_loop(
108 url: String,
109 state: Arc<RwLock<ConnectionState>>,
110 subscriptions: Arc<RwLock<SubscriptionRegistry>>,
111 config: ConnectionConfig,
112 frame_tx: mpsc::Sender<Frame>,
113 mut command_rx: mpsc::Receiver<ConnectionCommand>,
114) {
115 tokio::spawn(async move {
116 let mut reconnect_attempt: u32 = 0;
117 let mut should_run = true;
118
119 while should_run {
120 *state.write().await = ConnectionState::Connecting;
121
122 match connect_async(&url).await {
123 Ok((ws, _)) => {
124 *state.write().await = ConnectionState::Connected;
125 reconnect_attempt = 0;
126
127 let (mut ws_tx, mut ws_rx) = ws.split();
128
129 let subs = subscriptions.read().await.all();
130 for sub in subs {
131 let client_msg = ClientMessage::Subscribe(sub);
132 if let Ok(msg) = serde_json::to_string(&client_msg) {
133 let _ = ws_tx.send(Message::Text(msg)).await;
134 }
135 }
136
137 let ping_interval = config.ping_interval;
138 let mut ping_timer = tokio::time::interval(ping_interval);
139
140 loop {
141 tokio::select! {
142 msg = ws_rx.next() => {
143 match msg {
144 Some(Ok(Message::Binary(bytes))) => {
145 if let Ok(frame) = parse_frame(&bytes) {
146 let _ = frame_tx.send(frame).await;
147 }
148 }
149 Some(Ok(Message::Text(text))) => {
150 if let Ok(frame) = serde_json::from_str::<Frame>(&text) {
151 let _ = frame_tx.send(frame).await;
152 }
153 }
154 Some(Ok(Message::Ping(payload))) => {
155 let _ = ws_tx.send(Message::Pong(payload)).await;
156 }
157 Some(Ok(Message::Close(_))) => {
158 break;
159 }
160 Some(Err(_)) => {
161 break;
162 }
163 None => {
164 break;
165 }
166 _ => {}
167 }
168 }
169 cmd = command_rx.recv() => {
170 match cmd {
171 Some(ConnectionCommand::Subscribe(sub)) => {
172 subscriptions.write().await.add(sub.clone());
173 let client_msg = ClientMessage::Subscribe(sub);
174 if let Ok(msg) = serde_json::to_string(&client_msg) {
175 let _ = ws_tx.send(Message::Text(msg)).await;
176 }
177 }
178 Some(ConnectionCommand::Unsubscribe(unsub)) => {
179 let sub = Subscription {
180 view: unsub.view.clone(),
181 key: unsub.key.clone(),
182 partition: None,
183 filters: None,
184 };
185 subscriptions.write().await.remove(&sub);
186 let client_msg = ClientMessage::Unsubscribe(unsub);
187 if let Ok(msg) = serde_json::to_string(&client_msg) {
188 let _ = ws_tx.send(Message::Text(msg)).await;
189 }
190 }
191 Some(ConnectionCommand::Disconnect) => {
192 let _ = ws_tx.close().await;
193 *state.write().await = ConnectionState::Disconnected;
194 should_run = false;
195 break;
196 }
197 None => {
198 should_run = false;
199 break;
200 }
201 }
202 }
203 _ = ping_timer.tick() => {
204 if let Ok(msg) = serde_json::to_string(&ClientMessage::Ping) {
205 let _ = ws_tx.send(Message::Text(msg)).await;
206 }
207 }
208 }
209 }
210 }
211 Err(e) => {
212 tracing::error!("Connection failed: {}", e);
213 }
214 }
215
216 if !should_run {
217 break;
218 }
219
220 if !config.auto_reconnect {
221 *state.write().await = ConnectionState::Error;
222 break;
223 }
224
225 if reconnect_attempt >= config.max_reconnect_attempts {
226 *state.write().await = ConnectionState::Error;
227 break;
228 }
229
230 let delay = config
231 .reconnect_intervals
232 .get(reconnect_attempt as usize)
233 .copied()
234 .unwrap_or_else(|| {
235 config
236 .reconnect_intervals
237 .last()
238 .copied()
239 .unwrap_or(Duration::from_secs(16))
240 });
241
242 *state.write().await = ConnectionState::Reconnecting {
243 attempt: reconnect_attempt,
244 };
245 reconnect_attempt += 1;
246
247 tracing::info!(
248 "Reconnecting in {:?} (attempt {})",
249 delay,
250 reconnect_attempt
251 );
252 sleep(delay).await;
253 }
254 });
255}