Skip to main content

ralph_api/
transport.rs

1use std::future::Future;
2
3use anyhow::{Context, Result};
4use axum::body::Bytes;
5use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
6use axum::extract::{Query, State};
7use axum::http::HeaderMap;
8use axum::response::IntoResponse;
9use axum::routing::{get, post};
10use axum::{Json, Router};
11use futures::{SinkExt, StreamExt};
12use serde::Deserialize;
13use tokio::net::TcpListener;
14use tokio::sync::broadcast::error::RecvError;
15use tracing::{error, info, warn};
16
17use crate::config::ApiConfig;
18use crate::runtime::RpcRuntime;
19use crate::stream_domain::KEEPALIVE_INTERVAL_MS;
20
21#[derive(Clone)]
22struct AppState {
23    runtime: RpcRuntime,
24}
25
26#[derive(Debug, Clone, Deserialize)]
27#[serde(rename_all = "camelCase")]
28struct StreamQuery {
29    subscription_id: Option<String>,
30}
31
32pub fn router(runtime: RpcRuntime) -> Router {
33    Router::new()
34        .route("/health", get(health_handler))
35        .route("/rpc/v1", post(rpc_handler))
36        .route("/rpc/v1/capabilities", get(capabilities_handler))
37        .route("/rpc/v1/stream", get(stream_handler))
38        .with_state(AppState { runtime })
39}
40
41/// Format a host+port into a bind address string.
42///
43/// IPv6 addresses must be wrapped in brackets so the port is unambiguous:
44/// `::1` → `[::1]:3000`. Already-bracketed hosts (e.g. `[::1]`) are left
45/// as-is to prevent double-wrapping. IPv4 and hostnames are unchanged.
46pub(crate) fn bind_addr_string(host: &str, port: u16) -> String {
47    let trimmed = host.trim();
48    if trimmed.starts_with('[') {
49        // Already bracketed (e.g. "[::1]" from env var)
50        format!("{trimmed}:{port}")
51    } else if trimmed.contains(':') {
52        // Raw IPv6 literal — add brackets
53        format!("[{trimmed}]:{port}")
54    } else {
55        format!("{trimmed}:{port}")
56    }
57}
58
59pub async fn serve(config: ApiConfig) -> Result<()> {
60    let bind_address = bind_addr_string(&config.host, config.port);
61    let listener = TcpListener::bind(&bind_address)
62        .await
63        .with_context(|| format!("failed binding listener at {bind_address}"))?;
64
65    info!(
66        address = %bind_address,
67        auth_mode = %config.auth_mode.as_contract_mode(),
68        "starting ralph-api server"
69    );
70
71    let runtime = RpcRuntime::new(config)?;
72    serve_with_listener(listener, runtime, shutdown_signal()).await
73}
74
75pub async fn serve_with_listener<F>(
76    listener: TcpListener,
77    runtime: RpcRuntime,
78    shutdown: F,
79) -> Result<()>
80where
81    F: Future<Output = ()> + Send + 'static,
82{
83    let local_addr = listener
84        .local_addr()
85        .context("failed to read listener local_addr")?;
86    info!(%local_addr, "ralph-api listening");
87
88    axum::serve(listener, router(runtime))
89        .with_graceful_shutdown(shutdown)
90        .await
91        .context("axum server terminated with error")
92}
93
94async fn health_handler(State(state): State<AppState>) -> Json<serde_json::Value> {
95    Json(state.runtime.health_payload())
96}
97
98async fn capabilities_handler(State(state): State<AppState>) -> Json<serde_json::Value> {
99    Json(state.runtime.capabilities_payload())
100}
101
102async fn rpc_handler(
103    State(state): State<AppState>,
104    headers: HeaderMap,
105    body: Bytes,
106) -> impl IntoResponse {
107    let (status, payload) = state.runtime.handle_http_request(&body, &headers);
108    (status, Json(payload))
109}
110
111async fn stream_handler(
112    ws: WebSocketUpgrade,
113    Query(query): Query<StreamQuery>,
114    headers: HeaderMap,
115    State(state): State<AppState>,
116) -> impl IntoResponse {
117    let principal = match state.runtime.authenticate_websocket(&headers) {
118        Ok(p) => p,
119        Err(error) => {
120            let status = error.status;
121            let error_payload =
122                crate::protocol::error_envelope(&error, &state.runtime.config.served_by);
123            return (status, Json(error_payload)).into_response();
124        }
125    };
126
127    ws.on_upgrade(move |socket| {
128        stream_connection(socket, state.runtime, query.subscription_id, principal)
129    })
130}
131
132async fn stream_connection(
133    mut socket: WebSocket,
134    runtime: RpcRuntime,
135    subscription_id: Option<String>,
136    principal: String,
137) {
138    let Some(subscription_id) = subscription_id else {
139        warn!("stream connection missing subscriptionId query parameter");
140        let _ = socket.close().await;
141        return;
142    };
143
144    let streams = runtime.stream_domain();
145    if !streams.has_subscription(&subscription_id) {
146        warn!(subscription_id, "stream subscription does not exist");
147        let _ = socket.close().await;
148        return;
149    }
150
151    if streams
152        .get_subscription_principal(&subscription_id)
153        .as_deref()
154        != Some(principal.as_str())
155    {
156        warn!(subscription_id, "stream connection auth principal mismatch");
157        let _ = socket.close().await;
158        return;
159    }
160
161    let replay = match streams.replay_for_subscription(&subscription_id) {
162        Ok(replay) => replay,
163        Err(error) => {
164            warn!(subscription_id, error = %error.message, "failed preparing replay batch");
165            let _ = socket.close().await;
166            return;
167        }
168    };
169
170    if replay.dropped_count > 0 {
171        let event = streams.backpressure_event(&subscription_id, replay.dropped_count);
172        if !send_stream_event(&mut socket, &event).await {
173            return;
174        }
175    }
176
177    for event in replay.events {
178        if !send_stream_event(&mut socket, &event).await {
179            return;
180        }
181    }
182
183    let mut live_rx = streams.live_receiver();
184    let mut ticker = tokio::time::interval(std::time::Duration::from_millis(KEEPALIVE_INTERVAL_MS));
185
186    loop {
187        tokio::select! {
188            _ = ticker.tick() => {
189                let keepalive = streams.keepalive_event(&subscription_id, KEEPALIVE_INTERVAL_MS);
190                if !send_stream_event(&mut socket, &keepalive).await {
191                    break;
192                }
193            }
194            message = live_rx.recv() => {
195                match message {
196                    Ok(event) => {
197                        if streams.matches_subscription(&subscription_id, &event)
198                            && !send_stream_event(&mut socket, &event).await
199                        {
200                            break;
201                        }
202                    }
203                    Err(RecvError::Lagged(skipped)) => {
204                        let event = streams.backpressure_event(
205                            &subscription_id,
206                            usize::try_from(skipped).unwrap_or(usize::MAX),
207                        );
208                        if !send_stream_event(&mut socket, &event).await {
209                            break;
210                        }
211                    }
212                    Err(RecvError::Closed) => break,
213                }
214            }
215            message = socket.next() => {
216                match message {
217                    None | Some(Ok(Message::Close(_)) | Err(_)) => break,
218                    Some(Ok(Message::Ping(payload))) => {
219                        if socket.send(Message::Pong(payload)).await.is_err() {
220                            break;
221                        }
222                    }
223                    Some(Ok(Message::Text(_) | Message::Binary(_) | Message::Pong(_))) => {}
224                }
225            }
226        }
227    }
228}
229
230async fn send_stream_event(
231    socket: &mut WebSocket,
232    event: &crate::stream_domain::StreamEventEnvelope,
233) -> bool {
234    match serde_json::to_string(event) {
235        Ok(serialized) => socket.send(Message::Text(serialized.into())).await.is_ok(),
236        Err(error) => {
237            error!(%error, "failed to serialize stream event");
238            false
239        }
240    }
241}
242
243async fn shutdown_signal() {
244    if let Err(error) = tokio::signal::ctrl_c().await {
245        error!(%error, "failed waiting for ctrl-c shutdown signal");
246    }
247    info!("shutdown signal received");
248}
249
250#[cfg(test)]
251mod tests {
252    use super::bind_addr_string;
253
254    #[test]
255    fn ipv4_loopback_formats_without_brackets() {
256        assert_eq!(bind_addr_string("127.0.0.1", 3000), "127.0.0.1:3000");
257    }
258
259    #[test]
260    fn hostname_formats_without_brackets() {
261        assert_eq!(bind_addr_string("localhost", 8080), "localhost:8080");
262    }
263
264    #[test]
265    fn ipv6_loopback_wraps_in_brackets() {
266        assert_eq!(bind_addr_string("::1", 3000), "[::1]:3000");
267    }
268
269    #[test]
270    fn ipv6_any_wraps_in_brackets() {
271        assert_eq!(bind_addr_string("::", 3000), "[::]:3000");
272    }
273
274    #[test]
275    fn ipv6_full_address_wraps_in_brackets() {
276        assert_eq!(bind_addr_string("2001:db8::1", 443), "[2001:db8::1]:443");
277    }
278
279    #[test]
280    fn pre_bracketed_ipv6_does_not_double_wrap() {
281        // RALPH_API_HOST=[::1] should not become [[::1]]:3000
282        assert_eq!(bind_addr_string("[::1]", 3000), "[::1]:3000");
283    }
284}