Skip to main content

brainos_bridge/
lib.rs

1//! # Brain Bridge
2//!
3//! External service relay — WebSocket client that connects Brain to remote
4//! messaging gateways (Slack bots, Telegram bridges, custom agents, etc.)
5//! and relays inbound messages through Brain's signal processing pipeline.
6//!
7//! ## Protocol
8//! 1. `BridgeClient` connects to the configured `url` via WebSocket.
9//! 2. Each inbound text frame must be a JSON-encoded [`BridgeMessage`].
10//! 3. A caller-supplied `handler` function processes the message and returns
11//!    a [`BridgeMessage`] response.
12//! 4. The response is serialised and sent back as a text frame.
13//! 5. On disconnect, the client reconnects with exponential backoff.
14//!
15//! ## Usage
16//! ```no_run
17//! # use brainos_bridge::{BridgeClient, BridgeConfig, BridgeMessage};
18//! # #[tokio::main] async fn main() -> anyhow::Result<()> {
19//! let client = BridgeClient::new("ws://gateway.example.com/brain", BridgeConfig::default());
20//! client.connect_and_relay(|msg| async move {
21//!     BridgeMessage::reply(&msg, format!("Echo: {}", msg.content))
22//! }).await?;
23//! # Ok(())
24//! # }
25//! ```
26
27use std::{collections::HashMap, future::Future, time::Duration};
28
29use serde::{Deserialize, Serialize};
30use thiserror::Error;
31
32// ─── Errors ──────────────────────────────────────────────────────────────────
33
34#[derive(Debug, Error)]
35pub enum BridgeError {
36    #[error("WebSocket error: {0}")]
37    WebSocket(String),
38
39    #[error("Connection failed after {0} attempts")]
40    MaxRetriesExceeded(u32),
41}
42
43// ─── Types ────────────────────────────────────────────────────────────────────
44
45/// Configuration for [`BridgeClient`] reconnection behaviour.
46#[derive(Debug, Clone)]
47pub struct BridgeConfig {
48    /// Initial delay before first retry (milliseconds). Default: 1 000 ms.
49    pub initial_backoff_ms: u64,
50    /// Maximum delay between retries (milliseconds). Default: 60 000 ms.
51    pub max_backoff_ms: u64,
52    /// Maximum total reconnection attempts. `None` = reconnect forever.
53    pub max_reconnect_attempts: Option<u32>,
54}
55
56impl Default for BridgeConfig {
57    fn default() -> Self {
58        Self {
59            initial_backoff_ms: 1_000,
60            max_backoff_ms: 60_000,
61            max_reconnect_attempts: None,
62        }
63    }
64}
65
66/// A message exchanged between Brain and the remote gateway.
67#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
68pub struct BridgeMessage {
69    /// Unique message ID (UUID v4 string).
70    pub id: String,
71    /// Text content of the message.
72    pub content: String,
73    /// Optional source label (e.g. `"slack"`, `"telegram"`).
74    pub source: Option<String>,
75    /// Arbitrary key-value metadata forwarded from the gateway.
76    pub metadata: Option<HashMap<String, String>>,
77}
78
79impl BridgeMessage {
80    /// Create a new outbound message with a fresh UUID.
81    pub fn new(content: impl Into<String>) -> Self {
82        Self {
83            id: uuid::Uuid::new_v4().to_string(),
84            content: content.into(),
85            source: None,
86            metadata: None,
87        }
88    }
89
90    /// Create a reply to `original`, reusing the same `id` for correlation.
91    pub fn reply(original: &BridgeMessage, content: impl Into<String>) -> Self {
92        Self {
93            id: original.id.clone(),
94            content: content.into(),
95            source: None,
96            metadata: None,
97        }
98    }
99}
100
101// ─── BridgeClient ─────────────────────────────────────────────────────────────
102
103/// WebSocket client that relays messages between Brain and a remote gateway.
104pub struct BridgeClient {
105    url: String,
106    config: BridgeConfig,
107}
108
109impl BridgeClient {
110    /// Create a new client pointing at `url`.
111    pub fn new(url: impl Into<String>, config: BridgeConfig) -> Self {
112        Self {
113            url: url.into(),
114            config,
115        }
116    }
117
118    /// Calculate the backoff duration for a given retry attempt number (0-indexed).
119    ///
120    /// Uses exponential back-off: `initial * 2^attempt`, capped at `max`.
121    pub fn backoff_duration(&self, attempt: u32) -> Duration {
122        // 2^attempt, capped to avoid overflow
123        let multiplier = 1u64.checked_shl(attempt.min(62)).unwrap_or(u64::MAX);
124        let ms = self
125            .config
126            .initial_backoff_ms
127            .saturating_mul(multiplier)
128            .min(self.config.max_backoff_ms);
129        Duration::from_millis(ms)
130    }
131
132    /// Connect to the remote WebSocket gateway and relay messages indefinitely.
133    ///
134    /// For each inbound [`BridgeMessage`], `handler` is called and its return
135    /// value is sent back as a text frame.  On disconnect the client waits for
136    /// the appropriate backoff period then reconnects automatically.
137    ///
138    /// Returns `Err(BridgeError::MaxRetriesExceeded)` only when
139    /// [`BridgeConfig::max_reconnect_attempts`] is set and exceeded.
140    pub async fn connect_and_relay<F, Fut>(&self, handler: F) -> Result<(), BridgeError>
141    where
142        F: Fn(BridgeMessage) -> Fut + Clone,
143        Fut: Future<Output = BridgeMessage>,
144    {
145        self.connect_and_relay_bidirectional(handler, None).await
146    }
147
148    /// Connect with optional proactive push channel.
149    ///
150    /// When `proactive_rx` is provided, proactive notifications are forwarded
151    /// to the gateway as outbound `BridgeMessage` frames alongside the normal
152    /// request-response relay.
153    pub async fn connect_and_relay_bidirectional<F, Fut>(
154        &self,
155        handler: F,
156        mut proactive_rx: Option<tokio::sync::broadcast::Receiver<BridgeMessage>>,
157    ) -> Result<(), BridgeError>
158    where
159        F: Fn(BridgeMessage) -> Fut + Clone,
160        Fut: Future<Output = BridgeMessage>,
161    {
162        use futures_util::{SinkExt, StreamExt};
163        use tokio_tungstenite::{connect_async, tungstenite::Message};
164
165        let mut attempt = 0u32;
166
167        loop {
168            // Check retry limit before sleeping
169            if let Some(max) = self.config.max_reconnect_attempts {
170                if attempt >= max {
171                    return Err(BridgeError::MaxRetriesExceeded(attempt));
172                }
173            }
174
175            // Backoff before retries (not before the first attempt)
176            if attempt > 0 {
177                let backoff = self.backoff_duration(attempt - 1);
178                tracing::info!(
179                    url = %self.url,
180                    attempt,
181                    backoff_ms = backoff.as_millis(),
182                    "Reconnecting to bridge gateway"
183                );
184                tokio::time::sleep(backoff).await;
185            }
186
187            tracing::info!(url = %self.url, "Connecting to bridge gateway");
188
189            let ws_stream = match connect_async(&self.url).await {
190                Err(e) => {
191                    tracing::warn!(url = %self.url, error = %e, "Bridge connection failed");
192                    attempt += 1;
193                    continue;
194                }
195                Ok((ws, _response)) => {
196                    tracing::info!(url = %self.url, "Bridge connected");
197                    attempt = 0; // reset on successful connection
198                    ws
199                }
200            };
201
202            let (mut sink, mut stream) = ws_stream.split();
203
204            loop {
205                tokio::select! {
206                    ws_msg = stream.next() => {
207                        match ws_msg {
208                            None => {
209                                tracing::warn!(url = %self.url, "Bridge stream ended (EOF)");
210                                break;
211                            }
212                            Some(Err(e)) => {
213                                tracing::warn!(url = %self.url, error = %e, "Bridge WebSocket error");
214                                break;
215                            }
216                            Some(Ok(Message::Ping(data))) => {
217                                if sink.send(Message::Pong(data)).await.is_err() {
218                                    break;
219                                }
220                            }
221                            Some(Ok(Message::Close(_))) => {
222                                tracing::info!(url = %self.url, "Bridge connection closed by remote");
223                                break;
224                            }
225                            Some(Ok(Message::Text(text))) => {
226                                let msg: BridgeMessage = match serde_json::from_str(&text) {
227                                    Ok(m) => m,
228                                    Err(e) => {
229                                        tracing::warn!(
230                                            error = %e,
231                                            raw = %text,
232                                            "Ignoring unparseable bridge message"
233                                        );
234                                        continue;
235                                    }
236                                };
237
238                                let msg_id = msg.id.clone();
239                                let response = handler.clone()(msg).await;
240
241                                match serde_json::to_string(&response) {
242                                    Ok(payload) => {
243                                        if sink.send(Message::Text(payload.into())).await.is_err() {
244                                            tracing::warn!(id = %msg_id, "Failed to send bridge response");
245                                            break;
246                                        }
247                                    }
248                                    Err(e) => {
249                                        tracing::error!(id = %msg_id, error = %e, "Failed to serialise response");
250                                    }
251                                }
252                            }
253                            Some(Ok(_)) => {} // ignore binary frames and pong
254                        }
255                    }
256                    proactive = async {
257                        match proactive_rx.as_mut() {
258                            Some(rx) => rx.recv().await,
259                            None => std::future::pending().await,
260                        }
261                    } => {
262                        match proactive {
263                            Ok(msg) => {
264                                if let Ok(payload) = serde_json::to_string(&msg) {
265                                    if sink.send(Message::Text(payload.into())).await.is_err() {
266                                        tracing::warn!("Failed to push proactive to bridge");
267                                        break;
268                                    }
269                                    tracing::debug!(id = %msg.id, "Proactive pushed to bridge");
270                                }
271                            }
272                            Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
273                                tracing::warn!(skipped = n, "Bridge proactive receiver lagged");
274                            }
275                            Err(tokio::sync::broadcast::error::RecvError::Closed) => {
276                                tracing::info!("Proactive channel closed, bridge continues in relay-only mode");
277                                proactive_rx = None;
278                            }
279                        }
280                    }
281                }
282            }
283
284            attempt += 1;
285        }
286    }
287}
288
289// ─── Tests ────────────────────────────────────────────────────────────────────
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    #[test]
296    fn test_bridge_config_defaults() {
297        let cfg = BridgeConfig::default();
298        assert_eq!(cfg.initial_backoff_ms, 1_000);
299        assert_eq!(cfg.max_backoff_ms, 60_000);
300        assert!(cfg.max_reconnect_attempts.is_none());
301    }
302
303    #[test]
304    fn test_backoff_grows_exponentially() {
305        let client = BridgeClient::new("ws://unused", BridgeConfig::default());
306        // attempt 0: 1000 ms
307        assert_eq!(client.backoff_duration(0), Duration::from_millis(1_000));
308        // attempt 1: 2000 ms
309        assert_eq!(client.backoff_duration(1), Duration::from_millis(2_000));
310        // attempt 2: 4000 ms
311        assert_eq!(client.backoff_duration(2), Duration::from_millis(4_000));
312        // attempt 3: 8000 ms
313        assert_eq!(client.backoff_duration(3), Duration::from_millis(8_000));
314    }
315
316    #[test]
317    fn test_backoff_capped_at_max() {
318        let cfg = BridgeConfig {
319            initial_backoff_ms: 1_000,
320            max_backoff_ms: 5_000,
321            max_reconnect_attempts: None,
322        };
323        let client = BridgeClient::new("ws://unused", cfg);
324        // After enough doublings it should saturate at 5000
325        assert_eq!(client.backoff_duration(10), Duration::from_millis(5_000));
326        assert_eq!(client.backoff_duration(30), Duration::from_millis(5_000));
327    }
328
329    #[test]
330    fn test_bridge_message_new() {
331        let msg = BridgeMessage::new("hello");
332        assert!(!msg.id.is_empty());
333        assert_eq!(msg.content, "hello");
334        assert!(msg.source.is_none());
335        assert!(msg.metadata.is_none());
336    }
337
338    #[test]
339    fn test_bridge_message_reply_shares_id() {
340        let original = BridgeMessage::new("what time is it?");
341        let reply = BridgeMessage::reply(&original, "It is noon.");
342        assert_eq!(
343            reply.id, original.id,
344            "reply should reuse original message ID"
345        );
346        assert_eq!(reply.content, "It is noon.");
347    }
348
349    #[test]
350    fn test_bridge_message_roundtrip_json() {
351        let mut meta = HashMap::new();
352        meta.insert("channel".to_string(), "#general".to_string());
353        let original = BridgeMessage {
354            id: "test-id-123".to_string(),
355            content: "Deploy to prod?".to_string(),
356            source: Some("slack".to_string()),
357            metadata: Some(meta),
358        };
359        let json = serde_json::to_string(&original).unwrap();
360        let decoded: BridgeMessage = serde_json::from_str(&json).unwrap();
361        assert_eq!(decoded, original);
362    }
363
364    #[tokio::test]
365    async fn test_connect_fails_and_hits_max_retries() {
366        // Point at an address that will refuse connections immediately.
367        let cfg = BridgeConfig {
368            initial_backoff_ms: 1, // tiny backoff so the test is fast
369            max_backoff_ms: 1,
370            max_reconnect_attempts: Some(3),
371        };
372        let client = BridgeClient::new("ws://127.0.0.1:19999", cfg);
373
374        let result = client
375            .connect_and_relay(|msg| async move { BridgeMessage::reply(&msg, "ok") })
376            .await;
377
378        assert!(
379            matches!(result, Err(BridgeError::MaxRetriesExceeded(3))),
380            "expected MaxRetriesExceeded(3), got: {:?}",
381            result
382        );
383    }
384}