Skip to main content

brainos_bridge/
lib.rs

1//! # Brain Bridge
2//!
3//! External service relay — WebSocket client that connects Brain to a
4//! remote messaging gateway and relays inbound messages through Brain's
5//! signal processing pipeline.
6//!
7//! ## Protocol
8//!
9//! 1. [`BridgeClient`] connects to the configured `url` via WebSocket.
10//! 2. Each inbound text frame is parsed as a JSON-encoded [`BridgeMessage`].
11//!    Unparseable text frames are logged and skipped; the connection stays
12//!    open. Binary frames are silently ignored.
13//! 3. WebSocket `Ping` frames are answered with `Pong` automatically.
14//!    `Close` frames trigger the reconnect loop after the configured
15//!    backoff.
16//! 4. A caller-supplied `handler` is invoked for each parsed message and
17//!    must return a [`BridgeMessage`] response (the response keeps the
18//!    original message's `id` if built via [`BridgeMessage::reply`]).
19//! 5. The response is serialised and sent back as a text frame. Failure
20//!    to send tears down the connection and triggers the reconnect loop;
21//!    serialisation failure logs and drops the response only.
22//! 6. On disconnect, the client waits `initial_backoff_ms * 2^attempt`
23//!    (capped at `max_backoff_ms`) before retrying. With
24//!    [`BridgeConfig::max_reconnect_attempts`] set, the loop returns
25//!    [`BridgeError::MaxRetriesExceeded`] once the cap is reached; with
26//!    `None`, retries continue indefinitely.
27//!
28//! [`BridgeClient::connect_and_relay_bidirectional`] adds an optional
29//! proactive-push channel: each [`BridgeMessage`] received on
30//! `proactive_rx` is serialised and pushed to the gateway alongside the
31//! normal request/response flow. If the broadcast channel lags, lagged
32//! events are logged and dropped; if it closes, the client continues in
33//! relay-only mode.
34//!
35//! ## Usage
36//! ```no_run
37//! # use brainos_bridge::{BridgeClient, BridgeConfig, BridgeMessage};
38//! # #[tokio::main] async fn main() -> anyhow::Result<()> {
39//! let client = BridgeClient::new("ws://gateway.example.com/brain", BridgeConfig::default());
40//! client.connect_and_relay(|msg| async move {
41//!     BridgeMessage::reply(&msg, format!("Echo: {}", msg.content))
42//! }).await?;
43//! # Ok(())
44//! # }
45//! ```
46
47use std::{collections::HashMap, future::Future, time::Duration};
48
49use serde::{Deserialize, Serialize};
50use thiserror::Error;
51
52// ─── Errors ──────────────────────────────────────────────────────────────────
53
54#[derive(Debug, Error)]
55pub enum BridgeError {
56    #[error("WebSocket error: {0}")]
57    WebSocket(String),
58
59    #[error("Connection failed after {0} attempts")]
60    MaxRetriesExceeded(u32),
61}
62
63// ─── Types ────────────────────────────────────────────────────────────────────
64
65/// Configuration for [`BridgeClient`] reconnection behaviour.
66#[derive(Debug, Clone)]
67pub struct BridgeConfig {
68    /// Base delay used to compute backoff (milliseconds). Default: 1 000 ms.
69    /// The actual wait before retry `n` (counting from 0 after the most
70    /// recent successful connection) is
71    /// `min(initial_backoff_ms * 2^n, max_backoff_ms)`.
72    pub initial_backoff_ms: u64,
73    /// Upper bound on the per-retry backoff delay (milliseconds).
74    /// Default: 60 000 ms.
75    pub max_backoff_ms: u64,
76    /// Maximum number of *consecutive* failed reconnect attempts before
77    /// the relay loop gives up with [`BridgeError::MaxRetriesExceeded`].
78    /// The counter resets to zero on every successful WebSocket connect,
79    /// so a flaky link can still recover indefinitely as long as it
80    /// occasionally succeeds. `None` (the default) means keep retrying
81    /// forever.
82    pub max_reconnect_attempts: Option<u32>,
83}
84
85impl Default for BridgeConfig {
86    fn default() -> Self {
87        Self {
88            initial_backoff_ms: 1_000,
89            max_backoff_ms: 60_000,
90            max_reconnect_attempts: None,
91        }
92    }
93}
94
95/// A message exchanged between Brain and the remote gateway.
96#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
97pub struct BridgeMessage {
98    /// Unique message ID (UUID v4 string).
99    pub id: String,
100    /// Text content of the message.
101    pub content: String,
102    /// Optional source label set by the gateway (transport id, channel
103    /// name, sender id — whatever the gateway forwards).
104    pub source: Option<String>,
105    /// Arbitrary key-value metadata forwarded from the gateway.
106    pub metadata: Option<HashMap<String, String>>,
107}
108
109impl BridgeMessage {
110    /// Create a new outbound message with a fresh UUID.
111    pub fn new(content: impl Into<String>) -> Self {
112        Self {
113            id: uuid::Uuid::new_v4().to_string(),
114            content: content.into(),
115            source: None,
116            metadata: None,
117        }
118    }
119
120    /// Create a reply to `original`, reusing the same `id` for correlation.
121    pub fn reply(original: &BridgeMessage, content: impl Into<String>) -> Self {
122        Self {
123            id: original.id.clone(),
124            content: content.into(),
125            source: None,
126            metadata: None,
127        }
128    }
129}
130
131// ─── BridgeClient ─────────────────────────────────────────────────────────────
132
133/// WebSocket client that relays messages between Brain and a remote gateway.
134pub struct BridgeClient {
135    url: String,
136    config: BridgeConfig,
137}
138
139impl BridgeClient {
140    /// Create a new client pointing at `url`.
141    pub fn new(url: impl Into<String>, config: BridgeConfig) -> Self {
142        Self {
143            url: url.into(),
144            config,
145        }
146    }
147
148    /// Calculate the backoff duration for a given retry attempt number (0-indexed).
149    ///
150    /// Uses exponential back-off: `initial * 2^attempt`, capped at `max`.
151    pub fn backoff_duration(&self, attempt: u32) -> Duration {
152        // 2^attempt, capped to avoid overflow
153        let multiplier = 1u64.checked_shl(attempt.min(62)).unwrap_or(u64::MAX);
154        let ms = self
155            .config
156            .initial_backoff_ms
157            .saturating_mul(multiplier)
158            .min(self.config.max_backoff_ms);
159        Duration::from_millis(ms)
160    }
161
162    /// Connect to the remote WebSocket gateway and relay messages indefinitely.
163    ///
164    /// For each inbound [`BridgeMessage`], `handler` is called and its return
165    /// value is sent back as a text frame.  On disconnect the client waits for
166    /// the appropriate backoff period then reconnects automatically.
167    ///
168    /// Returns `Err(BridgeError::MaxRetriesExceeded)` only when
169    /// [`BridgeConfig::max_reconnect_attempts`] is set and exceeded.
170    pub async fn connect_and_relay<F, Fut>(&self, handler: F) -> Result<(), BridgeError>
171    where
172        F: Fn(BridgeMessage) -> Fut + Clone,
173        Fut: Future<Output = BridgeMessage>,
174    {
175        self.connect_and_relay_bidirectional(handler, None).await
176    }
177
178    /// Connect with optional proactive push channel.
179    ///
180    /// When `proactive_rx` is provided, proactive notifications are forwarded
181    /// to the gateway as outbound `BridgeMessage` frames alongside the normal
182    /// request-response relay.
183    pub async fn connect_and_relay_bidirectional<F, Fut>(
184        &self,
185        handler: F,
186        mut proactive_rx: Option<tokio::sync::broadcast::Receiver<BridgeMessage>>,
187    ) -> Result<(), BridgeError>
188    where
189        F: Fn(BridgeMessage) -> Fut + Clone,
190        Fut: Future<Output = BridgeMessage>,
191    {
192        use futures_util::{SinkExt, StreamExt};
193        use tokio_tungstenite::{connect_async, tungstenite::Message};
194
195        let mut attempt = 0u32;
196
197        loop {
198            // Check retry limit before sleeping
199            if let Some(max) = self.config.max_reconnect_attempts {
200                if attempt >= max {
201                    return Err(BridgeError::MaxRetriesExceeded(attempt));
202                }
203            }
204
205            // Backoff before retries (not before the first attempt)
206            if attempt > 0 {
207                let backoff = self.backoff_duration(attempt - 1);
208                tracing::info!(
209                    url = %self.url,
210                    attempt,
211                    backoff_ms = backoff.as_millis(),
212                    "Reconnecting to bridge gateway"
213                );
214                tokio::time::sleep(backoff).await;
215            }
216
217            tracing::info!(url = %self.url, "Connecting to bridge gateway");
218
219            let ws_stream = match connect_async(&self.url).await {
220                Err(e) => {
221                    tracing::warn!(url = %self.url, error = %e, "Bridge connection failed");
222                    attempt += 1;
223                    continue;
224                }
225                Ok((ws, _response)) => {
226                    tracing::info!(url = %self.url, "Bridge connected");
227                    attempt = 0; // reset on successful connection
228                    ws
229                }
230            };
231
232            let (mut sink, mut stream) = ws_stream.split();
233
234            loop {
235                tokio::select! {
236                    ws_msg = stream.next() => {
237                        match ws_msg {
238                            None => {
239                                tracing::warn!(url = %self.url, "Bridge stream ended (EOF)");
240                                break;
241                            }
242                            Some(Err(e)) => {
243                                tracing::warn!(url = %self.url, error = %e, "Bridge WebSocket error");
244                                break;
245                            }
246                            Some(Ok(Message::Ping(data))) => {
247                                if sink.send(Message::Pong(data)).await.is_err() {
248                                    break;
249                                }
250                            }
251                            Some(Ok(Message::Close(_))) => {
252                                tracing::info!(url = %self.url, "Bridge connection closed by remote");
253                                break;
254                            }
255                            Some(Ok(Message::Text(text))) => {
256                                let msg: BridgeMessage = match serde_json::from_str(&text) {
257                                    Ok(m) => m,
258                                    Err(e) => {
259                                        tracing::warn!(
260                                            error = %e,
261                                            raw = %text,
262                                            "Ignoring unparseable bridge message"
263                                        );
264                                        continue;
265                                    }
266                                };
267
268                                let msg_id = msg.id.clone();
269                                let response = handler.clone()(msg).await;
270
271                                match serde_json::to_string(&response) {
272                                    Ok(payload) => {
273                                        if sink.send(Message::Text(payload.into())).await.is_err() {
274                                            tracing::warn!(id = %msg_id, "Failed to send bridge response");
275                                            break;
276                                        }
277                                    }
278                                    Err(e) => {
279                                        tracing::error!(id = %msg_id, error = %e, "Failed to serialise response");
280                                    }
281                                }
282                            }
283                            Some(Ok(_)) => {} // ignore binary frames and pong
284                        }
285                    }
286                    proactive = async {
287                        match proactive_rx.as_mut() {
288                            Some(rx) => rx.recv().await,
289                            None => std::future::pending().await,
290                        }
291                    } => {
292                        match proactive {
293                            Ok(msg) => {
294                                if let Ok(payload) = serde_json::to_string(&msg) {
295                                    if sink.send(Message::Text(payload.into())).await.is_err() {
296                                        tracing::warn!("Failed to push proactive to bridge");
297                                        break;
298                                    }
299                                    tracing::debug!(id = %msg.id, "Proactive pushed to bridge");
300                                }
301                            }
302                            Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
303                                tracing::warn!(skipped = n, "Bridge proactive receiver lagged");
304                            }
305                            Err(tokio::sync::broadcast::error::RecvError::Closed) => {
306                                tracing::info!("Proactive channel closed, bridge continues in relay-only mode");
307                                proactive_rx = None;
308                            }
309                        }
310                    }
311                }
312            }
313
314            attempt += 1;
315        }
316    }
317}
318
319// ─── Tests ────────────────────────────────────────────────────────────────────
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn test_bridge_config_defaults() {
327        let cfg = BridgeConfig::default();
328        assert_eq!(cfg.initial_backoff_ms, 1_000);
329        assert_eq!(cfg.max_backoff_ms, 60_000);
330        assert!(cfg.max_reconnect_attempts.is_none());
331    }
332
333    #[test]
334    fn test_backoff_grows_exponentially() {
335        let client = BridgeClient::new("ws://unused", BridgeConfig::default());
336        // attempt 0: 1000 ms
337        assert_eq!(client.backoff_duration(0), Duration::from_millis(1_000));
338        // attempt 1: 2000 ms
339        assert_eq!(client.backoff_duration(1), Duration::from_millis(2_000));
340        // attempt 2: 4000 ms
341        assert_eq!(client.backoff_duration(2), Duration::from_millis(4_000));
342        // attempt 3: 8000 ms
343        assert_eq!(client.backoff_duration(3), Duration::from_millis(8_000));
344    }
345
346    #[test]
347    fn test_backoff_capped_at_max() {
348        let cfg = BridgeConfig {
349            initial_backoff_ms: 1_000,
350            max_backoff_ms: 5_000,
351            max_reconnect_attempts: None,
352        };
353        let client = BridgeClient::new("ws://unused", cfg);
354        // After enough doublings it should saturate at 5000
355        assert_eq!(client.backoff_duration(10), Duration::from_millis(5_000));
356        assert_eq!(client.backoff_duration(30), Duration::from_millis(5_000));
357    }
358
359    #[test]
360    fn test_bridge_message_new() {
361        let msg = BridgeMessage::new("hello");
362        assert!(!msg.id.is_empty());
363        assert_eq!(msg.content, "hello");
364        assert!(msg.source.is_none());
365        assert!(msg.metadata.is_none());
366    }
367
368    #[test]
369    fn test_bridge_message_reply_shares_id() {
370        let original = BridgeMessage::new("what time is it?");
371        let reply = BridgeMessage::reply(&original, "It is noon.");
372        assert_eq!(
373            reply.id, original.id,
374            "reply should reuse original message ID"
375        );
376        assert_eq!(reply.content, "It is noon.");
377    }
378
379    #[test]
380    fn test_bridge_message_roundtrip_json() {
381        let mut meta = HashMap::new();
382        meta.insert("channel".to_string(), "#general".to_string());
383        let original = BridgeMessage {
384            id: "test-id-123".to_string(),
385            content: "Deploy to prod?".to_string(),
386            source: Some("chat-main".to_string()),
387            metadata: Some(meta),
388        };
389        let json = serde_json::to_string(&original).unwrap();
390        let decoded: BridgeMessage = serde_json::from_str(&json).unwrap();
391        assert_eq!(decoded, original);
392    }
393
394    #[tokio::test]
395    async fn test_connect_fails_and_hits_max_retries() {
396        // Point at an address that will refuse connections immediately.
397        let cfg = BridgeConfig {
398            initial_backoff_ms: 1, // tiny backoff so the test is fast
399            max_backoff_ms: 1,
400            max_reconnect_attempts: Some(3),
401        };
402        let client = BridgeClient::new("ws://127.0.0.1:19999", cfg);
403
404        let result = client
405            .connect_and_relay(|msg| async move { BridgeMessage::reply(&msg, "ok") })
406            .await;
407
408        assert!(
409            matches!(result, Err(BridgeError::MaxRetriesExceeded(3))),
410            "expected MaxRetriesExceeded(3), got: {:?}",
411            result
412        );
413    }
414}