1pub mod auth;
18pub mod handler;
19pub mod metrics;
20pub mod type_map;
21
22use crate::server::handler::{scan_max_txn_id_from_storage, EntHandler};
23use crate::server::metrics::ServerMetrics;
24use entdb::catalog::Catalog;
25use entdb::error::{EntDbError, Result};
26use entdb::query::history::OptimizerHistoryRecorder;
27use entdb::query::optimizer::OptimizerConfig;
28use entdb::storage::buffer_pool::BufferPool;
29use entdb::storage::buffer_pool::BufferPoolStats;
30use entdb::storage::disk_manager::DiskManager;
31use entdb::tx::{DurabilityMode, TransactionManager};
32use entdb::wal::log_manager::LogManager;
33use entdb::wal::recovery::RecoveryManager;
34use futures::Sink;
35use pgwire::api::auth::md5pass::Md5PasswordAuthStartupHandler;
36use pgwire::api::auth::scram::SASLScramAuthStartupHandler;
37use pgwire::api::auth::DefaultServerParameterProvider;
38use pgwire::api::auth::StartupHandler;
39use pgwire::api::copy::NoopCopyHandler;
40use pgwire::api::ClientInfo;
41use pgwire::api::NoopErrorHandler;
42use pgwire::api::PgWireServerHandlers;
43use pgwire::error::PgWireError;
44use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
45use pgwire::tokio::process_socket;
46use rustls_pemfile::{certs, pkcs8_private_keys};
47use std::path::{Path, PathBuf};
48use std::sync::Arc;
49use tokio::net::TcpListener;
50use tokio::sync::{oneshot, OwnedSemaphorePermit, Semaphore};
51use tokio_rustls::rustls::ServerConfig as RustlsServerConfig;
52use tokio_rustls::TlsAcceptor;
53use tracing::{error, info};
54
55#[derive(Debug, Clone)]
56pub struct ServerConfig {
57 pub data_path: PathBuf,
58 pub host: String,
59 pub port: u16,
60 pub buffer_pool_size: usize,
61 pub max_connections: usize,
62 pub max_statement_bytes: usize,
63 pub query_timeout_ms: u64,
64 pub auth_method: crate::server::auth::AuthMethod,
65 pub scram_iterations: usize,
66 pub auth_user: String,
67 pub auth_password: String,
68 pub tls_cert: Option<PathBuf>,
69 pub tls_key: Option<PathBuf>,
70 pub durability_mode: DurabilityMode,
71 pub await_durable: bool,
72}
73
74impl ServerConfig {
75 pub fn listen_addr(&self) -> String {
76 format!("{}:{}", self.host, self.port)
77 }
78
79 pub fn validate(&self) -> Result<()> {
80 if self.host.trim().is_empty() {
81 return Err(EntDbError::Query("host cannot be empty".to_string()));
82 }
83 if self.port == 0 {
84 return Err(EntDbError::Query("port must be > 0".to_string()));
85 }
86 if self.buffer_pool_size == 0 {
87 return Err(EntDbError::Query(
88 "buffer_pool_size must be > 0".to_string(),
89 ));
90 }
91 if self.max_connections == 0 {
92 return Err(EntDbError::Query("max_connections must be > 0".to_string()));
93 }
94 if self.max_statement_bytes == 0 {
95 return Err(EntDbError::Query(
96 "max_statement_bytes must be > 0".to_string(),
97 ));
98 }
99 if self.query_timeout_ms == 0 {
100 return Err(EntDbError::Query(
101 "query_timeout_ms must be > 0".to_string(),
102 ));
103 }
104 if self.auth_user.trim().is_empty() {
105 return Err(EntDbError::Query("auth_user cannot be empty".to_string()));
106 }
107 if self.auth_password.is_empty() {
108 return Err(EntDbError::Query(
109 "auth_password cannot be empty".to_string(),
110 ));
111 }
112 if matches!(
113 self.auth_method,
114 crate::server::auth::AuthMethod::ScramSha256
115 ) && self.scram_iterations < 4096
116 {
117 return Err(EntDbError::Query(
118 "scram_iterations must be >= 4096".to_string(),
119 ));
120 }
121 match (&self.tls_cert, &self.tls_key) {
122 (Some(cert), Some(key)) => {
123 if !cert.exists() {
124 return Err(EntDbError::Query(format!(
125 "tls_cert path does not exist: {}",
126 cert.display()
127 )));
128 }
129 if !key.exists() {
130 return Err(EntDbError::Query(format!(
131 "tls_key path does not exist: {}",
132 key.display()
133 )));
134 }
135 }
136 (None, None) => {}
137 _ => {
138 return Err(EntDbError::Query(
139 "tls_cert and tls_key must both be provided together".to_string(),
140 ))
141 }
142 }
143 Ok(())
144 }
145}
146
147pub struct Database {
148 pub disk_manager: Arc<DiskManager>,
149 pub log_manager: Arc<LogManager>,
150 pub buffer_pool: Arc<BufferPool>,
151 pub catalog: Arc<Catalog>,
152 pub txn_manager: Arc<TransactionManager>,
153 pub optimizer_history: Arc<OptimizerHistoryRecorder>,
154 pub optimizer_config: OptimizerConfig,
155}
156
157impl Database {
158 pub fn open(
159 data_path: &Path,
160 buffer_pool_size: usize,
161 durability_mode: DurabilityMode,
162 ) -> Result<Self> {
163 let disk_manager = Arc::new(DiskManager::new(data_path)?);
164
165 let mut wal_path = data_path.to_path_buf();
166 wal_path.set_extension("wal");
167 let log_manager = Arc::new(LogManager::new(wal_path, 4096)?);
168
169 let buffer_pool = Arc::new(BufferPool::with_log_manager(
170 buffer_pool_size,
171 Arc::clone(&disk_manager),
172 Arc::clone(&log_manager),
173 ));
174
175 RecoveryManager::new(Arc::clone(&log_manager), Arc::clone(&buffer_pool)).recover()?;
176
177 let catalog = Arc::new(Catalog::load(Arc::clone(&buffer_pool))?);
178 validate_catalog_page_references(&catalog)?;
179
180 let mut txn_state_path = data_path.to_path_buf();
181 txn_state_path.set_extension("txn.json");
182 let mut txn_wal_path = data_path.to_path_buf();
183 txn_wal_path.set_extension("txn.wal");
184
185 let txn_manager = TransactionManager::with_wal_persistence(&txn_state_path, &txn_wal_path)
186 .or_else(|_| TransactionManager::with_persistence(&txn_state_path))
187 .unwrap_or_else(|_| TransactionManager::new());
188 txn_manager.set_durability_mode(durability_mode);
189
190 if let Ok(max_txn) = scan_max_txn_id_from_storage(&catalog) {
191 txn_manager.ensure_min_next_txn_id(max_txn.saturating_add(1));
192 }
193
194 let optimizer_history_path = optimizer_history_path_for_data_path(data_path);
195 let optimizer_history = OptimizerHistoryRecorder::new(
196 optimizer_history_path,
197 optimizer_history_schema_hash(),
198 16,
199 1024,
200 )
201 .or_else(|_| {
202 OptimizerHistoryRecorder::new(
203 std::env::temp_dir().join("entdb.optimizer_history.server.fallback.json"),
204 optimizer_history_schema_hash(),
205 16,
206 1024,
207 )
208 })?;
209
210 let mut optimizer_config = OptimizerConfig::default();
211 if let Ok(v) = std::env::var("ENTDB_CBO") {
212 optimizer_config.cbo_enabled = v == "1" || v.eq_ignore_ascii_case("true");
213 }
214 if let Ok(v) = std::env::var("ENTDB_HBO") {
215 optimizer_config.hbo_enabled = v == "1" || v.eq_ignore_ascii_case("true");
216 }
217 if let Ok(v) = std::env::var("ENTDB_OPT_MAX_SEARCH_MS") {
218 if let Ok(ms) = v.parse::<u64>() {
219 optimizer_config.max_search_ms = ms;
220 }
221 }
222 if let Ok(v) = std::env::var("ENTDB_OPT_MAX_JOIN_RELATIONS") {
223 if let Ok(n) = v.parse::<usize>() {
224 optimizer_config.max_join_relations = n;
225 }
226 }
227 let optimizer_config = optimizer_config.sanitize();
228
229 Ok(Self {
230 disk_manager,
231 log_manager,
232 buffer_pool,
233 catalog,
234 txn_manager: Arc::new(txn_manager),
235 optimizer_history: Arc::new(optimizer_history),
236 optimizer_config,
237 })
238 }
239}
240
241pub fn optimizer_history_path_for_data_path(data_path: &Path) -> PathBuf {
242 let mut p = data_path.to_path_buf();
243 p.set_extension("optimizer_history.json");
244 p
245}
246
247pub fn optimizer_history_schema_hash() -> &'static str {
248 "optimizer_history_schema_v1_planner_v1"
249}
250
251fn validate_catalog_page_references(catalog: &Catalog) -> Result<()> {
252 let bp = catalog.buffer_pool();
253 for table in catalog.list_tables() {
254 bp.fetch_page(table.first_page_id).map_err(|e| {
255 EntDbError::Corruption(format!(
256 "catalog table '{}' references missing first_page_id {}: {e}",
257 table.name, table.first_page_id
258 ))
259 })?;
260
261 for idx in &table.indexes {
262 bp.fetch_page(idx.root_page_id).map_err(|e| {
263 EntDbError::Corruption(format!(
264 "catalog index '{}.{}' references missing root_page_id {}: {e}",
265 table.name, idx.name, idx.root_page_id
266 ))
267 })?;
268 }
269 }
270 Ok(())
271}
272
273struct EntHandlerFactory {
274 startup_handler: Arc<EntStartupHandler>,
275 query_handler: Arc<EntHandler>,
276}
277
278pub enum EntStartupHandler {
279 Md5(
280 Md5PasswordAuthStartupHandler<
281 crate::server::auth::EntAuthSource,
282 DefaultServerParameterProvider,
283 >,
284 ),
285 Scram(
286 SASLScramAuthStartupHandler<
287 crate::server::auth::EntAuthSource,
288 DefaultServerParameterProvider,
289 >,
290 ),
291}
292
293#[async_trait::async_trait]
294impl StartupHandler for EntStartupHandler {
295 async fn on_startup<C>(
296 &self,
297 client: &mut C,
298 message: PgWireFrontendMessage,
299 ) -> pgwire::error::PgWireResult<()>
300 where
301 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
302 C::Error: std::fmt::Debug,
303 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
304 {
305 match self {
306 EntStartupHandler::Md5(h) => h.on_startup(client, message).await,
307 EntStartupHandler::Scram(h) => h.on_startup(client, message).await,
308 }
309 }
310}
311
312impl EntHandlerFactory {
313 fn new(config: Arc<ServerConfig>, db: Arc<Database>, metrics: Arc<ServerMetrics>) -> Self {
314 let auth_source = Arc::new(crate::server::auth::EntAuthSource {
315 method: config.auth_method,
316 expected_user: config.auth_user.clone(),
317 expected_password: config.auth_password.clone(),
318 scram_iterations: config.scram_iterations,
319 });
320 let params = Arc::new(DefaultServerParameterProvider::default());
321 let startup_handler = match config.auth_method {
322 crate::server::auth::AuthMethod::Md5 => Arc::new(EntStartupHandler::Md5(
323 Md5PasswordAuthStartupHandler::new(auth_source, params),
324 )),
325 crate::server::auth::AuthMethod::ScramSha256 => {
326 let mut scram = SASLScramAuthStartupHandler::new(auth_source, params);
327 scram.set_iterations(config.scram_iterations);
328 Arc::new(EntStartupHandler::Scram(scram))
329 }
330 };
331 Self {
332 startup_handler,
333 query_handler: Arc::new(EntHandler::new(
334 db,
335 config.max_statement_bytes,
336 config.query_timeout_ms,
337 metrics,
338 config.await_durable,
339 )),
340 }
341 }
342}
343
344impl PgWireServerHandlers for EntHandlerFactory {
345 type StartupHandler = EntStartupHandler;
346 type SimpleQueryHandler = EntHandler;
347 type ExtendedQueryHandler = EntHandler;
348 type CopyHandler = NoopCopyHandler;
349 type ErrorHandler = NoopErrorHandler;
350
351 fn simple_query_handler(&self) -> Arc<Self::SimpleQueryHandler> {
352 Arc::clone(&self.query_handler)
353 }
354
355 fn extended_query_handler(&self) -> Arc<Self::ExtendedQueryHandler> {
356 Arc::clone(&self.query_handler)
357 }
358
359 fn startup_handler(&self) -> Arc<Self::StartupHandler> {
360 Arc::clone(&self.startup_handler)
361 }
362
363 fn copy_handler(&self) -> Arc<Self::CopyHandler> {
364 Arc::new(NoopCopyHandler)
365 }
366
367 fn error_handler(&self) -> Arc<Self::ErrorHandler> {
368 Arc::new(NoopErrorHandler)
369 }
370}
371
372pub async fn run(config: ServerConfig) -> Result<()> {
373 config.validate()?;
374 let config = Arc::new(config);
375 let database = Arc::new(Database::open(
376 &config.data_path,
377 config.buffer_pool_size,
378 config.durability_mode,
379 )?);
380 let tls_acceptor = build_tls_acceptor(&config)?;
381 let listener = TcpListener::bind(config.listen_addr()).await?;
382 serve(listener, config, database, tls_acceptor, None).await
383}
384
385pub async fn serve(
386 listener: TcpListener,
387 config: Arc<ServerConfig>,
388 database: Arc<Database>,
389 tls_acceptor: Option<Arc<TlsAcceptor>>,
390 mut shutdown: Option<oneshot::Receiver<()>>,
391) -> Result<()> {
392 let addr = listener.local_addr()?;
393 info!(%addr, "entdb server listening");
394 let conn_limit = Arc::new(Semaphore::new(config.max_connections));
395 let metrics = Arc::new(ServerMetrics::default());
396
397 loop {
398 tokio::select! {
399 _ = async {
400 if let Some(rx) = &mut shutdown {
401 let _ = rx.await;
402 }
403 }, if shutdown.is_some() => {
404 info!(%addr, "shutdown signal received");
405 let bp_stats: BufferPoolStats = database.buffer_pool.stats();
406 metrics.set_buffer_pool_pressure(bp_stats);
407
408 let flush_started = std::time::Instant::now();
409 database.buffer_pool.flush_all()?;
410 metrics.on_shutdown_flush(flush_started.elapsed().as_nanos() as u64);
411
412 let persist_started = std::time::Instant::now();
413 database.txn_manager.persist_state()?;
414 metrics.on_shutdown_persist(persist_started.elapsed().as_nanos() as u64);
415 info!(?bp_stats, metrics=?metrics.snapshot(), "server shutdown metrics");
416 break;
417 }
418 accepted = listener.accept() => {
419 let (socket, peer) = accepted?;
420 let permit = match Arc::clone(&conn_limit).try_acquire_owned() {
421 Ok(p) => p,
422 Err(_) => {
423 metrics.on_connection_refused();
424 info!(%peer, "connection refused due to max_connections limit");
425 continue;
426 }
427 };
428 metrics.on_connection_accepted();
429 let factory = Arc::new(EntHandlerFactory::new(
430 Arc::clone(&config),
431 Arc::clone(&database),
432 Arc::clone(&metrics),
433 ));
434 let tls_for_conn = tls_acceptor.clone();
435 let metrics_for_conn = Arc::clone(&metrics);
436 info!(%peer, "accepted connection");
437 tokio::spawn(async move {
438 let _permit: OwnedSemaphorePermit = permit;
439 if let Err(err) = process_socket(socket, tls_for_conn, factory).await {
440 error!(%peer, error = %err, "connection processing error");
441 }
442 metrics_for_conn.on_connection_closed();
443 });
444 }
445 }
446 }
447
448 Ok(())
449}
450
451fn build_tls_acceptor(config: &ServerConfig) -> Result<Option<Arc<TlsAcceptor>>> {
452 let (Some(cert_path), Some(key_path)) = (&config.tls_cert, &config.tls_key) else {
453 return Ok(None);
454 };
455
456 let cert_file = std::fs::File::open(cert_path)?;
457 let mut cert_reader = std::io::BufReader::new(cert_file);
458 let cert_chain = certs(&mut cert_reader)
459 .collect::<std::result::Result<Vec<_>, _>>()
460 .map_err(|e| EntDbError::Query(format!("invalid tls cert PEM: {e}")))?;
461
462 let key_file = std::fs::File::open(key_path)?;
463 let mut key_reader = std::io::BufReader::new(key_file);
464 let mut keys = pkcs8_private_keys(&mut key_reader)
465 .collect::<std::result::Result<Vec<_>, _>>()
466 .map_err(|e| EntDbError::Query(format!("invalid tls key PEM: {e}")))?;
467 let Some(key) = keys.pop() else {
468 return Err(EntDbError::Query(
469 "tls key PEM has no PKCS8 private key".to_string(),
470 ));
471 };
472
473 let rustls = RustlsServerConfig::builder()
474 .with_no_client_auth()
475 .with_single_cert(cert_chain, key.into())
476 .map_err(|e| EntDbError::Query(format!("failed to build tls config: {e}")))?;
477
478 Ok(Some(Arc::new(TlsAcceptor::from(Arc::new(rustls)))))
479}
480
481#[cfg(test)]
482mod tests {
483 use super::{serve, Database, ServerConfig};
484 use entdb::catalog::{Column, Schema};
485 use entdb::types::DataType;
486 use entdb::DurabilityMode;
487 use std::sync::Arc;
488 use tempfile::tempdir;
489 use tokio::net::TcpListener;
490 use tokio::sync::oneshot;
491
492 #[tokio::test]
493 async fn server_accepts_and_stops_with_shutdown_signal() {
494 let dir = tempdir().expect("tempdir");
495 let data_path = dir.path().join("server.db");
496 let db = Arc::new(Database::open(&data_path, 64, DurabilityMode::Full).expect("open db"));
497 let cfg = Arc::new(ServerConfig {
498 data_path,
499 host: "127.0.0.1".to_string(),
500 port: 0,
501 buffer_pool_size: 64,
502 max_connections: 4,
503 max_statement_bytes: 1024 * 1024,
504 query_timeout_ms: 30_000,
505 auth_method: crate::server::auth::AuthMethod::Md5,
506 scram_iterations: 4096,
507 auth_user: "entdb".to_string(),
508 auth_password: "entdb".to_string(),
509 tls_cert: None,
510 tls_key: None,
511 durability_mode: DurabilityMode::Full,
512 await_durable: false,
513 });
514 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
515
516 let (tx, rx) = oneshot::channel();
517 let handle = tokio::spawn(async move {
518 serve(listener, cfg, db, None, Some(rx))
519 .await
520 .expect("serve")
521 });
522
523 tx.send(()).expect("signal shutdown");
524 handle.await.expect("join");
525 }
526
527 #[test]
528 fn server_config_validation_rejects_invalid_limits() {
529 let cfg = ServerConfig {
530 data_path: "x.db".into(),
531 host: "".to_string(),
532 port: 0,
533 buffer_pool_size: 0,
534 max_connections: 0,
535 max_statement_bytes: 0,
536 query_timeout_ms: 0,
537 auth_method: crate::server::auth::AuthMethod::Md5,
538 scram_iterations: 0,
539 auth_user: "".to_string(),
540 auth_password: "".to_string(),
541 tls_cert: None,
542 tls_key: None,
543 durability_mode: DurabilityMode::Full,
544 await_durable: false,
545 };
546 assert!(cfg.validate().is_err());
547 }
548
549 #[test]
550 fn database_open_rejects_catalog_with_missing_table_page() {
551 let dir = tempdir().expect("tempdir");
552 let data_path = dir.path().join("server-corrupt.db");
553 let db = Database::open(&data_path, 64, DurabilityMode::Full).expect("open db");
554 let schema = Schema::new(vec![
555 Column {
556 name: "id".to_string(),
557 data_type: DataType::Int32,
558 nullable: false,
559 default: None,
560 primary_key: false,
561 },
562 Column {
563 name: "name".to_string(),
564 data_type: DataType::Text,
565 nullable: true,
566 default: None,
567 primary_key: false,
568 },
569 ]);
570 let table = db
571 .catalog
572 .create_table("users", schema)
573 .expect("create table");
574
575 db.buffer_pool
576 .delete_page(table.first_page_id)
577 .expect("delete table root page to simulate corruption");
578 drop(db);
579
580 let err = match Database::open(&data_path, 64, DurabilityMode::Full) {
581 Ok(_) => panic!("expected startup integrity validation failure"),
582 Err(e) => e,
583 };
584 assert!(
585 err.to_string().contains("references missing first_page_id"),
586 "unexpected error: {err}"
587 );
588 }
589}