1use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
21use std::sync::Arc;
22
23use hyper_util::rt::{TokioExecutor, TokioIo};
24use hyper_util::server::conn::auto::Builder;
25use hyper_util::service::TowerToHyperService;
26use tokio::net::TcpListener;
27
28use crate::backend::{DatabaseBackend, DbVersion, SqlDialect};
29use crate::config::AppConfig;
30use crate::error::Error;
31
32use super::admin::create_admin_router;
33use super::router::create_router;
34use super::state::AppState;
35
36pub async fn start_server(_config: AppConfig) -> Result<(), Error> {
41 Err(Error::Internal(
49 "start_server() cannot create a database backend from dbrest-core. \
50 Use start_server_with_backend() instead."
51 .to_string(),
52 ))
53}
54
55pub async fn start_server_with_backend(
60 db: Arc<dyn DatabaseBackend>,
61 dialect: Arc<dyn SqlDialect>,
62 db_version: DbVersion,
63 config: AppConfig,
64) -> Result<(), Error> {
65 let state = AppState::new_with_backend(db.clone(), dialect, config.clone(), db_version);
66
67 tracing::info!("Loading schema cache…");
69 state.reload_schema_cache().await?;
70
71 let main_router = create_router(state.clone());
73 let admin_router = create_admin_router(state.clone());
74
75 let (cancel_tx, cancel_rx) = tokio::sync::watch::channel(false);
77
78 if config.db_channel_enabled {
80 let listener_state = state.clone();
81 let listener_db = db.clone();
82 let channel = config.db_channel.clone();
83 tokio::spawn(async move {
84 start_notify_listener(listener_db, listener_state, &channel, cancel_rx).await;
85 });
86 }
87
88 if let Some(admin_port) = config.admin_server_port {
90 let admin_ip = parse_address(&config.admin_server_host)?;
91 let admin_addr = SocketAddr::new(admin_ip, admin_port);
92 let admin_listener = TcpListener::bind(admin_addr)
93 .await
94 .map_err(|e| Error::Internal(format!("Failed to bind admin server: {}", e)))?;
95
96 tracing::info!(addr = %admin_addr, "Admin server listening");
97
98 tokio::spawn(async move {
99 loop {
100 let (stream, _addr) = match admin_listener.accept().await {
101 Ok(v) => v,
102 Err(e) => {
103 tracing::warn!(error = %e, "Admin TCP accept error");
104 continue;
105 }
106 };
107
108 let svc = admin_router.clone();
109 tokio::spawn(async move {
110 let io = TokioIo::new(stream);
111 let hyper_svc = TowerToHyperService::new(svc);
112 let conn = Builder::new(TokioExecutor::new());
113 if let Err(e) = conn.serve_connection_with_upgrades(io, hyper_svc).await {
114 tracing::debug!(error = %e, "Admin connection error");
115 }
116 });
117 }
118 });
119 }
120
121 #[cfg(unix)]
123 if let Some(ref socket_path) = config.server_unix_socket {
124 serve_unix_socket(main_router, socket_path, config.server_unix_socket_mode).await?;
125 } else {
126 serve_tcp(main_router, &config).await?;
127 }
128
129 #[cfg(not(unix))]
130 {
131 if config.server_unix_socket.is_some() {
132 return Err(Error::InvalidConfig {
133 message: "Unix sockets are not supported on this platform".to_string(),
134 });
135 }
136 serve_tcp(main_router, &config).await?;
137 }
138
139 tracing::info!("Shutting down…");
141 let _ = cancel_tx.send(true);
142
143 Ok(())
144}
145
146pub async fn start_notify_listener_public(
150 db: Arc<dyn DatabaseBackend>,
151 state: AppState,
152 channel: &str,
153 cancel: tokio::sync::watch::Receiver<bool>,
154) {
155 start_notify_listener(db, state, channel, cancel).await;
156}
157
158async fn start_notify_listener(
160 db: Arc<dyn DatabaseBackend>,
161 state: AppState,
162 channel: &str,
163 cancel: tokio::sync::watch::Receiver<bool>,
164) {
165 tracing::info!(channel = %channel, "Starting NOTIFY listener");
166
167 loop {
168 if *cancel.borrow() {
169 tracing::info!("NOTIFY listener shutting down");
170 return;
171 }
172
173 let state_clone = state.clone();
174 let on_event: std::sync::Arc<dyn Fn(String) + Send + Sync> =
175 std::sync::Arc::new(move |payload: String| {
176 let state = state_clone.clone();
177 tokio::spawn(async move {
178 if (payload.contains("schema") || payload.contains("reload"))
179 && let Err(e) = state.reload_schema_cache().await
180 {
181 tracing::error!(error = %e, "Failed to reload schema cache");
182 }
183 if payload.contains("config")
184 && let Err(e) = state.reload_config().await
185 {
186 tracing::error!(error = %e, "Failed to reload config");
187 }
188 });
189 });
190
191 match db.start_listener(channel, cancel.clone(), on_event).await {
192 Ok(()) => {
193 tracing::info!("NOTIFY listener exiting normally");
194 return;
195 }
196 Err(e) => {
197 tracing::warn!(error = %e, "NOTIFY listener disconnected, reconnecting in 5s");
198 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
199 }
200 }
201 }
202}
203
204async fn serve_tcp(router: axum::Router, config: &AppConfig) -> Result<(), Error> {
211 let server_ip = parse_address(&config.server_host)?;
212 let server_addr = SocketAddr::new(server_ip, config.server_port);
213 let listener = TcpListener::bind(server_addr)
214 .await
215 .map_err(|e| Error::Internal(format!("Failed to bind main server: {}", e)))?;
216
217 tracing::info!(addr = %server_addr, "dbrest server listening (HTTP/1.1 + h2c)");
218
219 let shutdown = shutdown_signal();
220 tokio::pin!(shutdown);
221
222 loop {
223 tokio::select! {
224 result = listener.accept() => {
225 let (stream, _addr) = match result {
226 Ok(v) => v,
227 Err(e) => {
228 tracing::warn!(error = %e, "TCP accept error");
229 continue;
230 }
231 };
232
233 let svc = router.clone();
234 tokio::spawn(async move {
235 let io = TokioIo::new(stream);
236 let hyper_svc = TowerToHyperService::new(svc);
237 let conn = Builder::new(TokioExecutor::new());
238 if let Err(e) = conn.serve_connection_with_upgrades(io, hyper_svc).await {
239 tracing::debug!(error = %e, "Connection error");
240 }
241 });
242 }
243 _ = &mut shutdown => {
244 tracing::info!("Shutting down TCP server");
245 break;
246 }
247 }
248 }
249
250 Ok(())
251}
252
253#[cfg(unix)]
255async fn serve_unix_socket(
256 router: axum::Router,
257 socket_path: &std::path::Path,
258 mode: u32,
259) -> Result<(), Error> {
260 use std::os::unix::fs::PermissionsExt;
261
262 let _ = std::fs::remove_file(socket_path);
263
264 let uds = tokio::net::UnixListener::bind(socket_path).map_err(|e| {
265 Error::Internal(format!(
266 "Failed to bind Unix socket '{}': {}",
267 socket_path.display(),
268 e
269 ))
270 })?;
271
272 std::fs::set_permissions(socket_path, std::fs::Permissions::from_mode(mode)).map_err(|e| {
273 Error::Internal(format!(
274 "Failed to set socket permissions on '{}': {}",
275 socket_path.display(),
276 e
277 ))
278 })?;
279
280 tracing::info!(path = %socket_path.display(), "dbrest server listening (Unix socket)");
281
282 let shutdown = shutdown_signal();
283 tokio::pin!(shutdown);
284
285 loop {
286 tokio::select! {
287 result = uds.accept() => {
288 let (stream, _addr) = match result {
289 Ok(v) => v,
290 Err(e) => {
291 tracing::warn!(error = %e, "Unix socket accept error");
292 continue;
293 }
294 };
295
296 let svc = router.clone();
297 tokio::spawn(async move {
298 let io = TokioIo::new(stream);
299 let hyper_svc = TowerToHyperService::new(svc);
300 let conn = Builder::new(TokioExecutor::new());
301 if let Err(e) = conn.serve_connection_with_upgrades(io, hyper_svc).await {
302 tracing::debug!(error = %e, "Connection error");
303 }
304 });
305 }
306 _ = &mut shutdown => {
307 tracing::info!("Shutting down Unix socket server");
308 break;
309 }
310 }
311 }
312
313 let _ = std::fs::remove_file(socket_path);
314 Ok(())
315}
316
317pub fn parse_address(host: &str) -> Result<IpAddr, Error> {
319 match host {
320 "!4" | "*" | "*4" => Ok(IpAddr::V4(Ipv4Addr::UNSPECIFIED)),
321 "!6" | "*6" => Ok(IpAddr::V6(Ipv6Addr::UNSPECIFIED)),
322 "localhost" => Ok(IpAddr::V4(Ipv4Addr::LOCALHOST)),
323 other => other.parse::<IpAddr>().map_err(|_| Error::InvalidConfig {
324 message: format!("Invalid server host: '{other}'"),
325 }),
326 }
327}
328
329async fn shutdown_signal() {
331 let ctrl_c = async {
332 tokio::signal::ctrl_c()
333 .await
334 .expect("Failed to install Ctrl+C handler");
335 };
336
337 #[cfg(unix)]
338 let terminate = async {
339 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
340 .expect("Failed to install SIGTERM handler")
341 .recv()
342 .await;
343 };
344
345 #[cfg(not(unix))]
346 let terminate = std::future::pending::<()>();
347
348 tokio::select! {
349 _ = ctrl_c => {},
350 _ = terminate => {},
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 #[test]
359 fn test_parse_address_ipv4_any() {
360 assert_eq!(
361 parse_address("!4").unwrap(),
362 IpAddr::V4(Ipv4Addr::UNSPECIFIED)
363 );
364 }
365
366 #[test]
367 fn test_parse_address_ipv6_any() {
368 assert_eq!(
369 parse_address("!6").unwrap(),
370 IpAddr::V6(Ipv6Addr::UNSPECIFIED)
371 );
372 }
373
374 #[test]
375 fn test_parse_address_star() {
376 assert_eq!(
377 parse_address("*").unwrap(),
378 IpAddr::V4(Ipv4Addr::UNSPECIFIED)
379 );
380 }
381
382 #[test]
383 fn test_parse_address_star4() {
384 assert_eq!(
385 parse_address("*4").unwrap(),
386 IpAddr::V4(Ipv4Addr::UNSPECIFIED)
387 );
388 }
389
390 #[test]
391 fn test_parse_address_star6() {
392 assert_eq!(
393 parse_address("*6").unwrap(),
394 IpAddr::V6(Ipv6Addr::UNSPECIFIED)
395 );
396 }
397
398 #[test]
399 fn test_parse_address_localhost() {
400 assert_eq!(
401 parse_address("localhost").unwrap(),
402 IpAddr::V4(Ipv4Addr::LOCALHOST)
403 );
404 }
405
406 #[test]
407 fn test_parse_address_literal_ipv4() {
408 let addr = parse_address("192.168.1.1").unwrap();
409 assert_eq!(addr, IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)));
410 }
411
412 #[test]
413 fn test_parse_address_literal_ipv6() {
414 let addr = parse_address("::1").unwrap();
415 assert_eq!(addr, IpAddr::V6(Ipv6Addr::LOCALHOST));
416 }
417
418 #[test]
419 fn test_parse_address_invalid() {
420 let err = parse_address("not-an-ip");
421 assert!(err.is_err());
422 }
423
424 #[test]
425 fn test_parse_address_loopback() {
426 assert_eq!(
427 parse_address("127.0.0.1").unwrap(),
428 IpAddr::V4(Ipv4Addr::LOCALHOST)
429 );
430 }
431}