1use std::future::Future;
2
3use anyhow::{Context, Result};
4use axum::body::Bytes;
5use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
6use axum::extract::{Query, State};
7use axum::http::HeaderMap;
8use axum::response::IntoResponse;
9use axum::routing::{get, post};
10use axum::{Json, Router};
11use futures::{SinkExt, StreamExt};
12use serde::Deserialize;
13use tokio::net::TcpListener;
14use tokio::sync::broadcast::error::RecvError;
15use tracing::{error, info, warn};
16
17use crate::config::ApiConfig;
18use crate::runtime::RpcRuntime;
19use crate::stream_domain::KEEPALIVE_INTERVAL_MS;
20
21#[derive(Clone)]
22struct AppState {
23 runtime: RpcRuntime,
24}
25
26#[derive(Debug, Clone, Deserialize)]
27#[serde(rename_all = "camelCase")]
28struct StreamQuery {
29 subscription_id: Option<String>,
30}
31
32pub fn router(runtime: RpcRuntime) -> Router {
33 Router::new()
34 .route("/health", get(health_handler))
35 .route("/rpc/v1", post(rpc_handler))
36 .route("/rpc/v1/capabilities", get(capabilities_handler))
37 .route("/rpc/v1/stream", get(stream_handler))
38 .with_state(AppState { runtime })
39}
40
41pub(crate) fn bind_addr_string(host: &str, port: u16) -> String {
47 let trimmed = host.trim();
48 if trimmed.starts_with('[') {
49 format!("{trimmed}:{port}")
51 } else if trimmed.contains(':') {
52 format!("[{trimmed}]:{port}")
54 } else {
55 format!("{trimmed}:{port}")
56 }
57}
58
59pub async fn serve(config: ApiConfig) -> Result<()> {
60 let bind_address = bind_addr_string(&config.host, config.port);
61 let listener = TcpListener::bind(&bind_address)
62 .await
63 .with_context(|| format!("failed binding listener at {bind_address}"))?;
64
65 info!(
66 address = %bind_address,
67 auth_mode = %config.auth_mode.as_contract_mode(),
68 "starting ralph-api server"
69 );
70
71 let runtime = RpcRuntime::new(config)?;
72 serve_with_listener(listener, runtime, shutdown_signal()).await
73}
74
75pub async fn serve_with_listener<F>(
76 listener: TcpListener,
77 runtime: RpcRuntime,
78 shutdown: F,
79) -> Result<()>
80where
81 F: Future<Output = ()> + Send + 'static,
82{
83 let local_addr = listener
84 .local_addr()
85 .context("failed to read listener local_addr")?;
86 info!(%local_addr, "ralph-api listening");
87
88 axum::serve(listener, router(runtime))
89 .with_graceful_shutdown(shutdown)
90 .await
91 .context("axum server terminated with error")
92}
93
94async fn health_handler(State(state): State<AppState>) -> Json<serde_json::Value> {
95 Json(state.runtime.health_payload())
96}
97
98async fn capabilities_handler(State(state): State<AppState>) -> Json<serde_json::Value> {
99 Json(state.runtime.capabilities_payload())
100}
101
102async fn rpc_handler(
103 State(state): State<AppState>,
104 headers: HeaderMap,
105 body: Bytes,
106) -> impl IntoResponse {
107 let (status, payload) = state.runtime.handle_http_request(&body, &headers);
108 (status, Json(payload))
109}
110
111async fn stream_handler(
112 ws: WebSocketUpgrade,
113 Query(query): Query<StreamQuery>,
114 headers: HeaderMap,
115 State(state): State<AppState>,
116) -> impl IntoResponse {
117 let principal = match state.runtime.authenticate_websocket(&headers) {
118 Ok(p) => p,
119 Err(error) => {
120 let status = error.status;
121 let error_payload =
122 crate::protocol::error_envelope(&error, &state.runtime.config.served_by);
123 return (status, Json(error_payload)).into_response();
124 }
125 };
126
127 ws.on_upgrade(move |socket| {
128 stream_connection(socket, state.runtime, query.subscription_id, principal)
129 })
130}
131
132async fn stream_connection(
133 mut socket: WebSocket,
134 runtime: RpcRuntime,
135 subscription_id: Option<String>,
136 principal: String,
137) {
138 let Some(subscription_id) = subscription_id else {
139 warn!("stream connection missing subscriptionId query parameter");
140 let _ = socket.close().await;
141 return;
142 };
143
144 let streams = runtime.stream_domain();
145 if !streams.has_subscription(&subscription_id) {
146 warn!(subscription_id, "stream subscription does not exist");
147 let _ = socket.close().await;
148 return;
149 }
150
151 if streams
152 .get_subscription_principal(&subscription_id)
153 .as_deref()
154 != Some(principal.as_str())
155 {
156 warn!(subscription_id, "stream connection auth principal mismatch");
157 let _ = socket.close().await;
158 return;
159 }
160
161 let replay = match streams.replay_for_subscription(&subscription_id) {
162 Ok(replay) => replay,
163 Err(error) => {
164 warn!(subscription_id, error = %error.message, "failed preparing replay batch");
165 let _ = socket.close().await;
166 return;
167 }
168 };
169
170 if replay.dropped_count > 0 {
171 let event = streams.backpressure_event(&subscription_id, replay.dropped_count);
172 if !send_stream_event(&mut socket, &event).await {
173 return;
174 }
175 }
176
177 for event in replay.events {
178 if !send_stream_event(&mut socket, &event).await {
179 return;
180 }
181 }
182
183 let mut live_rx = streams.live_receiver();
184 let mut ticker = tokio::time::interval(std::time::Duration::from_millis(KEEPALIVE_INTERVAL_MS));
185
186 loop {
187 tokio::select! {
188 _ = ticker.tick() => {
189 let keepalive = streams.keepalive_event(&subscription_id, KEEPALIVE_INTERVAL_MS);
190 if !send_stream_event(&mut socket, &keepalive).await {
191 break;
192 }
193 }
194 message = live_rx.recv() => {
195 match message {
196 Ok(event) => {
197 if streams.matches_subscription(&subscription_id, &event)
198 && !send_stream_event(&mut socket, &event).await
199 {
200 break;
201 }
202 }
203 Err(RecvError::Lagged(skipped)) => {
204 let event = streams.backpressure_event(
205 &subscription_id,
206 usize::try_from(skipped).unwrap_or(usize::MAX),
207 );
208 if !send_stream_event(&mut socket, &event).await {
209 break;
210 }
211 }
212 Err(RecvError::Closed) => break,
213 }
214 }
215 message = socket.next() => {
216 match message {
217 None | Some(Ok(Message::Close(_)) | Err(_)) => break,
218 Some(Ok(Message::Ping(payload))) => {
219 if socket.send(Message::Pong(payload)).await.is_err() {
220 break;
221 }
222 }
223 Some(Ok(Message::Text(_) | Message::Binary(_) | Message::Pong(_))) => {}
224 }
225 }
226 }
227 }
228}
229
230async fn send_stream_event(
231 socket: &mut WebSocket,
232 event: &crate::stream_domain::StreamEventEnvelope,
233) -> bool {
234 match serde_json::to_string(event) {
235 Ok(serialized) => socket.send(Message::Text(serialized.into())).await.is_ok(),
236 Err(error) => {
237 error!(%error, "failed to serialize stream event");
238 false
239 }
240 }
241}
242
243async fn shutdown_signal() {
244 if let Err(error) = tokio::signal::ctrl_c().await {
245 error!(%error, "failed waiting for ctrl-c shutdown signal");
246 }
247 info!("shutdown signal received");
248}
249
250#[cfg(test)]
251mod tests {
252 use super::bind_addr_string;
253
254 #[test]
255 fn ipv4_loopback_formats_without_brackets() {
256 assert_eq!(bind_addr_string("127.0.0.1", 3000), "127.0.0.1:3000");
257 }
258
259 #[test]
260 fn hostname_formats_without_brackets() {
261 assert_eq!(bind_addr_string("localhost", 8080), "localhost:8080");
262 }
263
264 #[test]
265 fn ipv6_loopback_wraps_in_brackets() {
266 assert_eq!(bind_addr_string("::1", 3000), "[::1]:3000");
267 }
268
269 #[test]
270 fn ipv6_any_wraps_in_brackets() {
271 assert_eq!(bind_addr_string("::", 3000), "[::]:3000");
272 }
273
274 #[test]
275 fn ipv6_full_address_wraps_in_brackets() {
276 assert_eq!(bind_addr_string("2001:db8::1", 443), "[2001:db8::1]:443");
277 }
278
279 #[test]
280 fn pre_bracketed_ipv6_does_not_double_wrap() {
281 assert_eq!(bind_addr_string("[::1]", 3000), "[::1]:3000");
283 }
284}