hyperstack_sdk/
connection.rs1use crate::config::ConnectionConfig;
2use crate::frame::{parse_frame, Frame};
3use crate::subscription::{Subscription, SubscriptionRegistry};
4use futures_util::{SinkExt, StreamExt};
5use std::sync::Arc;
6use tokio::sync::{mpsc, RwLock};
7use tokio::time::{sleep, Duration};
8use tokio_tungstenite::{connect_async, tungstenite::Message};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum ConnectionState {
12 Disconnected,
13 Connecting,
14 Connected,
15 Reconnecting { attempt: u32 },
16 Error,
17}
18
19pub enum ConnectionCommand {
20 Subscribe(Subscription),
21 #[allow(dead_code)]
22 Unsubscribe(Subscription),
23 Disconnect,
24}
25
26struct ConnectionManagerInner {
27 #[allow(dead_code)]
28 url: String,
29 state: Arc<RwLock<ConnectionState>>,
30 subscriptions: Arc<RwLock<SubscriptionRegistry>>,
31 #[allow(dead_code)]
32 config: ConnectionConfig,
33 command_tx: mpsc::Sender<ConnectionCommand>,
34}
35
36#[derive(Clone)]
37pub struct ConnectionManager {
38 inner: Arc<ConnectionManagerInner>,
39}
40
41impl ConnectionManager {
42 pub async fn new(url: String, config: ConnectionConfig, frame_tx: mpsc::Sender<Frame>) -> Self {
43 let (command_tx, command_rx) = mpsc::channel(100);
44 let state = Arc::new(RwLock::new(ConnectionState::Disconnected));
45 let subscriptions = Arc::new(RwLock::new(SubscriptionRegistry::new()));
46
47 let inner = ConnectionManagerInner {
48 url: url.clone(),
49 state: state.clone(),
50 subscriptions: subscriptions.clone(),
51 config: config.clone(),
52 command_tx,
53 };
54
55 spawn_connection_loop(url, state, subscriptions, config, frame_tx, command_rx);
56
57 Self {
58 inner: Arc::new(inner),
59 }
60 }
61
62 pub async fn state(&self) -> ConnectionState {
63 *self.inner.state.read().await
64 }
65
66 pub async fn ensure_subscription(&self, view: &str, key: Option<&str>) {
67 let sub = Subscription {
68 view: view.to_string(),
69 key: key.map(|s| s.to_string()),
70 partition: None,
71 filters: None,
72 };
73
74 if !self.inner.subscriptions.read().await.contains(&sub) {
75 let _ = self
76 .inner
77 .command_tx
78 .send(ConnectionCommand::Subscribe(sub))
79 .await;
80 }
81 }
82
83 #[allow(dead_code)]
84 pub async fn subscribe(&self, sub: Subscription) {
85 let _ = self
86 .inner
87 .command_tx
88 .send(ConnectionCommand::Subscribe(sub))
89 .await;
90 }
91
92 pub async fn disconnect(&self) {
93 let _ = self
94 .inner
95 .command_tx
96 .send(ConnectionCommand::Disconnect)
97 .await;
98 }
99}
100
101fn spawn_connection_loop(
102 url: String,
103 state: Arc<RwLock<ConnectionState>>,
104 subscriptions: Arc<RwLock<SubscriptionRegistry>>,
105 config: ConnectionConfig,
106 frame_tx: mpsc::Sender<Frame>,
107 mut command_rx: mpsc::Receiver<ConnectionCommand>,
108) {
109 tokio::spawn(async move {
110 let mut reconnect_attempt: u32 = 0;
111 let mut should_run = true;
112
113 while should_run {
114 *state.write().await = ConnectionState::Connecting;
115
116 match connect_async(&url).await {
117 Ok((ws, _)) => {
118 *state.write().await = ConnectionState::Connected;
119 reconnect_attempt = 0;
120
121 let (mut ws_tx, mut ws_rx) = ws.split();
122
123 let subs = subscriptions.read().await.all();
124 for sub in subs {
125 if let Ok(msg) = serde_json::to_string(&sub) {
126 let _ = ws_tx.send(Message::Text(msg)).await;
127 }
128 }
129
130 let ping_interval = config.ping_interval;
131 let mut ping_timer = tokio::time::interval(ping_interval);
132
133 loop {
134 tokio::select! {
135 msg = ws_rx.next() => {
136 match msg {
137 Some(Ok(Message::Binary(bytes))) => {
138 if let Ok(frame) = parse_frame(&bytes) {
139 let _ = frame_tx.send(frame).await;
140 }
141 }
142 Some(Ok(Message::Text(text))) => {
143 if let Ok(frame) = serde_json::from_str::<Frame>(&text) {
144 let _ = frame_tx.send(frame).await;
145 }
146 }
147 Some(Ok(Message::Ping(payload))) => {
148 let _ = ws_tx.send(Message::Pong(payload)).await;
149 }
150 Some(Ok(Message::Close(_))) => {
151 break;
152 }
153 Some(Err(_)) => {
154 break;
155 }
156 None => {
157 break;
158 }
159 _ => {}
160 }
161 }
162 cmd = command_rx.recv() => {
163 match cmd {
164 Some(ConnectionCommand::Subscribe(sub)) => {
165 subscriptions.write().await.add(sub.clone());
166 if let Ok(msg) = serde_json::to_string(&sub) {
167 let _ = ws_tx.send(Message::Text(msg)).await;
168 }
169 }
170 Some(ConnectionCommand::Unsubscribe(sub)) => {
171 subscriptions.write().await.remove(&sub);
172 }
173 Some(ConnectionCommand::Disconnect) => {
174 let _ = ws_tx.close().await;
175 *state.write().await = ConnectionState::Disconnected;
176 should_run = false;
177 break;
178 }
179 None => {
180 should_run = false;
181 break;
182 }
183 }
184 }
185 _ = ping_timer.tick() => {
186 let _ = ws_tx.send(Message::Ping(vec![])).await;
187 }
188 }
189 }
190 }
191 Err(e) => {
192 tracing::error!("Connection failed: {}", e);
193 }
194 }
195
196 if !should_run {
197 break;
198 }
199
200 if !config.auto_reconnect {
201 *state.write().await = ConnectionState::Error;
202 break;
203 }
204
205 if reconnect_attempt >= config.max_reconnect_attempts {
206 *state.write().await = ConnectionState::Error;
207 break;
208 }
209
210 let delay = config
211 .reconnect_intervals
212 .get(reconnect_attempt as usize)
213 .copied()
214 .unwrap_or_else(|| {
215 config
216 .reconnect_intervals
217 .last()
218 .copied()
219 .unwrap_or(Duration::from_secs(16))
220 });
221
222 *state.write().await = ConnectionState::Reconnecting {
223 attempt: reconnect_attempt,
224 };
225 reconnect_attempt += 1;
226
227 tracing::info!(
228 "Reconnecting in {:?} (attempt {})",
229 delay,
230 reconnect_attempt
231 );
232 sleep(delay).await;
233 }
234 });
235}