use std::{collections::HashMap, future::Future, time::Duration};
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum BridgeError {
#[error("WebSocket error: {0}")]
WebSocket(String),
#[error("Connection failed after {0} attempts")]
MaxRetriesExceeded(u32),
}
#[derive(Debug, Clone)]
pub struct BridgeConfig {
pub initial_backoff_ms: u64,
pub max_backoff_ms: u64,
pub max_reconnect_attempts: Option<u32>,
}
impl Default for BridgeConfig {
fn default() -> Self {
Self {
initial_backoff_ms: 1_000,
max_backoff_ms: 60_000,
max_reconnect_attempts: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct BridgeMessage {
pub id: String,
pub content: String,
pub source: Option<String>,
pub metadata: Option<HashMap<String, String>>,
}
impl BridgeMessage {
pub fn new(content: impl Into<String>) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
content: content.into(),
source: None,
metadata: None,
}
}
pub fn reply(original: &BridgeMessage, content: impl Into<String>) -> Self {
Self {
id: original.id.clone(),
content: content.into(),
source: None,
metadata: None,
}
}
}
pub struct BridgeClient {
url: String,
config: BridgeConfig,
}
impl BridgeClient {
pub fn new(url: impl Into<String>, config: BridgeConfig) -> Self {
Self {
url: url.into(),
config,
}
}
pub fn backoff_duration(&self, attempt: u32) -> Duration {
let multiplier = 1u64.checked_shl(attempt.min(62)).unwrap_or(u64::MAX);
let ms = self
.config
.initial_backoff_ms
.saturating_mul(multiplier)
.min(self.config.max_backoff_ms);
Duration::from_millis(ms)
}
pub async fn connect_and_relay<F, Fut>(&self, handler: F) -> Result<(), BridgeError>
where
F: Fn(BridgeMessage) -> Fut + Clone,
Fut: Future<Output = BridgeMessage>,
{
self.connect_and_relay_bidirectional(handler, None).await
}
pub async fn connect_and_relay_bidirectional<F, Fut>(
&self,
handler: F,
mut proactive_rx: Option<tokio::sync::broadcast::Receiver<BridgeMessage>>,
) -> Result<(), BridgeError>
where
F: Fn(BridgeMessage) -> Fut + Clone,
Fut: Future<Output = BridgeMessage>,
{
use futures_util::{SinkExt, StreamExt};
use tokio_tungstenite::{connect_async, tungstenite::Message};
let mut attempt = 0u32;
loop {
if let Some(max) = self.config.max_reconnect_attempts {
if attempt >= max {
return Err(BridgeError::MaxRetriesExceeded(attempt));
}
}
if attempt > 0 {
let backoff = self.backoff_duration(attempt - 1);
tracing::info!(
url = %self.url,
attempt,
backoff_ms = backoff.as_millis(),
"Reconnecting to bridge gateway"
);
tokio::time::sleep(backoff).await;
}
tracing::info!(url = %self.url, "Connecting to bridge gateway");
let ws_stream = match connect_async(&self.url).await {
Err(e) => {
tracing::warn!(url = %self.url, error = %e, "Bridge connection failed");
attempt += 1;
continue;
}
Ok((ws, _response)) => {
tracing::info!(url = %self.url, "Bridge connected");
attempt = 0; ws
}
};
let (mut sink, mut stream) = ws_stream.split();
loop {
tokio::select! {
ws_msg = stream.next() => {
match ws_msg {
None => {
tracing::warn!(url = %self.url, "Bridge stream ended (EOF)");
break;
}
Some(Err(e)) => {
tracing::warn!(url = %self.url, error = %e, "Bridge WebSocket error");
break;
}
Some(Ok(Message::Ping(data))) => {
if sink.send(Message::Pong(data)).await.is_err() {
break;
}
}
Some(Ok(Message::Close(_))) => {
tracing::info!(url = %self.url, "Bridge connection closed by remote");
break;
}
Some(Ok(Message::Text(text))) => {
let msg: BridgeMessage = match serde_json::from_str(&text) {
Ok(m) => m,
Err(e) => {
tracing::warn!(
error = %e,
raw = %text,
"Ignoring unparseable bridge message"
);
continue;
}
};
let msg_id = msg.id.clone();
let response = handler.clone()(msg).await;
match serde_json::to_string(&response) {
Ok(payload) => {
if sink.send(Message::Text(payload.into())).await.is_err() {
tracing::warn!(id = %msg_id, "Failed to send bridge response");
break;
}
}
Err(e) => {
tracing::error!(id = %msg_id, error = %e, "Failed to serialise response");
}
}
}
Some(Ok(_)) => {} }
}
proactive = async {
match proactive_rx.as_mut() {
Some(rx) => rx.recv().await,
None => std::future::pending().await,
}
} => {
match proactive {
Ok(msg) => {
if let Ok(payload) = serde_json::to_string(&msg) {
if sink.send(Message::Text(payload.into())).await.is_err() {
tracing::warn!("Failed to push proactive to bridge");
break;
}
tracing::debug!(id = %msg.id, "Proactive pushed to bridge");
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!(skipped = n, "Bridge proactive receiver lagged");
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
tracing::info!("Proactive channel closed, bridge continues in relay-only mode");
proactive_rx = None;
}
}
}
}
}
attempt += 1;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bridge_config_defaults() {
let cfg = BridgeConfig::default();
assert_eq!(cfg.initial_backoff_ms, 1_000);
assert_eq!(cfg.max_backoff_ms, 60_000);
assert!(cfg.max_reconnect_attempts.is_none());
}
#[test]
fn test_backoff_grows_exponentially() {
let client = BridgeClient::new("ws://unused", BridgeConfig::default());
assert_eq!(client.backoff_duration(0), Duration::from_millis(1_000));
assert_eq!(client.backoff_duration(1), Duration::from_millis(2_000));
assert_eq!(client.backoff_duration(2), Duration::from_millis(4_000));
assert_eq!(client.backoff_duration(3), Duration::from_millis(8_000));
}
#[test]
fn test_backoff_capped_at_max() {
let cfg = BridgeConfig {
initial_backoff_ms: 1_000,
max_backoff_ms: 5_000,
max_reconnect_attempts: None,
};
let client = BridgeClient::new("ws://unused", cfg);
assert_eq!(client.backoff_duration(10), Duration::from_millis(5_000));
assert_eq!(client.backoff_duration(30), Duration::from_millis(5_000));
}
#[test]
fn test_bridge_message_new() {
let msg = BridgeMessage::new("hello");
assert!(!msg.id.is_empty());
assert_eq!(msg.content, "hello");
assert!(msg.source.is_none());
assert!(msg.metadata.is_none());
}
#[test]
fn test_bridge_message_reply_shares_id() {
let original = BridgeMessage::new("what time is it?");
let reply = BridgeMessage::reply(&original, "It is noon.");
assert_eq!(
reply.id, original.id,
"reply should reuse original message ID"
);
assert_eq!(reply.content, "It is noon.");
}
#[test]
fn test_bridge_message_roundtrip_json() {
let mut meta = HashMap::new();
meta.insert("channel".to_string(), "#general".to_string());
let original = BridgeMessage {
id: "test-id-123".to_string(),
content: "Deploy to prod?".to_string(),
source: Some("slack".to_string()),
metadata: Some(meta),
};
let json = serde_json::to_string(&original).unwrap();
let decoded: BridgeMessage = serde_json::from_str(&json).unwrap();
assert_eq!(decoded, original);
}
#[tokio::test]
async fn test_connect_fails_and_hits_max_retries() {
let cfg = BridgeConfig {
initial_backoff_ms: 1, max_backoff_ms: 1,
max_reconnect_attempts: Some(3),
};
let client = BridgeClient::new("ws://127.0.0.1:19999", cfg);
let result = client
.connect_and_relay(|msg| async move { BridgeMessage::reply(&msg, "ok") })
.await;
assert!(
matches!(result, Err(BridgeError::MaxRetriesExceeded(3))),
"expected MaxRetriesExceeded(3), got: {:?}",
result
);
}
}