1pub mod auth;
18pub mod handler;
19pub mod metrics;
20pub mod type_map;
21
22use crate::server::handler::{scan_max_txn_id_from_storage, EntHandler};
23use crate::server::metrics::ServerMetrics;
24use entdb::catalog::Catalog;
25use entdb::error::{EntDbError, Result};
26use entdb::query::history::OptimizerHistoryRecorder;
27use entdb::query::optimizer::OptimizerConfig;
28use entdb::storage::buffer_pool::BufferPool;
29use entdb::storage::buffer_pool::BufferPoolStats;
30use entdb::storage::disk_manager::DiskManager;
31use entdb::tx::TransactionManager;
32use entdb::wal::log_manager::LogManager;
33use entdb::wal::recovery::RecoveryManager;
34use futures::Sink;
35use pgwire::api::auth::md5pass::Md5PasswordAuthStartupHandler;
36use pgwire::api::auth::scram::SASLScramAuthStartupHandler;
37use pgwire::api::auth::DefaultServerParameterProvider;
38use pgwire::api::auth::StartupHandler;
39use pgwire::api::copy::NoopCopyHandler;
40use pgwire::api::ClientInfo;
41use pgwire::api::NoopErrorHandler;
42use pgwire::api::PgWireServerHandlers;
43use pgwire::error::PgWireError;
44use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
45use pgwire::tokio::process_socket;
46use rustls_pemfile::{certs, pkcs8_private_keys};
47use std::path::{Path, PathBuf};
48use std::sync::Arc;
49use tokio::net::TcpListener;
50use tokio::sync::{oneshot, OwnedSemaphorePermit, Semaphore};
51use tokio_rustls::rustls::ServerConfig as RustlsServerConfig;
52use tokio_rustls::TlsAcceptor;
53use tracing::{error, info};
54
55#[derive(Debug, Clone)]
56pub struct ServerConfig {
57 pub data_path: PathBuf,
58 pub host: String,
59 pub port: u16,
60 pub buffer_pool_size: usize,
61 pub max_connections: usize,
62 pub max_statement_bytes: usize,
63 pub query_timeout_ms: u64,
64 pub auth_method: crate::server::auth::AuthMethod,
65 pub scram_iterations: usize,
66 pub auth_user: String,
67 pub auth_password: String,
68 pub tls_cert: Option<PathBuf>,
69 pub tls_key: Option<PathBuf>,
70}
71
72impl ServerConfig {
73 pub fn listen_addr(&self) -> String {
74 format!("{}:{}", self.host, self.port)
75 }
76
77 pub fn validate(&self) -> Result<()> {
78 if self.host.trim().is_empty() {
79 return Err(EntDbError::Query("host cannot be empty".to_string()));
80 }
81 if self.port == 0 {
82 return Err(EntDbError::Query("port must be > 0".to_string()));
83 }
84 if self.buffer_pool_size == 0 {
85 return Err(EntDbError::Query(
86 "buffer_pool_size must be > 0".to_string(),
87 ));
88 }
89 if self.max_connections == 0 {
90 return Err(EntDbError::Query("max_connections must be > 0".to_string()));
91 }
92 if self.max_statement_bytes == 0 {
93 return Err(EntDbError::Query(
94 "max_statement_bytes must be > 0".to_string(),
95 ));
96 }
97 if self.query_timeout_ms == 0 {
98 return Err(EntDbError::Query(
99 "query_timeout_ms must be > 0".to_string(),
100 ));
101 }
102 if self.auth_user.trim().is_empty() {
103 return Err(EntDbError::Query("auth_user cannot be empty".to_string()));
104 }
105 if self.auth_password.is_empty() {
106 return Err(EntDbError::Query(
107 "auth_password cannot be empty".to_string(),
108 ));
109 }
110 if matches!(
111 self.auth_method,
112 crate::server::auth::AuthMethod::ScramSha256
113 ) && self.scram_iterations < 4096
114 {
115 return Err(EntDbError::Query(
116 "scram_iterations must be >= 4096".to_string(),
117 ));
118 }
119 match (&self.tls_cert, &self.tls_key) {
120 (Some(cert), Some(key)) => {
121 if !cert.exists() {
122 return Err(EntDbError::Query(format!(
123 "tls_cert path does not exist: {}",
124 cert.display()
125 )));
126 }
127 if !key.exists() {
128 return Err(EntDbError::Query(format!(
129 "tls_key path does not exist: {}",
130 key.display()
131 )));
132 }
133 }
134 (None, None) => {}
135 _ => {
136 return Err(EntDbError::Query(
137 "tls_cert and tls_key must both be provided together".to_string(),
138 ))
139 }
140 }
141 Ok(())
142 }
143}
144
145pub struct Database {
146 pub disk_manager: Arc<DiskManager>,
147 pub log_manager: Arc<LogManager>,
148 pub buffer_pool: Arc<BufferPool>,
149 pub catalog: Arc<Catalog>,
150 pub txn_manager: Arc<TransactionManager>,
151 pub optimizer_history: Arc<OptimizerHistoryRecorder>,
152 pub optimizer_config: OptimizerConfig,
153}
154
155impl Database {
156 pub fn open(data_path: &Path, buffer_pool_size: usize) -> Result<Self> {
157 let disk_manager = Arc::new(DiskManager::new(data_path)?);
158
159 let mut wal_path = data_path.to_path_buf();
160 wal_path.set_extension("wal");
161 let log_manager = Arc::new(LogManager::new(wal_path, 4096)?);
162
163 let buffer_pool = Arc::new(BufferPool::with_log_manager(
164 buffer_pool_size,
165 Arc::clone(&disk_manager),
166 Arc::clone(&log_manager),
167 ));
168
169 RecoveryManager::new(Arc::clone(&log_manager), Arc::clone(&buffer_pool)).recover()?;
170
171 let catalog = Arc::new(Catalog::load(Arc::clone(&buffer_pool))?);
172 validate_catalog_page_references(&catalog)?;
173
174 let mut txn_state_path = data_path.to_path_buf();
175 txn_state_path.set_extension("txn.json");
176 let mut txn_wal_path = data_path.to_path_buf();
177 txn_wal_path.set_extension("txn.wal");
178
179 let txn_manager = TransactionManager::with_wal_persistence(&txn_state_path, &txn_wal_path)
180 .or_else(|_| TransactionManager::with_persistence(&txn_state_path))
181 .unwrap_or_else(|_| TransactionManager::new());
182
183 if let Ok(max_txn) = scan_max_txn_id_from_storage(&catalog) {
184 txn_manager.ensure_min_next_txn_id(max_txn.saturating_add(1));
185 }
186
187 let optimizer_history_path = optimizer_history_path_for_data_path(data_path);
188 let optimizer_history = OptimizerHistoryRecorder::new(
189 optimizer_history_path,
190 optimizer_history_schema_hash(),
191 16,
192 1024,
193 )
194 .or_else(|_| {
195 OptimizerHistoryRecorder::new(
196 std::env::temp_dir().join("entdb.optimizer_history.server.fallback.json"),
197 optimizer_history_schema_hash(),
198 16,
199 1024,
200 )
201 })?;
202
203 let mut optimizer_config = OptimizerConfig::default();
204 if let Ok(v) = std::env::var("ENTDB_CBO") {
205 optimizer_config.cbo_enabled = v == "1" || v.eq_ignore_ascii_case("true");
206 }
207 if let Ok(v) = std::env::var("ENTDB_HBO") {
208 optimizer_config.hbo_enabled = v == "1" || v.eq_ignore_ascii_case("true");
209 }
210 if let Ok(v) = std::env::var("ENTDB_OPT_MAX_SEARCH_MS") {
211 if let Ok(ms) = v.parse::<u64>() {
212 optimizer_config.max_search_ms = ms;
213 }
214 }
215 if let Ok(v) = std::env::var("ENTDB_OPT_MAX_JOIN_RELATIONS") {
216 if let Ok(n) = v.parse::<usize>() {
217 optimizer_config.max_join_relations = n;
218 }
219 }
220 let optimizer_config = optimizer_config.sanitize();
221
222 Ok(Self {
223 disk_manager,
224 log_manager,
225 buffer_pool,
226 catalog,
227 txn_manager: Arc::new(txn_manager),
228 optimizer_history: Arc::new(optimizer_history),
229 optimizer_config,
230 })
231 }
232}
233
234pub fn optimizer_history_path_for_data_path(data_path: &Path) -> PathBuf {
235 let mut p = data_path.to_path_buf();
236 p.set_extension("optimizer_history.json");
237 p
238}
239
240pub fn optimizer_history_schema_hash() -> &'static str {
241 "optimizer_history_schema_v1_planner_v1"
242}
243
244fn validate_catalog_page_references(catalog: &Catalog) -> Result<()> {
245 let bp = catalog.buffer_pool();
246 for table in catalog.list_tables() {
247 bp.fetch_page(table.first_page_id).map_err(|e| {
248 EntDbError::Corruption(format!(
249 "catalog table '{}' references missing first_page_id {}: {e}",
250 table.name, table.first_page_id
251 ))
252 })?;
253
254 for idx in &table.indexes {
255 bp.fetch_page(idx.root_page_id).map_err(|e| {
256 EntDbError::Corruption(format!(
257 "catalog index '{}.{}' references missing root_page_id {}: {e}",
258 table.name, idx.name, idx.root_page_id
259 ))
260 })?;
261 }
262 }
263 Ok(())
264}
265
266struct EntHandlerFactory {
267 startup_handler: Arc<EntStartupHandler>,
268 query_handler: Arc<EntHandler>,
269}
270
271pub enum EntStartupHandler {
272 Md5(
273 Md5PasswordAuthStartupHandler<
274 crate::server::auth::EntAuthSource,
275 DefaultServerParameterProvider,
276 >,
277 ),
278 Scram(
279 SASLScramAuthStartupHandler<
280 crate::server::auth::EntAuthSource,
281 DefaultServerParameterProvider,
282 >,
283 ),
284}
285
286#[async_trait::async_trait]
287impl StartupHandler for EntStartupHandler {
288 async fn on_startup<C>(
289 &self,
290 client: &mut C,
291 message: PgWireFrontendMessage,
292 ) -> pgwire::error::PgWireResult<()>
293 where
294 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
295 C::Error: std::fmt::Debug,
296 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
297 {
298 match self {
299 EntStartupHandler::Md5(h) => h.on_startup(client, message).await,
300 EntStartupHandler::Scram(h) => h.on_startup(client, message).await,
301 }
302 }
303}
304
305impl EntHandlerFactory {
306 fn new(config: Arc<ServerConfig>, db: Arc<Database>, metrics: Arc<ServerMetrics>) -> Self {
307 let auth_source = Arc::new(crate::server::auth::EntAuthSource {
308 method: config.auth_method,
309 expected_user: config.auth_user.clone(),
310 expected_password: config.auth_password.clone(),
311 scram_iterations: config.scram_iterations,
312 });
313 let params = Arc::new(DefaultServerParameterProvider::default());
314 let startup_handler = match config.auth_method {
315 crate::server::auth::AuthMethod::Md5 => Arc::new(EntStartupHandler::Md5(
316 Md5PasswordAuthStartupHandler::new(auth_source, params),
317 )),
318 crate::server::auth::AuthMethod::ScramSha256 => {
319 let mut scram = SASLScramAuthStartupHandler::new(auth_source, params);
320 scram.set_iterations(config.scram_iterations);
321 Arc::new(EntStartupHandler::Scram(scram))
322 }
323 };
324 Self {
325 startup_handler,
326 query_handler: Arc::new(EntHandler::new(
327 db,
328 config.max_statement_bytes,
329 config.query_timeout_ms,
330 metrics,
331 )),
332 }
333 }
334}
335
336impl PgWireServerHandlers for EntHandlerFactory {
337 type StartupHandler = EntStartupHandler;
338 type SimpleQueryHandler = EntHandler;
339 type ExtendedQueryHandler = EntHandler;
340 type CopyHandler = NoopCopyHandler;
341 type ErrorHandler = NoopErrorHandler;
342
343 fn simple_query_handler(&self) -> Arc<Self::SimpleQueryHandler> {
344 Arc::clone(&self.query_handler)
345 }
346
347 fn extended_query_handler(&self) -> Arc<Self::ExtendedQueryHandler> {
348 Arc::clone(&self.query_handler)
349 }
350
351 fn startup_handler(&self) -> Arc<Self::StartupHandler> {
352 Arc::clone(&self.startup_handler)
353 }
354
355 fn copy_handler(&self) -> Arc<Self::CopyHandler> {
356 Arc::new(NoopCopyHandler)
357 }
358
359 fn error_handler(&self) -> Arc<Self::ErrorHandler> {
360 Arc::new(NoopErrorHandler)
361 }
362}
363
364pub async fn run(config: ServerConfig) -> Result<()> {
365 config.validate()?;
366 let config = Arc::new(config);
367 let database = Arc::new(Database::open(&config.data_path, config.buffer_pool_size)?);
368 let tls_acceptor = build_tls_acceptor(&config)?;
369 let listener = TcpListener::bind(config.listen_addr()).await?;
370 serve(listener, config, database, tls_acceptor, None).await
371}
372
373pub async fn serve(
374 listener: TcpListener,
375 config: Arc<ServerConfig>,
376 database: Arc<Database>,
377 tls_acceptor: Option<Arc<TlsAcceptor>>,
378 mut shutdown: Option<oneshot::Receiver<()>>,
379) -> Result<()> {
380 let addr = listener.local_addr()?;
381 info!(%addr, "entdb server listening");
382 let conn_limit = Arc::new(Semaphore::new(config.max_connections));
383 let metrics = Arc::new(ServerMetrics::default());
384
385 loop {
386 tokio::select! {
387 _ = async {
388 if let Some(rx) = &mut shutdown {
389 let _ = rx.await;
390 }
391 }, if shutdown.is_some() => {
392 info!(%addr, "shutdown signal received");
393 let bp_stats: BufferPoolStats = database.buffer_pool.stats();
394 metrics.set_buffer_pool_pressure(bp_stats);
395
396 let flush_started = std::time::Instant::now();
397 database.buffer_pool.flush_all()?;
398 metrics.on_shutdown_flush(flush_started.elapsed().as_nanos() as u64);
399
400 let persist_started = std::time::Instant::now();
401 database.txn_manager.persist_state()?;
402 metrics.on_shutdown_persist(persist_started.elapsed().as_nanos() as u64);
403 info!(?bp_stats, metrics=?metrics.snapshot(), "server shutdown metrics");
404 break;
405 }
406 accepted = listener.accept() => {
407 let (socket, peer) = accepted?;
408 let permit = match Arc::clone(&conn_limit).try_acquire_owned() {
409 Ok(p) => p,
410 Err(_) => {
411 metrics.on_connection_refused();
412 info!(%peer, "connection refused due to max_connections limit");
413 continue;
414 }
415 };
416 metrics.on_connection_accepted();
417 let factory = Arc::new(EntHandlerFactory::new(
418 Arc::clone(&config),
419 Arc::clone(&database),
420 Arc::clone(&metrics),
421 ));
422 let tls_for_conn = tls_acceptor.clone();
423 let metrics_for_conn = Arc::clone(&metrics);
424 info!(%peer, "accepted connection");
425 tokio::spawn(async move {
426 let _permit: OwnedSemaphorePermit = permit;
427 if let Err(err) = process_socket(socket, tls_for_conn, factory).await {
428 error!(%peer, error = %err, "connection processing error");
429 }
430 metrics_for_conn.on_connection_closed();
431 });
432 }
433 }
434 }
435
436 Ok(())
437}
438
439fn build_tls_acceptor(config: &ServerConfig) -> Result<Option<Arc<TlsAcceptor>>> {
440 let (Some(cert_path), Some(key_path)) = (&config.tls_cert, &config.tls_key) else {
441 return Ok(None);
442 };
443
444 let cert_file = std::fs::File::open(cert_path)?;
445 let mut cert_reader = std::io::BufReader::new(cert_file);
446 let cert_chain = certs(&mut cert_reader)
447 .collect::<std::result::Result<Vec<_>, _>>()
448 .map_err(|e| EntDbError::Query(format!("invalid tls cert PEM: {e}")))?;
449
450 let key_file = std::fs::File::open(key_path)?;
451 let mut key_reader = std::io::BufReader::new(key_file);
452 let mut keys = pkcs8_private_keys(&mut key_reader)
453 .collect::<std::result::Result<Vec<_>, _>>()
454 .map_err(|e| EntDbError::Query(format!("invalid tls key PEM: {e}")))?;
455 let Some(key) = keys.pop() else {
456 return Err(EntDbError::Query(
457 "tls key PEM has no PKCS8 private key".to_string(),
458 ));
459 };
460
461 let rustls = RustlsServerConfig::builder()
462 .with_no_client_auth()
463 .with_single_cert(cert_chain, key.into())
464 .map_err(|e| EntDbError::Query(format!("failed to build tls config: {e}")))?;
465
466 Ok(Some(Arc::new(TlsAcceptor::from(Arc::new(rustls)))))
467}
468
469#[cfg(test)]
470mod tests {
471 use super::{serve, Database, ServerConfig};
472 use entdb::catalog::{Column, Schema};
473 use entdb::types::DataType;
474 use std::sync::Arc;
475 use tempfile::tempdir;
476 use tokio::net::TcpListener;
477 use tokio::sync::oneshot;
478
479 #[tokio::test]
480 async fn server_accepts_and_stops_with_shutdown_signal() {
481 let dir = tempdir().expect("tempdir");
482 let data_path = dir.path().join("server.db");
483 let db = Arc::new(Database::open(&data_path, 64).expect("open db"));
484 let cfg = Arc::new(ServerConfig {
485 data_path,
486 host: "127.0.0.1".to_string(),
487 port: 0,
488 buffer_pool_size: 64,
489 max_connections: 4,
490 max_statement_bytes: 1024 * 1024,
491 query_timeout_ms: 30_000,
492 auth_method: crate::server::auth::AuthMethod::Md5,
493 scram_iterations: 4096,
494 auth_user: "entdb".to_string(),
495 auth_password: "entdb".to_string(),
496 tls_cert: None,
497 tls_key: None,
498 });
499 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
500
501 let (tx, rx) = oneshot::channel();
502 let handle = tokio::spawn(async move {
503 serve(listener, cfg, db, None, Some(rx))
504 .await
505 .expect("serve")
506 });
507
508 tx.send(()).expect("signal shutdown");
509 handle.await.expect("join");
510 }
511
512 #[test]
513 fn server_config_validation_rejects_invalid_limits() {
514 let cfg = ServerConfig {
515 data_path: "x.db".into(),
516 host: "".to_string(),
517 port: 0,
518 buffer_pool_size: 0,
519 max_connections: 0,
520 max_statement_bytes: 0,
521 query_timeout_ms: 0,
522 auth_method: crate::server::auth::AuthMethod::Md5,
523 scram_iterations: 0,
524 auth_user: "".to_string(),
525 auth_password: "".to_string(),
526 tls_cert: None,
527 tls_key: None,
528 };
529 assert!(cfg.validate().is_err());
530 }
531
532 #[test]
533 fn database_open_rejects_catalog_with_missing_table_page() {
534 let dir = tempdir().expect("tempdir");
535 let data_path = dir.path().join("server-corrupt.db");
536 let db = Database::open(&data_path, 64).expect("open db");
537 let schema = Schema::new(vec![
538 Column {
539 name: "id".to_string(),
540 data_type: DataType::Int32,
541 nullable: false,
542 default: None,
543 primary_key: false,
544 },
545 Column {
546 name: "name".to_string(),
547 data_type: DataType::Text,
548 nullable: true,
549 default: None,
550 primary_key: false,
551 },
552 ]);
553 let table = db
554 .catalog
555 .create_table("users", schema)
556 .expect("create table");
557
558 db.buffer_pool
559 .delete_page(table.first_page_id)
560 .expect("delete table root page to simulate corruption");
561 drop(db);
562
563 let err = match Database::open(&data_path, 64) {
564 Ok(_) => panic!("expected startup integrity validation failure"),
565 Err(e) => e,
566 };
567 assert!(
568 err.to_string().contains("references missing first_page_id"),
569 "unexpected error: {err}"
570 );
571 }
572}