Skip to main content

brainos_bridge/
lib.rs

1//! # Brain Bridge
2//!
3//! External service relay — WebSocket client that connects Brain to a
4//! remote messaging gateway and relays inbound messages through Brain's
5//! signal processing pipeline.
6//!
7//! ## Protocol
8//! 1. `BridgeClient` connects to the configured `url` via WebSocket.
9//! 2. Each inbound text frame must be a JSON-encoded [`BridgeMessage`].
10//! 3. A caller-supplied `handler` function processes the message and returns
11//!    a [`BridgeMessage`] response.
12//! 4. The response is serialised and sent back as a text frame.
13//! 5. On disconnect, the client reconnects with exponential backoff.
14//!
15//! ## Usage
16//! ```no_run
17//! # use brainos_bridge::{BridgeClient, BridgeConfig, BridgeMessage};
18//! # #[tokio::main] async fn main() -> anyhow::Result<()> {
19//! let client = BridgeClient::new("ws://gateway.example.com/brain", BridgeConfig::default());
20//! client.connect_and_relay(|msg| async move {
21//!     BridgeMessage::reply(&msg, format!("Echo: {}", msg.content))
22//! }).await?;
23//! # Ok(())
24//! # }
25//! ```
26
27use std::{collections::HashMap, future::Future, time::Duration};
28
29use serde::{Deserialize, Serialize};
30use thiserror::Error;
31
32// ─── Errors ──────────────────────────────────────────────────────────────────
33
34#[derive(Debug, Error)]
35pub enum BridgeError {
36    #[error("WebSocket error: {0}")]
37    WebSocket(String),
38
39    #[error("Connection failed after {0} attempts")]
40    MaxRetriesExceeded(u32),
41}
42
43// ─── Types ────────────────────────────────────────────────────────────────────
44
45/// Configuration for [`BridgeClient`] reconnection behaviour.
46#[derive(Debug, Clone)]
47pub struct BridgeConfig {
48    /// Initial delay before first retry (milliseconds). Default: 1 000 ms.
49    pub initial_backoff_ms: u64,
50    /// Maximum delay between retries (milliseconds). Default: 60 000 ms.
51    pub max_backoff_ms: u64,
52    /// Maximum total reconnection attempts. `None` = reconnect forever.
53    pub max_reconnect_attempts: Option<u32>,
54}
55
56impl Default for BridgeConfig {
57    fn default() -> Self {
58        Self {
59            initial_backoff_ms: 1_000,
60            max_backoff_ms: 60_000,
61            max_reconnect_attempts: None,
62        }
63    }
64}
65
66/// A message exchanged between Brain and the remote gateway.
67#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
68pub struct BridgeMessage {
69    /// Unique message ID (UUID v4 string).
70    pub id: String,
71    /// Text content of the message.
72    pub content: String,
73    /// Optional source label set by the gateway (transport id, channel
74    /// name, sender id — whatever the gateway forwards).
75    pub source: Option<String>,
76    /// Arbitrary key-value metadata forwarded from the gateway.
77    pub metadata: Option<HashMap<String, String>>,
78}
79
80impl BridgeMessage {
81    /// Create a new outbound message with a fresh UUID.
82    pub fn new(content: impl Into<String>) -> Self {
83        Self {
84            id: uuid::Uuid::new_v4().to_string(),
85            content: content.into(),
86            source: None,
87            metadata: None,
88        }
89    }
90
91    /// Create a reply to `original`, reusing the same `id` for correlation.
92    pub fn reply(original: &BridgeMessage, content: impl Into<String>) -> Self {
93        Self {
94            id: original.id.clone(),
95            content: content.into(),
96            source: None,
97            metadata: None,
98        }
99    }
100}
101
102// ─── BridgeClient ─────────────────────────────────────────────────────────────
103
104/// WebSocket client that relays messages between Brain and a remote gateway.
105pub struct BridgeClient {
106    url: String,
107    config: BridgeConfig,
108}
109
110impl BridgeClient {
111    /// Create a new client pointing at `url`.
112    pub fn new(url: impl Into<String>, config: BridgeConfig) -> Self {
113        Self {
114            url: url.into(),
115            config,
116        }
117    }
118
119    /// Calculate the backoff duration for a given retry attempt number (0-indexed).
120    ///
121    /// Uses exponential back-off: `initial * 2^attempt`, capped at `max`.
122    pub fn backoff_duration(&self, attempt: u32) -> Duration {
123        // 2^attempt, capped to avoid overflow
124        let multiplier = 1u64.checked_shl(attempt.min(62)).unwrap_or(u64::MAX);
125        let ms = self
126            .config
127            .initial_backoff_ms
128            .saturating_mul(multiplier)
129            .min(self.config.max_backoff_ms);
130        Duration::from_millis(ms)
131    }
132
133    /// Connect to the remote WebSocket gateway and relay messages indefinitely.
134    ///
135    /// For each inbound [`BridgeMessage`], `handler` is called and its return
136    /// value is sent back as a text frame.  On disconnect the client waits for
137    /// the appropriate backoff period then reconnects automatically.
138    ///
139    /// Returns `Err(BridgeError::MaxRetriesExceeded)` only when
140    /// [`BridgeConfig::max_reconnect_attempts`] is set and exceeded.
141    pub async fn connect_and_relay<F, Fut>(&self, handler: F) -> Result<(), BridgeError>
142    where
143        F: Fn(BridgeMessage) -> Fut + Clone,
144        Fut: Future<Output = BridgeMessage>,
145    {
146        self.connect_and_relay_bidirectional(handler, None).await
147    }
148
149    /// Connect with optional proactive push channel.
150    ///
151    /// When `proactive_rx` is provided, proactive notifications are forwarded
152    /// to the gateway as outbound `BridgeMessage` frames alongside the normal
153    /// request-response relay.
154    pub async fn connect_and_relay_bidirectional<F, Fut>(
155        &self,
156        handler: F,
157        mut proactive_rx: Option<tokio::sync::broadcast::Receiver<BridgeMessage>>,
158    ) -> Result<(), BridgeError>
159    where
160        F: Fn(BridgeMessage) -> Fut + Clone,
161        Fut: Future<Output = BridgeMessage>,
162    {
163        use futures_util::{SinkExt, StreamExt};
164        use tokio_tungstenite::{connect_async, tungstenite::Message};
165
166        let mut attempt = 0u32;
167
168        loop {
169            // Check retry limit before sleeping
170            if let Some(max) = self.config.max_reconnect_attempts {
171                if attempt >= max {
172                    return Err(BridgeError::MaxRetriesExceeded(attempt));
173                }
174            }
175
176            // Backoff before retries (not before the first attempt)
177            if attempt > 0 {
178                let backoff = self.backoff_duration(attempt - 1);
179                tracing::info!(
180                    url = %self.url,
181                    attempt,
182                    backoff_ms = backoff.as_millis(),
183                    "Reconnecting to bridge gateway"
184                );
185                tokio::time::sleep(backoff).await;
186            }
187
188            tracing::info!(url = %self.url, "Connecting to bridge gateway");
189
190            let ws_stream = match connect_async(&self.url).await {
191                Err(e) => {
192                    tracing::warn!(url = %self.url, error = %e, "Bridge connection failed");
193                    attempt += 1;
194                    continue;
195                }
196                Ok((ws, _response)) => {
197                    tracing::info!(url = %self.url, "Bridge connected");
198                    attempt = 0; // reset on successful connection
199                    ws
200                }
201            };
202
203            let (mut sink, mut stream) = ws_stream.split();
204
205            loop {
206                tokio::select! {
207                    ws_msg = stream.next() => {
208                        match ws_msg {
209                            None => {
210                                tracing::warn!(url = %self.url, "Bridge stream ended (EOF)");
211                                break;
212                            }
213                            Some(Err(e)) => {
214                                tracing::warn!(url = %self.url, error = %e, "Bridge WebSocket error");
215                                break;
216                            }
217                            Some(Ok(Message::Ping(data))) => {
218                                if sink.send(Message::Pong(data)).await.is_err() {
219                                    break;
220                                }
221                            }
222                            Some(Ok(Message::Close(_))) => {
223                                tracing::info!(url = %self.url, "Bridge connection closed by remote");
224                                break;
225                            }
226                            Some(Ok(Message::Text(text))) => {
227                                let msg: BridgeMessage = match serde_json::from_str(&text) {
228                                    Ok(m) => m,
229                                    Err(e) => {
230                                        tracing::warn!(
231                                            error = %e,
232                                            raw = %text,
233                                            "Ignoring unparseable bridge message"
234                                        );
235                                        continue;
236                                    }
237                                };
238
239                                let msg_id = msg.id.clone();
240                                let response = handler.clone()(msg).await;
241
242                                match serde_json::to_string(&response) {
243                                    Ok(payload) => {
244                                        if sink.send(Message::Text(payload.into())).await.is_err() {
245                                            tracing::warn!(id = %msg_id, "Failed to send bridge response");
246                                            break;
247                                        }
248                                    }
249                                    Err(e) => {
250                                        tracing::error!(id = %msg_id, error = %e, "Failed to serialise response");
251                                    }
252                                }
253                            }
254                            Some(Ok(_)) => {} // ignore binary frames and pong
255                        }
256                    }
257                    proactive = async {
258                        match proactive_rx.as_mut() {
259                            Some(rx) => rx.recv().await,
260                            None => std::future::pending().await,
261                        }
262                    } => {
263                        match proactive {
264                            Ok(msg) => {
265                                if let Ok(payload) = serde_json::to_string(&msg) {
266                                    if sink.send(Message::Text(payload.into())).await.is_err() {
267                                        tracing::warn!("Failed to push proactive to bridge");
268                                        break;
269                                    }
270                                    tracing::debug!(id = %msg.id, "Proactive pushed to bridge");
271                                }
272                            }
273                            Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
274                                tracing::warn!(skipped = n, "Bridge proactive receiver lagged");
275                            }
276                            Err(tokio::sync::broadcast::error::RecvError::Closed) => {
277                                tracing::info!("Proactive channel closed, bridge continues in relay-only mode");
278                                proactive_rx = None;
279                            }
280                        }
281                    }
282                }
283            }
284
285            attempt += 1;
286        }
287    }
288}
289
290// ─── Tests ────────────────────────────────────────────────────────────────────
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295
296    #[test]
297    fn test_bridge_config_defaults() {
298        let cfg = BridgeConfig::default();
299        assert_eq!(cfg.initial_backoff_ms, 1_000);
300        assert_eq!(cfg.max_backoff_ms, 60_000);
301        assert!(cfg.max_reconnect_attempts.is_none());
302    }
303
304    #[test]
305    fn test_backoff_grows_exponentially() {
306        let client = BridgeClient::new("ws://unused", BridgeConfig::default());
307        // attempt 0: 1000 ms
308        assert_eq!(client.backoff_duration(0), Duration::from_millis(1_000));
309        // attempt 1: 2000 ms
310        assert_eq!(client.backoff_duration(1), Duration::from_millis(2_000));
311        // attempt 2: 4000 ms
312        assert_eq!(client.backoff_duration(2), Duration::from_millis(4_000));
313        // attempt 3: 8000 ms
314        assert_eq!(client.backoff_duration(3), Duration::from_millis(8_000));
315    }
316
317    #[test]
318    fn test_backoff_capped_at_max() {
319        let cfg = BridgeConfig {
320            initial_backoff_ms: 1_000,
321            max_backoff_ms: 5_000,
322            max_reconnect_attempts: None,
323        };
324        let client = BridgeClient::new("ws://unused", cfg);
325        // After enough doublings it should saturate at 5000
326        assert_eq!(client.backoff_duration(10), Duration::from_millis(5_000));
327        assert_eq!(client.backoff_duration(30), Duration::from_millis(5_000));
328    }
329
330    #[test]
331    fn test_bridge_message_new() {
332        let msg = BridgeMessage::new("hello");
333        assert!(!msg.id.is_empty());
334        assert_eq!(msg.content, "hello");
335        assert!(msg.source.is_none());
336        assert!(msg.metadata.is_none());
337    }
338
339    #[test]
340    fn test_bridge_message_reply_shares_id() {
341        let original = BridgeMessage::new("what time is it?");
342        let reply = BridgeMessage::reply(&original, "It is noon.");
343        assert_eq!(
344            reply.id, original.id,
345            "reply should reuse original message ID"
346        );
347        assert_eq!(reply.content, "It is noon.");
348    }
349
350    #[test]
351    fn test_bridge_message_roundtrip_json() {
352        let mut meta = HashMap::new();
353        meta.insert("channel".to_string(), "#general".to_string());
354        let original = BridgeMessage {
355            id: "test-id-123".to_string(),
356            content: "Deploy to prod?".to_string(),
357            source: Some("chat-main".to_string()),
358            metadata: Some(meta),
359        };
360        let json = serde_json::to_string(&original).unwrap();
361        let decoded: BridgeMessage = serde_json::from_str(&json).unwrap();
362        assert_eq!(decoded, original);
363    }
364
365    #[tokio::test]
366    async fn test_connect_fails_and_hits_max_retries() {
367        // Point at an address that will refuse connections immediately.
368        let cfg = BridgeConfig {
369            initial_backoff_ms: 1, // tiny backoff so the test is fast
370            max_backoff_ms: 1,
371            max_reconnect_attempts: Some(3),
372        };
373        let client = BridgeClient::new("ws://127.0.0.1:19999", cfg);
374
375        let result = client
376            .connect_and_relay(|msg| async move { BridgeMessage::reply(&msg, "ok") })
377            .await;
378
379        assert!(
380            matches!(result, Err(BridgeError::MaxRetriesExceeded(3))),
381            "expected MaxRetriesExceeded(3), got: {:?}",
382            result
383        );
384    }
385}