1use std::{collections::HashMap, future::Future, time::Duration};
28
29use serde::{Deserialize, Serialize};
30use thiserror::Error;
31
32#[derive(Debug, Error)]
35pub enum BridgeError {
36 #[error("WebSocket error: {0}")]
37 WebSocket(String),
38
39 #[error("Connection failed after {0} attempts")]
40 MaxRetriesExceeded(u32),
41}
42
43#[derive(Debug, Clone)]
47pub struct BridgeConfig {
48 pub initial_backoff_ms: u64,
50 pub max_backoff_ms: u64,
52 pub max_reconnect_attempts: Option<u32>,
54}
55
56impl Default for BridgeConfig {
57 fn default() -> Self {
58 Self {
59 initial_backoff_ms: 1_000,
60 max_backoff_ms: 60_000,
61 max_reconnect_attempts: None,
62 }
63 }
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
68pub struct BridgeMessage {
69 pub id: String,
71 pub content: String,
73 pub source: Option<String>,
76 pub metadata: Option<HashMap<String, String>>,
78}
79
80impl BridgeMessage {
81 pub fn new(content: impl Into<String>) -> Self {
83 Self {
84 id: uuid::Uuid::new_v4().to_string(),
85 content: content.into(),
86 source: None,
87 metadata: None,
88 }
89 }
90
91 pub fn reply(original: &BridgeMessage, content: impl Into<String>) -> Self {
93 Self {
94 id: original.id.clone(),
95 content: content.into(),
96 source: None,
97 metadata: None,
98 }
99 }
100}
101
102pub struct BridgeClient {
106 url: String,
107 config: BridgeConfig,
108}
109
110impl BridgeClient {
111 pub fn new(url: impl Into<String>, config: BridgeConfig) -> Self {
113 Self {
114 url: url.into(),
115 config,
116 }
117 }
118
119 pub fn backoff_duration(&self, attempt: u32) -> Duration {
123 let multiplier = 1u64.checked_shl(attempt.min(62)).unwrap_or(u64::MAX);
125 let ms = self
126 .config
127 .initial_backoff_ms
128 .saturating_mul(multiplier)
129 .min(self.config.max_backoff_ms);
130 Duration::from_millis(ms)
131 }
132
133 pub async fn connect_and_relay<F, Fut>(&self, handler: F) -> Result<(), BridgeError>
142 where
143 F: Fn(BridgeMessage) -> Fut + Clone,
144 Fut: Future<Output = BridgeMessage>,
145 {
146 self.connect_and_relay_bidirectional(handler, None).await
147 }
148
149 pub async fn connect_and_relay_bidirectional<F, Fut>(
155 &self,
156 handler: F,
157 mut proactive_rx: Option<tokio::sync::broadcast::Receiver<BridgeMessage>>,
158 ) -> Result<(), BridgeError>
159 where
160 F: Fn(BridgeMessage) -> Fut + Clone,
161 Fut: Future<Output = BridgeMessage>,
162 {
163 use futures_util::{SinkExt, StreamExt};
164 use tokio_tungstenite::{connect_async, tungstenite::Message};
165
166 let mut attempt = 0u32;
167
168 loop {
169 if let Some(max) = self.config.max_reconnect_attempts {
171 if attempt >= max {
172 return Err(BridgeError::MaxRetriesExceeded(attempt));
173 }
174 }
175
176 if attempt > 0 {
178 let backoff = self.backoff_duration(attempt - 1);
179 tracing::info!(
180 url = %self.url,
181 attempt,
182 backoff_ms = backoff.as_millis(),
183 "Reconnecting to bridge gateway"
184 );
185 tokio::time::sleep(backoff).await;
186 }
187
188 tracing::info!(url = %self.url, "Connecting to bridge gateway");
189
190 let ws_stream = match connect_async(&self.url).await {
191 Err(e) => {
192 tracing::warn!(url = %self.url, error = %e, "Bridge connection failed");
193 attempt += 1;
194 continue;
195 }
196 Ok((ws, _response)) => {
197 tracing::info!(url = %self.url, "Bridge connected");
198 attempt = 0; ws
200 }
201 };
202
203 let (mut sink, mut stream) = ws_stream.split();
204
205 loop {
206 tokio::select! {
207 ws_msg = stream.next() => {
208 match ws_msg {
209 None => {
210 tracing::warn!(url = %self.url, "Bridge stream ended (EOF)");
211 break;
212 }
213 Some(Err(e)) => {
214 tracing::warn!(url = %self.url, error = %e, "Bridge WebSocket error");
215 break;
216 }
217 Some(Ok(Message::Ping(data))) => {
218 if sink.send(Message::Pong(data)).await.is_err() {
219 break;
220 }
221 }
222 Some(Ok(Message::Close(_))) => {
223 tracing::info!(url = %self.url, "Bridge connection closed by remote");
224 break;
225 }
226 Some(Ok(Message::Text(text))) => {
227 let msg: BridgeMessage = match serde_json::from_str(&text) {
228 Ok(m) => m,
229 Err(e) => {
230 tracing::warn!(
231 error = %e,
232 raw = %text,
233 "Ignoring unparseable bridge message"
234 );
235 continue;
236 }
237 };
238
239 let msg_id = msg.id.clone();
240 let response = handler.clone()(msg).await;
241
242 match serde_json::to_string(&response) {
243 Ok(payload) => {
244 if sink.send(Message::Text(payload.into())).await.is_err() {
245 tracing::warn!(id = %msg_id, "Failed to send bridge response");
246 break;
247 }
248 }
249 Err(e) => {
250 tracing::error!(id = %msg_id, error = %e, "Failed to serialise response");
251 }
252 }
253 }
254 Some(Ok(_)) => {} }
256 }
257 proactive = async {
258 match proactive_rx.as_mut() {
259 Some(rx) => rx.recv().await,
260 None => std::future::pending().await,
261 }
262 } => {
263 match proactive {
264 Ok(msg) => {
265 if let Ok(payload) = serde_json::to_string(&msg) {
266 if sink.send(Message::Text(payload.into())).await.is_err() {
267 tracing::warn!("Failed to push proactive to bridge");
268 break;
269 }
270 tracing::debug!(id = %msg.id, "Proactive pushed to bridge");
271 }
272 }
273 Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
274 tracing::warn!(skipped = n, "Bridge proactive receiver lagged");
275 }
276 Err(tokio::sync::broadcast::error::RecvError::Closed) => {
277 tracing::info!("Proactive channel closed, bridge continues in relay-only mode");
278 proactive_rx = None;
279 }
280 }
281 }
282 }
283 }
284
285 attempt += 1;
286 }
287 }
288}
289
290#[cfg(test)]
293mod tests {
294 use super::*;
295
296 #[test]
297 fn test_bridge_config_defaults() {
298 let cfg = BridgeConfig::default();
299 assert_eq!(cfg.initial_backoff_ms, 1_000);
300 assert_eq!(cfg.max_backoff_ms, 60_000);
301 assert!(cfg.max_reconnect_attempts.is_none());
302 }
303
304 #[test]
305 fn test_backoff_grows_exponentially() {
306 let client = BridgeClient::new("ws://unused", BridgeConfig::default());
307 assert_eq!(client.backoff_duration(0), Duration::from_millis(1_000));
309 assert_eq!(client.backoff_duration(1), Duration::from_millis(2_000));
311 assert_eq!(client.backoff_duration(2), Duration::from_millis(4_000));
313 assert_eq!(client.backoff_duration(3), Duration::from_millis(8_000));
315 }
316
317 #[test]
318 fn test_backoff_capped_at_max() {
319 let cfg = BridgeConfig {
320 initial_backoff_ms: 1_000,
321 max_backoff_ms: 5_000,
322 max_reconnect_attempts: None,
323 };
324 let client = BridgeClient::new("ws://unused", cfg);
325 assert_eq!(client.backoff_duration(10), Duration::from_millis(5_000));
327 assert_eq!(client.backoff_duration(30), Duration::from_millis(5_000));
328 }
329
330 #[test]
331 fn test_bridge_message_new() {
332 let msg = BridgeMessage::new("hello");
333 assert!(!msg.id.is_empty());
334 assert_eq!(msg.content, "hello");
335 assert!(msg.source.is_none());
336 assert!(msg.metadata.is_none());
337 }
338
339 #[test]
340 fn test_bridge_message_reply_shares_id() {
341 let original = BridgeMessage::new("what time is it?");
342 let reply = BridgeMessage::reply(&original, "It is noon.");
343 assert_eq!(
344 reply.id, original.id,
345 "reply should reuse original message ID"
346 );
347 assert_eq!(reply.content, "It is noon.");
348 }
349
350 #[test]
351 fn test_bridge_message_roundtrip_json() {
352 let mut meta = HashMap::new();
353 meta.insert("channel".to_string(), "#general".to_string());
354 let original = BridgeMessage {
355 id: "test-id-123".to_string(),
356 content: "Deploy to prod?".to_string(),
357 source: Some("chat-main".to_string()),
358 metadata: Some(meta),
359 };
360 let json = serde_json::to_string(&original).unwrap();
361 let decoded: BridgeMessage = serde_json::from_str(&json).unwrap();
362 assert_eq!(decoded, original);
363 }
364
365 #[tokio::test]
366 async fn test_connect_fails_and_hits_max_retries() {
367 let cfg = BridgeConfig {
369 initial_backoff_ms: 1, max_backoff_ms: 1,
371 max_reconnect_attempts: Some(3),
372 };
373 let client = BridgeClient::new("ws://127.0.0.1:19999", cfg);
374
375 let result = client
376 .connect_and_relay(|msg| async move { BridgeMessage::reply(&msg, "ok") })
377 .await;
378
379 assert!(
380 matches!(result, Err(BridgeError::MaxRetriesExceeded(3))),
381 "expected MaxRetriesExceeded(3), got: {:?}",
382 result
383 );
384 }
385}