brainos_bridge/lib.rs
1//! # Brain Bridge
2//!
3//! External service relay — WebSocket client that connects Brain to a
4//! remote messaging gateway and relays inbound messages through Brain's
5//! signal processing pipeline.
6//!
7//! ## Protocol
8//!
9//! 1. [`BridgeClient`] connects to the configured `url` via WebSocket.
10//! 2. Each inbound text frame is parsed as a JSON-encoded [`BridgeMessage`].
11//! Unparseable text frames are logged and skipped; the connection stays
12//! open. Binary frames are silently ignored.
13//! 3. WebSocket `Ping` frames are answered with `Pong` automatically.
14//! `Close` frames trigger the reconnect loop after the configured
15//! backoff.
16//! 4. A caller-supplied `handler` is invoked for each parsed message and
17//! must return a [`BridgeMessage`] response (the response keeps the
18//! original message's `id` if built via [`BridgeMessage::reply`]).
19//! 5. The response is serialised and sent back as a text frame. Failure
20//! to send tears down the connection and triggers the reconnect loop;
21//! serialisation failure logs and drops the response only.
22//! 6. On disconnect, the client waits `initial_backoff_ms * 2^attempt`
23//! (capped at `max_backoff_ms`) before retrying. With
24//! [`BridgeConfig::max_reconnect_attempts`] set, the loop returns
25//! [`BridgeError::MaxRetriesExceeded`] once the cap is reached; with
26//! `None`, retries continue indefinitely.
27//!
28//! [`BridgeClient::connect_and_relay_bidirectional`] adds an optional
29//! proactive-push channel: each [`BridgeMessage`] received on
30//! `proactive_rx` is serialised and pushed to the gateway alongside the
31//! normal request/response flow. If the broadcast channel lags, lagged
32//! events are logged and dropped; if it closes, the client continues in
33//! relay-only mode.
34//!
35//! ## Usage
36//! ```no_run
37//! # use brainos_bridge::{BridgeClient, BridgeConfig, BridgeMessage};
38//! # #[tokio::main] async fn main() -> anyhow::Result<()> {
39//! let client = BridgeClient::new("ws://gateway.example.com/brain", BridgeConfig::default());
40//! client.connect_and_relay(|msg| async move {
41//! BridgeMessage::reply(&msg, format!("Echo: {}", msg.content))
42//! }).await?;
43//! # Ok(())
44//! # }
45//! ```
46
47use std::{collections::HashMap, future::Future, time::Duration};
48
49use serde::{Deserialize, Serialize};
50use thiserror::Error;
51
52// ─── Errors ──────────────────────────────────────────────────────────────────
53
54#[derive(Debug, Error)]
55pub enum BridgeError {
56 #[error("WebSocket error: {0}")]
57 WebSocket(String),
58
59 #[error("Connection failed after {0} attempts")]
60 MaxRetriesExceeded(u32),
61}
62
63// ─── Types ────────────────────────────────────────────────────────────────────
64
65/// Configuration for [`BridgeClient`] reconnection behaviour.
66#[derive(Debug, Clone)]
67pub struct BridgeConfig {
68 /// Base delay used to compute backoff (milliseconds). Default: 1 000 ms.
69 /// The actual wait before retry `n` (counting from 0 after the most
70 /// recent successful connection) is
71 /// `min(initial_backoff_ms * 2^n, max_backoff_ms)`.
72 pub initial_backoff_ms: u64,
73 /// Upper bound on the per-retry backoff delay (milliseconds).
74 /// Default: 60 000 ms.
75 pub max_backoff_ms: u64,
76 /// Maximum number of *consecutive* failed reconnect attempts before
77 /// the relay loop gives up with [`BridgeError::MaxRetriesExceeded`].
78 /// The counter resets to zero on every successful WebSocket connect,
79 /// so a flaky link can still recover indefinitely as long as it
80 /// occasionally succeeds. `None` (the default) means keep retrying
81 /// forever.
82 pub max_reconnect_attempts: Option<u32>,
83}
84
85impl Default for BridgeConfig {
86 fn default() -> Self {
87 Self {
88 initial_backoff_ms: 1_000,
89 max_backoff_ms: 60_000,
90 max_reconnect_attempts: None,
91 }
92 }
93}
94
95/// A message exchanged between Brain and the remote gateway.
96#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
97pub struct BridgeMessage {
98 /// Unique message ID (UUID v4 string).
99 pub id: String,
100 /// Text content of the message.
101 pub content: String,
102 /// Optional source label set by the gateway (transport id, channel
103 /// name, sender id — whatever the gateway forwards).
104 pub source: Option<String>,
105 /// Arbitrary key-value metadata forwarded from the gateway.
106 pub metadata: Option<HashMap<String, String>>,
107}
108
109impl BridgeMessage {
110 /// Create a new outbound message with a fresh UUID.
111 pub fn new(content: impl Into<String>) -> Self {
112 Self {
113 id: uuid::Uuid::new_v4().to_string(),
114 content: content.into(),
115 source: None,
116 metadata: None,
117 }
118 }
119
120 /// Create a reply to `original`, reusing the same `id` for correlation.
121 pub fn reply(original: &BridgeMessage, content: impl Into<String>) -> Self {
122 Self {
123 id: original.id.clone(),
124 content: content.into(),
125 source: None,
126 metadata: None,
127 }
128 }
129}
130
131// ─── BridgeClient ─────────────────────────────────────────────────────────────
132
133/// WebSocket client that relays messages between Brain and a remote gateway.
134pub struct BridgeClient {
135 url: String,
136 config: BridgeConfig,
137}
138
139impl BridgeClient {
140 /// Create a new client pointing at `url`.
141 pub fn new(url: impl Into<String>, config: BridgeConfig) -> Self {
142 Self {
143 url: url.into(),
144 config,
145 }
146 }
147
148 /// Calculate the backoff duration for a given retry attempt number (0-indexed).
149 ///
150 /// Uses exponential back-off: `initial * 2^attempt`, capped at `max`.
151 pub fn backoff_duration(&self, attempt: u32) -> Duration {
152 // 2^attempt, capped to avoid overflow
153 let multiplier = 1u64.checked_shl(attempt.min(62)).unwrap_or(u64::MAX);
154 let ms = self
155 .config
156 .initial_backoff_ms
157 .saturating_mul(multiplier)
158 .min(self.config.max_backoff_ms);
159 Duration::from_millis(ms)
160 }
161
162 /// Connect to the remote WebSocket gateway and relay messages indefinitely.
163 ///
164 /// For each inbound [`BridgeMessage`], `handler` is called and its return
165 /// value is sent back as a text frame. On disconnect the client waits for
166 /// the appropriate backoff period then reconnects automatically.
167 ///
168 /// Returns `Err(BridgeError::MaxRetriesExceeded)` only when
169 /// [`BridgeConfig::max_reconnect_attempts`] is set and exceeded.
170 pub async fn connect_and_relay<F, Fut>(&self, handler: F) -> Result<(), BridgeError>
171 where
172 F: Fn(BridgeMessage) -> Fut + Clone,
173 Fut: Future<Output = BridgeMessage>,
174 {
175 self.connect_and_relay_bidirectional(handler, None).await
176 }
177
178 /// Connect with optional proactive push channel.
179 ///
180 /// When `proactive_rx` is provided, proactive notifications are forwarded
181 /// to the gateway as outbound `BridgeMessage` frames alongside the normal
182 /// request-response relay.
183 pub async fn connect_and_relay_bidirectional<F, Fut>(
184 &self,
185 handler: F,
186 mut proactive_rx: Option<tokio::sync::broadcast::Receiver<BridgeMessage>>,
187 ) -> Result<(), BridgeError>
188 where
189 F: Fn(BridgeMessage) -> Fut + Clone,
190 Fut: Future<Output = BridgeMessage>,
191 {
192 use futures_util::{SinkExt, StreamExt};
193 use tokio_tungstenite::{connect_async, tungstenite::Message};
194
195 let mut attempt = 0u32;
196
197 loop {
198 // Check retry limit before sleeping
199 if let Some(max) = self.config.max_reconnect_attempts {
200 if attempt >= max {
201 return Err(BridgeError::MaxRetriesExceeded(attempt));
202 }
203 }
204
205 // Backoff before retries (not before the first attempt)
206 if attempt > 0 {
207 let backoff = self.backoff_duration(attempt - 1);
208 tracing::info!(
209 url = %self.url,
210 attempt,
211 backoff_ms = backoff.as_millis(),
212 "Reconnecting to bridge gateway"
213 );
214 tokio::time::sleep(backoff).await;
215 }
216
217 tracing::info!(url = %self.url, "Connecting to bridge gateway");
218
219 let ws_stream = match connect_async(&self.url).await {
220 Err(e) => {
221 tracing::warn!(url = %self.url, error = %e, "Bridge connection failed");
222 attempt += 1;
223 continue;
224 }
225 Ok((ws, _response)) => {
226 tracing::info!(url = %self.url, "Bridge connected");
227 attempt = 0; // reset on successful connection
228 ws
229 }
230 };
231
232 let (mut sink, mut stream) = ws_stream.split();
233
234 loop {
235 tokio::select! {
236 ws_msg = stream.next() => {
237 match ws_msg {
238 None => {
239 tracing::warn!(url = %self.url, "Bridge stream ended (EOF)");
240 break;
241 }
242 Some(Err(e)) => {
243 tracing::warn!(url = %self.url, error = %e, "Bridge WebSocket error");
244 break;
245 }
246 Some(Ok(Message::Ping(data))) => {
247 if sink.send(Message::Pong(data)).await.is_err() {
248 break;
249 }
250 }
251 Some(Ok(Message::Close(_))) => {
252 tracing::info!(url = %self.url, "Bridge connection closed by remote");
253 break;
254 }
255 Some(Ok(Message::Text(text))) => {
256 let msg: BridgeMessage = match serde_json::from_str(&text) {
257 Ok(m) => m,
258 Err(e) => {
259 tracing::warn!(
260 error = %e,
261 raw = %text,
262 "Ignoring unparseable bridge message"
263 );
264 continue;
265 }
266 };
267
268 let msg_id = msg.id.clone();
269 let response = handler.clone()(msg).await;
270
271 match serde_json::to_string(&response) {
272 Ok(payload) => {
273 if sink.send(Message::Text(payload.into())).await.is_err() {
274 tracing::warn!(id = %msg_id, "Failed to send bridge response");
275 break;
276 }
277 }
278 Err(e) => {
279 tracing::error!(id = %msg_id, error = %e, "Failed to serialise response");
280 }
281 }
282 }
283 Some(Ok(_)) => {} // ignore binary frames and pong
284 }
285 }
286 proactive = async {
287 match proactive_rx.as_mut() {
288 Some(rx) => rx.recv().await,
289 None => std::future::pending().await,
290 }
291 } => {
292 match proactive {
293 Ok(msg) => {
294 if let Ok(payload) = serde_json::to_string(&msg) {
295 if sink.send(Message::Text(payload.into())).await.is_err() {
296 tracing::warn!("Failed to push proactive to bridge");
297 break;
298 }
299 tracing::debug!(id = %msg.id, "Proactive pushed to bridge");
300 }
301 }
302 Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
303 tracing::warn!(skipped = n, "Bridge proactive receiver lagged");
304 }
305 Err(tokio::sync::broadcast::error::RecvError::Closed) => {
306 tracing::info!("Proactive channel closed, bridge continues in relay-only mode");
307 proactive_rx = None;
308 }
309 }
310 }
311 }
312 }
313
314 attempt += 1;
315 }
316 }
317}
318
319// ─── Tests ────────────────────────────────────────────────────────────────────
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[test]
326 fn test_bridge_config_defaults() {
327 let cfg = BridgeConfig::default();
328 assert_eq!(cfg.initial_backoff_ms, 1_000);
329 assert_eq!(cfg.max_backoff_ms, 60_000);
330 assert!(cfg.max_reconnect_attempts.is_none());
331 }
332
333 #[test]
334 fn test_backoff_grows_exponentially() {
335 let client = BridgeClient::new("ws://unused", BridgeConfig::default());
336 // attempt 0: 1000 ms
337 assert_eq!(client.backoff_duration(0), Duration::from_millis(1_000));
338 // attempt 1: 2000 ms
339 assert_eq!(client.backoff_duration(1), Duration::from_millis(2_000));
340 // attempt 2: 4000 ms
341 assert_eq!(client.backoff_duration(2), Duration::from_millis(4_000));
342 // attempt 3: 8000 ms
343 assert_eq!(client.backoff_duration(3), Duration::from_millis(8_000));
344 }
345
346 #[test]
347 fn test_backoff_capped_at_max() {
348 let cfg = BridgeConfig {
349 initial_backoff_ms: 1_000,
350 max_backoff_ms: 5_000,
351 max_reconnect_attempts: None,
352 };
353 let client = BridgeClient::new("ws://unused", cfg);
354 // After enough doublings it should saturate at 5000
355 assert_eq!(client.backoff_duration(10), Duration::from_millis(5_000));
356 assert_eq!(client.backoff_duration(30), Duration::from_millis(5_000));
357 }
358
359 #[test]
360 fn test_bridge_message_new() {
361 let msg = BridgeMessage::new("hello");
362 assert!(!msg.id.is_empty());
363 assert_eq!(msg.content, "hello");
364 assert!(msg.source.is_none());
365 assert!(msg.metadata.is_none());
366 }
367
368 #[test]
369 fn test_bridge_message_reply_shares_id() {
370 let original = BridgeMessage::new("what time is it?");
371 let reply = BridgeMessage::reply(&original, "It is noon.");
372 assert_eq!(
373 reply.id, original.id,
374 "reply should reuse original message ID"
375 );
376 assert_eq!(reply.content, "It is noon.");
377 }
378
379 #[test]
380 fn test_bridge_message_roundtrip_json() {
381 let mut meta = HashMap::new();
382 meta.insert("channel".to_string(), "#general".to_string());
383 let original = BridgeMessage {
384 id: "test-id-123".to_string(),
385 content: "Deploy to prod?".to_string(),
386 source: Some("chat-main".to_string()),
387 metadata: Some(meta),
388 };
389 let json = serde_json::to_string(&original).unwrap();
390 let decoded: BridgeMessage = serde_json::from_str(&json).unwrap();
391 assert_eq!(decoded, original);
392 }
393
394 #[tokio::test]
395 async fn test_connect_fails_and_hits_max_retries() {
396 // Point at an address that will refuse connections immediately.
397 let cfg = BridgeConfig {
398 initial_backoff_ms: 1, // tiny backoff so the test is fast
399 max_backoff_ms: 1,
400 max_reconnect_attempts: Some(3),
401 };
402 let client = BridgeClient::new("ws://127.0.0.1:19999", cfg);
403
404 let result = client
405 .connect_and_relay(|msg| async move { BridgeMessage::reply(&msg, "ok") })
406 .await;
407
408 assert!(
409 matches!(result, Err(BridgeError::MaxRetriesExceeded(3))),
410 "expected MaxRetriesExceeded(3), got: {:?}",
411 result
412 );
413 }
414}