Skip to main content

dbrest_core/app/
server.rs

1//! HTTP server setup and lifecycle
2//!
3//! Creates the main API server and the admin server, wires up graceful
4//! shutdown, and starts the NOTIFY listener.
5//!
6//! # Startup Sequence
7//!
8//! 1. Create database backend (connect, query version).
9//! 2. Create `AppState`.
10//! 3. Load schema cache.
11//! 4. Start admin server (separate port).
12//! 5. Start NOTIFY listener (background task).
13//! 6. Start main API server.
14//!
15//! # Graceful Shutdown
16//!
17//! Listens for `SIGTERM` and `Ctrl+C`. On receipt, stops accepting new
18//! connections and drains in-flight requests before exiting.
19
20use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
21use std::sync::Arc;
22
23use hyper_util::rt::{TokioExecutor, TokioIo};
24use hyper_util::server::conn::auto::Builder;
25use hyper_util::service::TowerToHyperService;
26use tokio::net::TcpListener;
27
28use crate::backend::{DatabaseBackend, DbVersion, SqlDialect};
29use crate::config::AppConfig;
30use crate::error::Error;
31
32use super::admin::create_admin_router;
33use super::router::create_router;
34use super::state::AppState;
35
36/// Start the dbrest server with a pre-constructed backend and dialect.
37///
38/// This is the main entry point for the application. It initializes all
39/// components and starts serving HTTP requests.
40pub async fn start_server(_config: AppConfig) -> Result<(), Error> {
41    // This function is kept as a convenience that will be called from the
42    // binary crate after constructing the backend. Since dbrest-core cannot
43    // create PgBackend directly (it lives in dbrest-postgres), the binary
44    // crate should use `start_server_with_backend` instead.
45    //
46    // For backwards compatibility during migration, this returns an error
47    // guiding callers to use the correct function.
48    Err(Error::Internal(
49        "start_server() cannot create a database backend from dbrest-core. \
50         Use start_server_with_backend() instead."
51            .to_string(),
52    ))
53}
54
55/// Start the dbrest server with an already-connected backend.
56///
57/// The caller (typically the root binary crate) is responsible for creating
58/// the database backend and querying its version.
59pub async fn start_server_with_backend(
60    db: Arc<dyn DatabaseBackend>,
61    dialect: Arc<dyn SqlDialect>,
62    db_version: DbVersion,
63    config: AppConfig,
64) -> Result<(), Error> {
65    let state = AppState::new_with_backend(db.clone(), dialect, config.clone(), db_version);
66
67    // 4. Load schema cache
68    tracing::info!("Loading schema cache…");
69    state.reload_schema_cache().await?;
70
71    // 5. Build routers
72    let main_router = create_router(state.clone());
73    let admin_router = create_admin_router(state.clone());
74
75    // 6. Cancellation channel for background tasks
76    let (cancel_tx, cancel_rx) = tokio::sync::watch::channel(false);
77
78    // 7. Start NOTIFY listener
79    if config.db_channel_enabled {
80        let listener_state = state.clone();
81        let listener_db = db.clone();
82        let channel = config.db_channel.clone();
83        tokio::spawn(async move {
84            start_notify_listener(listener_db, listener_state, &channel, cancel_rx).await;
85        });
86    }
87
88    // 8. Start admin server (if configured)
89    if let Some(admin_port) = config.admin_server_port {
90        let admin_ip = parse_address(&config.admin_server_host)?;
91        let admin_addr = SocketAddr::new(admin_ip, admin_port);
92        let admin_listener = TcpListener::bind(admin_addr)
93            .await
94            .map_err(|e| Error::Internal(format!("Failed to bind admin server: {}", e)))?;
95
96        tracing::info!(addr = %admin_addr, "Admin server listening");
97
98        tokio::spawn(async move {
99            loop {
100                let (stream, _addr) = match admin_listener.accept().await {
101                    Ok(v) => v,
102                    Err(e) => {
103                        tracing::warn!(error = %e, "Admin TCP accept error");
104                        continue;
105                    }
106                };
107
108                let svc = admin_router.clone();
109                tokio::spawn(async move {
110                    let io = TokioIo::new(stream);
111                    let hyper_svc = TowerToHyperService::new(svc);
112                    let conn = Builder::new(TokioExecutor::new());
113                    if let Err(e) = conn.serve_connection_with_upgrades(io, hyper_svc).await {
114                        tracing::debug!(error = %e, "Admin connection error");
115                    }
116                });
117            }
118        });
119    }
120
121    // 9. Start main server — Unix socket or TCP
122    #[cfg(unix)]
123    if let Some(ref socket_path) = config.server_unix_socket {
124        serve_unix_socket(main_router, socket_path, config.server_unix_socket_mode).await?;
125    } else {
126        serve_tcp(main_router, &config).await?;
127    }
128
129    #[cfg(not(unix))]
130    {
131        if config.server_unix_socket.is_some() {
132            return Err(Error::InvalidConfig {
133                message: "Unix sockets are not supported on this platform".to_string(),
134            });
135        }
136        serve_tcp(main_router, &config).await?;
137    }
138
139    // 10. Cleanup
140    tracing::info!("Shutting down…");
141    let _ = cancel_tx.send(true);
142
143    Ok(())
144}
145
146/// Background NOTIFY listener using the database backend.
147///
148/// Public variant for use by [`crate::app::builder::DbrestRouters::start_listener`].
149pub async fn start_notify_listener_public(
150    db: Arc<dyn DatabaseBackend>,
151    state: AppState,
152    channel: &str,
153    cancel: tokio::sync::watch::Receiver<bool>,
154) {
155    start_notify_listener(db, state, channel, cancel).await;
156}
157
158/// Background NOTIFY listener (internal).
159async fn start_notify_listener(
160    db: Arc<dyn DatabaseBackend>,
161    state: AppState,
162    channel: &str,
163    cancel: tokio::sync::watch::Receiver<bool>,
164) {
165    tracing::info!(channel = %channel, "Starting NOTIFY listener");
166
167    loop {
168        if *cancel.borrow() {
169            tracing::info!("NOTIFY listener shutting down");
170            return;
171        }
172
173        let state_clone = state.clone();
174        let on_event: std::sync::Arc<dyn Fn(String) + Send + Sync> =
175            std::sync::Arc::new(move |payload: String| {
176                let state = state_clone.clone();
177                tokio::spawn(async move {
178                    if (payload.contains("schema") || payload.contains("reload"))
179                        && let Err(e) = state.reload_schema_cache().await
180                    {
181                        tracing::error!(error = %e, "Failed to reload schema cache");
182                    }
183                    if payload.contains("config")
184                        && let Err(e) = state.reload_config().await
185                    {
186                        tracing::error!(error = %e, "Failed to reload config");
187                    }
188                });
189            });
190
191        match db.start_listener(channel, cancel.clone(), on_event).await {
192            Ok(()) => {
193                tracing::info!("NOTIFY listener exiting normally");
194                return;
195            }
196            Err(e) => {
197                tracing::warn!(error = %e, "NOTIFY listener disconnected, reconnecting in 5s");
198                tokio::time::sleep(std::time::Duration::from_secs(5)).await;
199            }
200        }
201    }
202}
203
204/// Start the main server on a TCP socket with HTTP/1.1 and HTTP/2 support.
205///
206/// Uses `hyper_util::server::conn::auto::Builder` to auto-negotiate the
207/// protocol. Browsers connecting over cleartext will use HTTP/1.1 with
208/// upgrade to h2c; behind a TLS-terminating proxy the ALPN negotiation
209/// selects HTTP/2 transparently.
210async fn serve_tcp(router: axum::Router, config: &AppConfig) -> Result<(), Error> {
211    let server_ip = parse_address(&config.server_host)?;
212    let server_addr = SocketAddr::new(server_ip, config.server_port);
213    let listener = TcpListener::bind(server_addr)
214        .await
215        .map_err(|e| Error::Internal(format!("Failed to bind main server: {}", e)))?;
216
217    tracing::info!(addr = %server_addr, "dbrest server listening (HTTP/1.1 + h2c)");
218
219    let shutdown = shutdown_signal();
220    tokio::pin!(shutdown);
221
222    loop {
223        tokio::select! {
224            result = listener.accept() => {
225                let (stream, _addr) = match result {
226                    Ok(v) => v,
227                    Err(e) => {
228                        tracing::warn!(error = %e, "TCP accept error");
229                        continue;
230                    }
231                };
232
233                let svc = router.clone();
234                tokio::spawn(async move {
235                    let io = TokioIo::new(stream);
236                    let hyper_svc = TowerToHyperService::new(svc);
237                    let conn = Builder::new(TokioExecutor::new());
238                    if let Err(e) = conn.serve_connection_with_upgrades(io, hyper_svc).await {
239                        tracing::debug!(error = %e, "Connection error");
240                    }
241                });
242            }
243            _ = &mut shutdown => {
244                tracing::info!("Shutting down TCP server");
245                break;
246            }
247        }
248    }
249
250    Ok(())
251}
252
253/// Start the main server on a Unix domain socket.
254#[cfg(unix)]
255async fn serve_unix_socket(
256    router: axum::Router,
257    socket_path: &std::path::Path,
258    mode: u32,
259) -> Result<(), Error> {
260    use std::os::unix::fs::PermissionsExt;
261
262    let _ = std::fs::remove_file(socket_path);
263
264    let uds = tokio::net::UnixListener::bind(socket_path).map_err(|e| {
265        Error::Internal(format!(
266            "Failed to bind Unix socket '{}': {}",
267            socket_path.display(),
268            e
269        ))
270    })?;
271
272    std::fs::set_permissions(socket_path, std::fs::Permissions::from_mode(mode)).map_err(|e| {
273        Error::Internal(format!(
274            "Failed to set socket permissions on '{}': {}",
275            socket_path.display(),
276            e
277        ))
278    })?;
279
280    tracing::info!(path = %socket_path.display(), "dbrest server listening (Unix socket)");
281
282    let shutdown = shutdown_signal();
283    tokio::pin!(shutdown);
284
285    loop {
286        tokio::select! {
287            result = uds.accept() => {
288                let (stream, _addr) = match result {
289                    Ok(v) => v,
290                    Err(e) => {
291                        tracing::warn!(error = %e, "Unix socket accept error");
292                        continue;
293                    }
294                };
295
296                let svc = router.clone();
297                tokio::spawn(async move {
298                    let io = TokioIo::new(stream);
299                    let hyper_svc = TowerToHyperService::new(svc);
300                    let conn = Builder::new(TokioExecutor::new());
301                    if let Err(e) = conn.serve_connection_with_upgrades(io, hyper_svc).await {
302                        tracing::debug!(error = %e, "Connection error");
303                    }
304                });
305            }
306            _ = &mut shutdown => {
307                tracing::info!("Shutting down Unix socket server");
308                break;
309            }
310        }
311    }
312
313    let _ = std::fs::remove_file(socket_path);
314    Ok(())
315}
316
317/// Parse a host string into an `IpAddr`.
318pub fn parse_address(host: &str) -> Result<IpAddr, Error> {
319    match host {
320        "!4" | "*" | "*4" => Ok(IpAddr::V4(Ipv4Addr::UNSPECIFIED)),
321        "!6" | "*6" => Ok(IpAddr::V6(Ipv6Addr::UNSPECIFIED)),
322        "localhost" => Ok(IpAddr::V4(Ipv4Addr::LOCALHOST)),
323        other => other.parse::<IpAddr>().map_err(|_| Error::InvalidConfig {
324            message: format!("Invalid server host: '{other}'"),
325        }),
326    }
327}
328
329/// Wait for a shutdown signal (SIGTERM or Ctrl+C).
330async fn shutdown_signal() {
331    let ctrl_c = async {
332        tokio::signal::ctrl_c()
333            .await
334            .expect("Failed to install Ctrl+C handler");
335    };
336
337    #[cfg(unix)]
338    let terminate = async {
339        tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
340            .expect("Failed to install SIGTERM handler")
341            .recv()
342            .await;
343    };
344
345    #[cfg(not(unix))]
346    let terminate = std::future::pending::<()>();
347
348    tokio::select! {
349        _ = ctrl_c => {},
350        _ = terminate => {},
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    #[test]
359    fn test_parse_address_ipv4_any() {
360        assert_eq!(
361            parse_address("!4").unwrap(),
362            IpAddr::V4(Ipv4Addr::UNSPECIFIED)
363        );
364    }
365
366    #[test]
367    fn test_parse_address_ipv6_any() {
368        assert_eq!(
369            parse_address("!6").unwrap(),
370            IpAddr::V6(Ipv6Addr::UNSPECIFIED)
371        );
372    }
373
374    #[test]
375    fn test_parse_address_star() {
376        assert_eq!(
377            parse_address("*").unwrap(),
378            IpAddr::V4(Ipv4Addr::UNSPECIFIED)
379        );
380    }
381
382    #[test]
383    fn test_parse_address_star4() {
384        assert_eq!(
385            parse_address("*4").unwrap(),
386            IpAddr::V4(Ipv4Addr::UNSPECIFIED)
387        );
388    }
389
390    #[test]
391    fn test_parse_address_star6() {
392        assert_eq!(
393            parse_address("*6").unwrap(),
394            IpAddr::V6(Ipv6Addr::UNSPECIFIED)
395        );
396    }
397
398    #[test]
399    fn test_parse_address_localhost() {
400        assert_eq!(
401            parse_address("localhost").unwrap(),
402            IpAddr::V4(Ipv4Addr::LOCALHOST)
403        );
404    }
405
406    #[test]
407    fn test_parse_address_literal_ipv4() {
408        let addr = parse_address("192.168.1.1").unwrap();
409        assert_eq!(addr, IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)));
410    }
411
412    #[test]
413    fn test_parse_address_literal_ipv6() {
414        let addr = parse_address("::1").unwrap();
415        assert_eq!(addr, IpAddr::V6(Ipv6Addr::LOCALHOST));
416    }
417
418    #[test]
419    fn test_parse_address_invalid() {
420        let err = parse_address("not-an-ip");
421        assert!(err.is_err());
422    }
423
424    #[test]
425    fn test_parse_address_loopback() {
426        assert_eq!(
427            parse_address("127.0.0.1").unwrap(),
428            IpAddr::V4(Ipv4Addr::LOCALHOST)
429        );
430    }
431}