pub mod approval;
pub mod dedup;
pub mod events;
pub mod health;
pub use dedup::Deduplicator;
pub use events::{Ack, BridgeEvent, BridgeEventEnvelope};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use futures::{SinkExt, StreamExt};
use tokio::sync::{mpsc, Mutex};
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::http::{HeaderValue, Request};
use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
use tokio_tungstenite::tungstenite::Message as WsMessage;
use tracing::{info, warn};
use url::Url;
const MAX_EVENT_ID_LEN: usize = 256;
const MAX_WS_MESSAGE_SIZE: usize = 1024 * 1024;
pub struct R8rBridge {
endpoint: String,
token: Option<String>,
sender: Arc<Mutex<Option<mpsc::Sender<String>>>>,
dedup: Arc<Mutex<Deduplicator>>,
health_status: Arc<Mutex<Option<BridgeEvent>>>,
connected: Arc<AtomicBool>,
}
fn build_ws_request(endpoint: &str, token: Option<&str>) -> Result<Request<()>, String> {
let parsed = Url::parse(endpoint).map_err(|e| format!("Invalid endpoint URL: {e}"))?;
let host = parsed
.host_str()
.ok_or_else(|| "Endpoint URL has no host".to_string())?;
let host_with_port = if let Some(port) = parsed.port() {
format!("{host}:{port}")
} else {
host.to_string()
};
let mut request = endpoint
.into_client_request()
.map_err(|e| format!("Failed to build WS request: {e}"))?;
let host_header =
HeaderValue::from_str(&host_with_port).map_err(|e| format!("Invalid Host header: {e}"))?;
request.headers_mut().insert("Host", host_header);
if let Some(token) = token {
let auth_header = HeaderValue::from_str(&format!("Bearer {token}"))
.map_err(|e| format!("Invalid Authorization header: {e}"))?;
request.headers_mut().insert("Authorization", auth_header);
}
Ok(request)
}
fn sanitize_endpoint(endpoint: &str) -> String {
Url::parse(endpoint)
.ok()
.map(|mut u| {
let _ = u.set_password(None);
let _ = u.set_username("");
u.to_string()
})
.unwrap_or_else(|| "[invalid url]".to_string())
}
impl R8rBridge {
pub fn new(endpoint: String, token: Option<String>) -> Self {
Self {
endpoint,
token,
sender: Arc::new(Mutex::new(None)),
dedup: Arc::new(Mutex::new(Deduplicator::default())),
health_status: Arc::new(Mutex::new(None)),
connected: Arc::new(AtomicBool::new(false)),
}
}
pub fn is_connected(&self) -> bool {
self.connected.load(Ordering::Relaxed)
}
pub async fn connect(&self) -> Result<(), String> {
let request = build_ws_request(&self.endpoint, self.token.as_deref())?;
let mut ws_config = WebSocketConfig::default();
ws_config.max_message_size = Some(MAX_WS_MESSAGE_SIZE);
ws_config.max_frame_size = Some(MAX_WS_MESSAGE_SIZE);
let (ws_stream, _response) =
tokio_tungstenite::connect_async_with_config(request, Some(ws_config), false)
.await
.map_err(|e| format!("WebSocket connection failed: {e}"))?;
let (ws_write, ws_read) = ws_stream.split();
let (tx, mut rx) = mpsc::channel::<String>(256);
{
let mut sender_guard = self.sender.lock().await;
*sender_guard = Some(tx);
}
self.connected.store(true, Ordering::Relaxed);
info!(
"r8r bridge connected to {}",
sanitize_endpoint(&self.endpoint)
);
let ws_write = Arc::new(Mutex::new(ws_write));
let ws_write_clone = Arc::clone(&ws_write);
tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
let mut writer = ws_write_clone.lock().await;
if let Err(e) = writer.send(WsMessage::Text(msg.into())).await {
warn!("r8r bridge send error: {e}");
break;
}
}
});
let dedup = Arc::clone(&self.dedup);
let health_status = Arc::clone(&self.health_status);
let connected = Arc::clone(&self.connected);
let sender = Arc::clone(&self.sender);
tokio::spawn(async move {
let mut ws_read = ws_read;
while let Some(msg_result) = ws_read.next().await {
let msg = match msg_result {
Ok(m) => m,
Err(e) => {
warn!("r8r bridge receive error: {e}");
break;
}
};
let text = match msg {
WsMessage::Text(t) => t.to_string(),
WsMessage::Close(_) => {
info!("r8r bridge received close frame");
break;
}
_ => continue,
};
let envelope: BridgeEventEnvelope = match serde_json::from_str(&text) {
Ok(e) => e,
Err(e) => {
warn!("r8r bridge: failed to parse envelope: {e}");
continue;
}
};
if envelope.id.len() > MAX_EVENT_ID_LEN {
warn!(
"r8r bridge: event ID exceeds {} chars, dropping",
MAX_EVENT_ID_LEN
);
continue;
}
let ack = Ack {
event_id: envelope.id.clone(),
};
if let Ok(ack_json) = serde_json::to_string(&ack) {
let mut writer = ws_write.lock().await;
if let Err(e) = writer.send(WsMessage::Text(ack_json.into())).await {
warn!("r8r bridge: failed to send ack: {e}");
}
}
{
let mut dd = dedup.lock().await;
if !dd.is_new(&envelope.id) {
info!("r8r bridge: acknowledged duplicate event {}", envelope.id);
continue;
}
}
match BridgeEvent::from_type_and_data(&envelope.event_type, &envelope.data) {
Ok(event) => match &event {
BridgeEvent::HealthStatus { .. } => {
let mut hs = health_status.lock().await;
*hs = Some(event);
}
BridgeEvent::ApprovalRequested {
workflow,
execution_id,
..
} => {
info!(
"r8r bridge: approval requested for {} ({})",
workflow, execution_id
);
}
BridgeEvent::ApprovalTimeout {
workflow,
execution_id,
..
} => {
info!(
"r8r bridge: approval timeout for {} ({})",
workflow, execution_id
);
}
BridgeEvent::ExecutionCompleted {
workflow,
execution_id,
..
} => {
info!(
"r8r bridge: execution completed for {} ({})",
workflow, execution_id
);
}
BridgeEvent::ExecutionFailed {
workflow,
execution_id,
..
} => {
info!(
"r8r bridge: execution failed for {} ({})",
workflow, execution_id
);
}
_ => {
info!("r8r bridge: received event type {}", envelope.event_type);
}
},
Err(e) => {
warn!("r8r bridge: failed to parse event: {e}");
}
}
}
connected.store(false, Ordering::Relaxed);
let mut sender_guard = sender.lock().await;
*sender_guard = None;
info!("r8r bridge disconnected");
});
Ok(())
}
pub async fn disconnect(&self) {
let mut sender_guard = self.sender.lock().await;
*sender_guard = None;
self.connected.store(false, Ordering::Relaxed);
info!("r8r bridge disconnected (manual)");
}
pub async fn send(&self, envelope: BridgeEventEnvelope) -> Result<(), String> {
let json = serde_json::to_string(&envelope)
.map_err(|e| format!("Failed to serialize envelope: {e}"))?;
let sender_guard = self.sender.lock().await;
match sender_guard.as_ref() {
Some(tx) => tx
.send(json)
.await
.map_err(|e| format!("Failed to send message: {e}")),
None => Err("Not connected".to_string()),
}
}
pub async fn send_health_ping(&self) -> Result<(), String> {
let envelope = BridgeEventEnvelope::new(BridgeEvent::HealthPing, None);
self.send(envelope).await
}
pub async fn last_health_status(&self) -> Option<BridgeEvent> {
let hs = self.health_status.lock().await;
hs.clone()
}
pub async fn run(&self, max_interval_secs: u64) {
let mut backoff_secs: u64 = 1;
loop {
match self.connect().await {
Ok(()) => {
backoff_secs = 1;
while self.is_connected() {
tokio::time::sleep(Duration::from_millis(250)).await;
}
}
Err(e) => {
warn!("r8r bridge connection failed: {e}");
}
}
tokio::time::sleep(Duration::from_secs(backoff_secs)).await;
backoff_secs = (backoff_secs * 2).min(max_interval_secs);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_build_ws_request_sets_host_and_authorization_headers() {
let request =
build_ws_request("ws://localhost:8080/api/ws/events", Some("secret-token")).unwrap();
assert_eq!(request.headers()["Host"], "localhost:8080");
assert_eq!(request.headers()["Authorization"], "Bearer secret-token");
}
#[test]
fn test_build_ws_request_rejects_invalid_authorization_header() {
let err = build_ws_request("ws://localhost:8080/api/ws/events", Some("bad\nvalue"))
.expect_err("invalid header should return an error");
assert!(err.contains("Authorization"));
}
#[test]
fn test_sanitize_endpoint_strips_credentials() {
let sanitized = sanitize_endpoint("ws://user:secret@host:8080/path");
assert!(!sanitized.contains("secret"), "password should be stripped");
assert!(!sanitized.contains("user"), "username should be stripped");
assert!(sanitized.contains("host:8080/path"));
}
#[test]
fn test_sanitize_endpoint_passes_clean_url() {
let sanitized = sanitize_endpoint("ws://localhost:8080/api/ws/events");
assert_eq!(sanitized, "ws://localhost:8080/api/ws/events");
}
#[test]
fn test_sanitize_endpoint_handles_invalid_url() {
assert_eq!(sanitize_endpoint("not a url"), "[invalid url]");
}
#[test]
fn test_max_event_id_len_constant() {
const { assert!(MAX_EVENT_ID_LEN >= 64, "must allow standard evt_<uuid> IDs") };
const { assert!(MAX_EVENT_ID_LEN <= 1024, "must not allow absurdly long IDs") };
}
}