hyperstack_sdk/
connection.rs1use crate::config::ConnectionConfig;
2use crate::frame::{parse_frame, Frame};
3use crate::subscription::{Subscription, SubscriptionRegistry};
4use futures_util::{SinkExt, StreamExt};
5use std::sync::Arc;
6use tokio::sync::{mpsc, RwLock};
7use tokio::time::{sleep, Duration};
8use tokio_tungstenite::{connect_async, tungstenite::Message};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum ConnectionState {
12 Disconnected,
13 Connecting,
14 Connected,
15 Reconnecting { attempt: u32 },
16 Error,
17}
18
19pub enum ConnectionCommand {
20 Subscribe(Subscription),
21 #[allow(dead_code)]
22 Unsubscribe(Subscription),
23 Disconnect,
24}
25
26pub struct ConnectionManager {
27 #[allow(dead_code)]
28 url: String,
29 state: Arc<RwLock<ConnectionState>>,
30 subscriptions: Arc<RwLock<SubscriptionRegistry>>,
31 #[allow(dead_code)]
32 config: ConnectionConfig,
33 command_tx: mpsc::Sender<ConnectionCommand>,
34}
35
36impl ConnectionManager {
37 pub async fn new(url: String, config: ConnectionConfig, frame_tx: mpsc::Sender<Frame>) -> Self {
38 let (command_tx, command_rx) = mpsc::channel(100);
39 let state = Arc::new(RwLock::new(ConnectionState::Disconnected));
40 let subscriptions = Arc::new(RwLock::new(SubscriptionRegistry::new()));
41
42 let manager = Self {
43 url: url.clone(),
44 state: state.clone(),
45 subscriptions: subscriptions.clone(),
46 config: config.clone(),
47 command_tx,
48 };
49
50 spawn_connection_loop(url, state, subscriptions, config, frame_tx, command_rx);
51
52 manager
53 }
54
55 pub async fn state(&self) -> ConnectionState {
56 *self.state.read().await
57 }
58
59 pub async fn ensure_subscription(&self, view: &str, key: Option<&str>) {
60 let sub = Subscription {
61 view: view.to_string(),
62 key: key.map(|s| s.to_string()),
63 partition: None,
64 filters: None,
65 };
66
67 if !self.subscriptions.read().await.contains(&sub) {
68 let _ = self
69 .command_tx
70 .send(ConnectionCommand::Subscribe(sub))
71 .await;
72 }
73 }
74
75 #[allow(dead_code)]
76 pub async fn subscribe(&self, sub: Subscription) {
77 let _ = self
78 .command_tx
79 .send(ConnectionCommand::Subscribe(sub))
80 .await;
81 }
82
83 pub async fn disconnect(&self) {
84 let _ = self.command_tx.send(ConnectionCommand::Disconnect).await;
85 }
86}
87
88fn spawn_connection_loop(
89 url: String,
90 state: Arc<RwLock<ConnectionState>>,
91 subscriptions: Arc<RwLock<SubscriptionRegistry>>,
92 config: ConnectionConfig,
93 frame_tx: mpsc::Sender<Frame>,
94 mut command_rx: mpsc::Receiver<ConnectionCommand>,
95) {
96 tokio::spawn(async move {
97 let mut reconnect_attempt: u32 = 0;
98 let mut should_run = true;
99
100 while should_run {
101 *state.write().await = ConnectionState::Connecting;
102
103 match connect_async(&url).await {
104 Ok((ws, _)) => {
105 *state.write().await = ConnectionState::Connected;
106 reconnect_attempt = 0;
107
108 let (mut ws_tx, mut ws_rx) = ws.split();
109
110 let subs = subscriptions.read().await.all();
111 for sub in subs {
112 if let Ok(msg) = serde_json::to_string(&sub) {
113 let _ = ws_tx.send(Message::Text(msg)).await;
114 }
115 }
116
117 let ping_interval = config.ping_interval;
118 let mut ping_timer = tokio::time::interval(ping_interval);
119
120 loop {
121 tokio::select! {
122 msg = ws_rx.next() => {
123 match msg {
124 Some(Ok(Message::Binary(bytes))) => {
125 if let Ok(frame) = parse_frame(&bytes) {
126 let _ = frame_tx.send(frame).await;
127 }
128 }
129 Some(Ok(Message::Text(text))) => {
130 if let Ok(frame) = serde_json::from_str::<Frame>(&text) {
131 let _ = frame_tx.send(frame).await;
132 }
133 }
134 Some(Ok(Message::Ping(payload))) => {
135 let _ = ws_tx.send(Message::Pong(payload)).await;
136 }
137 Some(Ok(Message::Close(_))) => {
138 break;
139 }
140 Some(Err(_)) => {
141 break;
142 }
143 None => {
144 break;
145 }
146 _ => {}
147 }
148 }
149 cmd = command_rx.recv() => {
150 match cmd {
151 Some(ConnectionCommand::Subscribe(sub)) => {
152 subscriptions.write().await.add(sub.clone());
153 if let Ok(msg) = serde_json::to_string(&sub) {
154 let _ = ws_tx.send(Message::Text(msg)).await;
155 }
156 }
157 Some(ConnectionCommand::Unsubscribe(sub)) => {
158 subscriptions.write().await.remove(&sub);
159 }
160 Some(ConnectionCommand::Disconnect) => {
161 let _ = ws_tx.close().await;
162 *state.write().await = ConnectionState::Disconnected;
163 should_run = false;
164 break;
165 }
166 None => {
167 should_run = false;
168 break;
169 }
170 }
171 }
172 _ = ping_timer.tick() => {
173 let _ = ws_tx.send(Message::Ping(vec![])).await;
174 }
175 }
176 }
177 }
178 Err(e) => {
179 tracing::error!("Connection failed: {}", e);
180 }
181 }
182
183 if !should_run {
184 break;
185 }
186
187 if !config.auto_reconnect {
188 *state.write().await = ConnectionState::Error;
189 break;
190 }
191
192 if reconnect_attempt >= config.max_reconnect_attempts {
193 *state.write().await = ConnectionState::Error;
194 break;
195 }
196
197 let delay = config
198 .reconnect_intervals
199 .get(reconnect_attempt as usize)
200 .copied()
201 .unwrap_or_else(|| {
202 config
203 .reconnect_intervals
204 .last()
205 .copied()
206 .unwrap_or(Duration::from_secs(16))
207 });
208
209 *state.write().await = ConnectionState::Reconnecting {
210 attempt: reconnect_attempt,
211 };
212 reconnect_attempt += 1;
213
214 tracing::info!(
215 "Reconnecting in {:?} (attempt {})",
216 delay,
217 reconnect_attempt
218 );
219 sleep(delay).await;
220 }
221 });
222}