1use std::{collections::HashMap, future::Future, time::Duration};
28
29use serde::{Deserialize, Serialize};
30use thiserror::Error;
31
32#[derive(Debug, Error)]
35pub enum BridgeError {
36 #[error("WebSocket error: {0}")]
37 WebSocket(String),
38
39 #[error("Connection failed after {0} attempts")]
40 MaxRetriesExceeded(u32),
41}
42
43#[derive(Debug, Clone)]
47pub struct BridgeConfig {
48 pub initial_backoff_ms: u64,
50 pub max_backoff_ms: u64,
52 pub max_reconnect_attempts: Option<u32>,
54}
55
56impl Default for BridgeConfig {
57 fn default() -> Self {
58 Self {
59 initial_backoff_ms: 1_000,
60 max_backoff_ms: 60_000,
61 max_reconnect_attempts: None,
62 }
63 }
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
68pub struct BridgeMessage {
69 pub id: String,
71 pub content: String,
73 pub source: Option<String>,
75 pub metadata: Option<HashMap<String, String>>,
77}
78
79impl BridgeMessage {
80 pub fn new(content: impl Into<String>) -> Self {
82 Self {
83 id: uuid::Uuid::new_v4().to_string(),
84 content: content.into(),
85 source: None,
86 metadata: None,
87 }
88 }
89
90 pub fn reply(original: &BridgeMessage, content: impl Into<String>) -> Self {
92 Self {
93 id: original.id.clone(),
94 content: content.into(),
95 source: None,
96 metadata: None,
97 }
98 }
99}
100
101pub struct BridgeClient {
105 url: String,
106 config: BridgeConfig,
107}
108
109impl BridgeClient {
110 pub fn new(url: impl Into<String>, config: BridgeConfig) -> Self {
112 Self {
113 url: url.into(),
114 config,
115 }
116 }
117
118 pub fn backoff_duration(&self, attempt: u32) -> Duration {
122 let multiplier = 1u64.checked_shl(attempt.min(62)).unwrap_or(u64::MAX);
124 let ms = self
125 .config
126 .initial_backoff_ms
127 .saturating_mul(multiplier)
128 .min(self.config.max_backoff_ms);
129 Duration::from_millis(ms)
130 }
131
132 pub async fn connect_and_relay<F, Fut>(&self, handler: F) -> Result<(), BridgeError>
141 where
142 F: Fn(BridgeMessage) -> Fut + Clone,
143 Fut: Future<Output = BridgeMessage>,
144 {
145 self.connect_and_relay_bidirectional(handler, None).await
146 }
147
148 pub async fn connect_and_relay_bidirectional<F, Fut>(
154 &self,
155 handler: F,
156 mut proactive_rx: Option<tokio::sync::broadcast::Receiver<BridgeMessage>>,
157 ) -> Result<(), BridgeError>
158 where
159 F: Fn(BridgeMessage) -> Fut + Clone,
160 Fut: Future<Output = BridgeMessage>,
161 {
162 use futures_util::{SinkExt, StreamExt};
163 use tokio_tungstenite::{connect_async, tungstenite::Message};
164
165 let mut attempt = 0u32;
166
167 loop {
168 if let Some(max) = self.config.max_reconnect_attempts {
170 if attempt >= max {
171 return Err(BridgeError::MaxRetriesExceeded(attempt));
172 }
173 }
174
175 if attempt > 0 {
177 let backoff = self.backoff_duration(attempt - 1);
178 tracing::info!(
179 url = %self.url,
180 attempt,
181 backoff_ms = backoff.as_millis(),
182 "Reconnecting to bridge gateway"
183 );
184 tokio::time::sleep(backoff).await;
185 }
186
187 tracing::info!(url = %self.url, "Connecting to bridge gateway");
188
189 let ws_stream = match connect_async(&self.url).await {
190 Err(e) => {
191 tracing::warn!(url = %self.url, error = %e, "Bridge connection failed");
192 attempt += 1;
193 continue;
194 }
195 Ok((ws, _response)) => {
196 tracing::info!(url = %self.url, "Bridge connected");
197 attempt = 0; ws
199 }
200 };
201
202 let (mut sink, mut stream) = ws_stream.split();
203
204 loop {
205 tokio::select! {
206 ws_msg = stream.next() => {
207 match ws_msg {
208 None => {
209 tracing::warn!(url = %self.url, "Bridge stream ended (EOF)");
210 break;
211 }
212 Some(Err(e)) => {
213 tracing::warn!(url = %self.url, error = %e, "Bridge WebSocket error");
214 break;
215 }
216 Some(Ok(Message::Ping(data))) => {
217 if sink.send(Message::Pong(data)).await.is_err() {
218 break;
219 }
220 }
221 Some(Ok(Message::Close(_))) => {
222 tracing::info!(url = %self.url, "Bridge connection closed by remote");
223 break;
224 }
225 Some(Ok(Message::Text(text))) => {
226 let msg: BridgeMessage = match serde_json::from_str(&text) {
227 Ok(m) => m,
228 Err(e) => {
229 tracing::warn!(
230 error = %e,
231 raw = %text,
232 "Ignoring unparseable bridge message"
233 );
234 continue;
235 }
236 };
237
238 let msg_id = msg.id.clone();
239 let response = handler.clone()(msg).await;
240
241 match serde_json::to_string(&response) {
242 Ok(payload) => {
243 if sink.send(Message::Text(payload.into())).await.is_err() {
244 tracing::warn!(id = %msg_id, "Failed to send bridge response");
245 break;
246 }
247 }
248 Err(e) => {
249 tracing::error!(id = %msg_id, error = %e, "Failed to serialise response");
250 }
251 }
252 }
253 Some(Ok(_)) => {} }
255 }
256 proactive = async {
257 match proactive_rx.as_mut() {
258 Some(rx) => rx.recv().await,
259 None => std::future::pending().await,
260 }
261 } => {
262 match proactive {
263 Ok(msg) => {
264 if let Ok(payload) = serde_json::to_string(&msg) {
265 if sink.send(Message::Text(payload.into())).await.is_err() {
266 tracing::warn!("Failed to push proactive to bridge");
267 break;
268 }
269 tracing::debug!(id = %msg.id, "Proactive pushed to bridge");
270 }
271 }
272 Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
273 tracing::warn!(skipped = n, "Bridge proactive receiver lagged");
274 }
275 Err(tokio::sync::broadcast::error::RecvError::Closed) => {
276 tracing::info!("Proactive channel closed, bridge continues in relay-only mode");
277 proactive_rx = None;
278 }
279 }
280 }
281 }
282 }
283
284 attempt += 1;
285 }
286 }
287}
288
289#[cfg(test)]
292mod tests {
293 use super::*;
294
295 #[test]
296 fn test_bridge_config_defaults() {
297 let cfg = BridgeConfig::default();
298 assert_eq!(cfg.initial_backoff_ms, 1_000);
299 assert_eq!(cfg.max_backoff_ms, 60_000);
300 assert!(cfg.max_reconnect_attempts.is_none());
301 }
302
303 #[test]
304 fn test_backoff_grows_exponentially() {
305 let client = BridgeClient::new("ws://unused", BridgeConfig::default());
306 assert_eq!(client.backoff_duration(0), Duration::from_millis(1_000));
308 assert_eq!(client.backoff_duration(1), Duration::from_millis(2_000));
310 assert_eq!(client.backoff_duration(2), Duration::from_millis(4_000));
312 assert_eq!(client.backoff_duration(3), Duration::from_millis(8_000));
314 }
315
316 #[test]
317 fn test_backoff_capped_at_max() {
318 let cfg = BridgeConfig {
319 initial_backoff_ms: 1_000,
320 max_backoff_ms: 5_000,
321 max_reconnect_attempts: None,
322 };
323 let client = BridgeClient::new("ws://unused", cfg);
324 assert_eq!(client.backoff_duration(10), Duration::from_millis(5_000));
326 assert_eq!(client.backoff_duration(30), Duration::from_millis(5_000));
327 }
328
329 #[test]
330 fn test_bridge_message_new() {
331 let msg = BridgeMessage::new("hello");
332 assert!(!msg.id.is_empty());
333 assert_eq!(msg.content, "hello");
334 assert!(msg.source.is_none());
335 assert!(msg.metadata.is_none());
336 }
337
338 #[test]
339 fn test_bridge_message_reply_shares_id() {
340 let original = BridgeMessage::new("what time is it?");
341 let reply = BridgeMessage::reply(&original, "It is noon.");
342 assert_eq!(
343 reply.id, original.id,
344 "reply should reuse original message ID"
345 );
346 assert_eq!(reply.content, "It is noon.");
347 }
348
349 #[test]
350 fn test_bridge_message_roundtrip_json() {
351 let mut meta = HashMap::new();
352 meta.insert("channel".to_string(), "#general".to_string());
353 let original = BridgeMessage {
354 id: "test-id-123".to_string(),
355 content: "Deploy to prod?".to_string(),
356 source: Some("slack".to_string()),
357 metadata: Some(meta),
358 };
359 let json = serde_json::to_string(&original).unwrap();
360 let decoded: BridgeMessage = serde_json::from_str(&json).unwrap();
361 assert_eq!(decoded, original);
362 }
363
364 #[tokio::test]
365 async fn test_connect_fails_and_hits_max_retries() {
366 let cfg = BridgeConfig {
368 initial_backoff_ms: 1, max_backoff_ms: 1,
370 max_reconnect_attempts: Some(3),
371 };
372 let client = BridgeClient::new("ws://127.0.0.1:19999", cfg);
373
374 let result = client
375 .connect_and_relay(|msg| async move { BridgeMessage::reply(&msg, "ok") })
376 .await;
377
378 assert!(
379 matches!(result, Err(BridgeError::MaxRetriesExceeded(3))),
380 "expected MaxRetriesExceeded(3), got: {:?}",
381 result
382 );
383 }
384}