Skip to main content

entdb_server/server/
mod.rs

1/*
2 * Copyright 2026 EntDB Authors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17pub mod auth;
18pub mod handler;
19pub mod metrics;
20pub mod type_map;
21
22use crate::server::handler::{scan_max_txn_id_from_storage, EntHandler};
23use crate::server::metrics::ServerMetrics;
24use entdb::catalog::Catalog;
25use entdb::error::{EntDbError, Result};
26use entdb::query::history::OptimizerHistoryRecorder;
27use entdb::query::optimizer::OptimizerConfig;
28use entdb::storage::buffer_pool::BufferPool;
29use entdb::storage::buffer_pool::BufferPoolStats;
30use entdb::storage::disk_manager::DiskManager;
31use entdb::tx::TransactionManager;
32use entdb::wal::log_manager::LogManager;
33use entdb::wal::recovery::RecoveryManager;
34use futures::Sink;
35use pgwire::api::auth::md5pass::Md5PasswordAuthStartupHandler;
36use pgwire::api::auth::scram::SASLScramAuthStartupHandler;
37use pgwire::api::auth::DefaultServerParameterProvider;
38use pgwire::api::auth::StartupHandler;
39use pgwire::api::copy::NoopCopyHandler;
40use pgwire::api::ClientInfo;
41use pgwire::api::NoopErrorHandler;
42use pgwire::api::PgWireServerHandlers;
43use pgwire::error::PgWireError;
44use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
45use pgwire::tokio::process_socket;
46use rustls_pemfile::{certs, pkcs8_private_keys};
47use std::path::{Path, PathBuf};
48use std::sync::Arc;
49use tokio::net::TcpListener;
50use tokio::sync::{oneshot, OwnedSemaphorePermit, Semaphore};
51use tokio_rustls::rustls::ServerConfig as RustlsServerConfig;
52use tokio_rustls::TlsAcceptor;
53use tracing::{error, info};
54
55#[derive(Debug, Clone)]
56pub struct ServerConfig {
57    pub data_path: PathBuf,
58    pub host: String,
59    pub port: u16,
60    pub buffer_pool_size: usize,
61    pub max_connections: usize,
62    pub max_statement_bytes: usize,
63    pub query_timeout_ms: u64,
64    pub auth_method: crate::server::auth::AuthMethod,
65    pub scram_iterations: usize,
66    pub auth_user: String,
67    pub auth_password: String,
68    pub tls_cert: Option<PathBuf>,
69    pub tls_key: Option<PathBuf>,
70}
71
72impl ServerConfig {
73    pub fn listen_addr(&self) -> String {
74        format!("{}:{}", self.host, self.port)
75    }
76
77    pub fn validate(&self) -> Result<()> {
78        if self.host.trim().is_empty() {
79            return Err(EntDbError::Query("host cannot be empty".to_string()));
80        }
81        if self.port == 0 {
82            return Err(EntDbError::Query("port must be > 0".to_string()));
83        }
84        if self.buffer_pool_size == 0 {
85            return Err(EntDbError::Query(
86                "buffer_pool_size must be > 0".to_string(),
87            ));
88        }
89        if self.max_connections == 0 {
90            return Err(EntDbError::Query("max_connections must be > 0".to_string()));
91        }
92        if self.max_statement_bytes == 0 {
93            return Err(EntDbError::Query(
94                "max_statement_bytes must be > 0".to_string(),
95            ));
96        }
97        if self.query_timeout_ms == 0 {
98            return Err(EntDbError::Query(
99                "query_timeout_ms must be > 0".to_string(),
100            ));
101        }
102        if self.auth_user.trim().is_empty() {
103            return Err(EntDbError::Query("auth_user cannot be empty".to_string()));
104        }
105        if self.auth_password.is_empty() {
106            return Err(EntDbError::Query(
107                "auth_password cannot be empty".to_string(),
108            ));
109        }
110        if matches!(
111            self.auth_method,
112            crate::server::auth::AuthMethod::ScramSha256
113        ) && self.scram_iterations < 4096
114        {
115            return Err(EntDbError::Query(
116                "scram_iterations must be >= 4096".to_string(),
117            ));
118        }
119        match (&self.tls_cert, &self.tls_key) {
120            (Some(cert), Some(key)) => {
121                if !cert.exists() {
122                    return Err(EntDbError::Query(format!(
123                        "tls_cert path does not exist: {}",
124                        cert.display()
125                    )));
126                }
127                if !key.exists() {
128                    return Err(EntDbError::Query(format!(
129                        "tls_key path does not exist: {}",
130                        key.display()
131                    )));
132                }
133            }
134            (None, None) => {}
135            _ => {
136                return Err(EntDbError::Query(
137                    "tls_cert and tls_key must both be provided together".to_string(),
138                ))
139            }
140        }
141        Ok(())
142    }
143}
144
145pub struct Database {
146    pub disk_manager: Arc<DiskManager>,
147    pub log_manager: Arc<LogManager>,
148    pub buffer_pool: Arc<BufferPool>,
149    pub catalog: Arc<Catalog>,
150    pub txn_manager: Arc<TransactionManager>,
151    pub optimizer_history: Arc<OptimizerHistoryRecorder>,
152    pub optimizer_config: OptimizerConfig,
153}
154
155impl Database {
156    pub fn open(data_path: &Path, buffer_pool_size: usize) -> Result<Self> {
157        let disk_manager = Arc::new(DiskManager::new(data_path)?);
158
159        let mut wal_path = data_path.to_path_buf();
160        wal_path.set_extension("wal");
161        let log_manager = Arc::new(LogManager::new(wal_path, 4096)?);
162
163        let buffer_pool = Arc::new(BufferPool::with_log_manager(
164            buffer_pool_size,
165            Arc::clone(&disk_manager),
166            Arc::clone(&log_manager),
167        ));
168
169        RecoveryManager::new(Arc::clone(&log_manager), Arc::clone(&buffer_pool)).recover()?;
170
171        let catalog = Arc::new(Catalog::load(Arc::clone(&buffer_pool))?);
172        validate_catalog_page_references(&catalog)?;
173
174        let mut txn_state_path = data_path.to_path_buf();
175        txn_state_path.set_extension("txn.json");
176        let mut txn_wal_path = data_path.to_path_buf();
177        txn_wal_path.set_extension("txn.wal");
178
179        let txn_manager = TransactionManager::with_wal_persistence(&txn_state_path, &txn_wal_path)
180            .or_else(|_| TransactionManager::with_persistence(&txn_state_path))
181            .unwrap_or_else(|_| TransactionManager::new());
182
183        if let Ok(max_txn) = scan_max_txn_id_from_storage(&catalog) {
184            txn_manager.ensure_min_next_txn_id(max_txn.saturating_add(1));
185        }
186
187        let optimizer_history_path = optimizer_history_path_for_data_path(data_path);
188        let optimizer_history = OptimizerHistoryRecorder::new(
189            optimizer_history_path,
190            optimizer_history_schema_hash(),
191            16,
192            1024,
193        )
194        .or_else(|_| {
195            OptimizerHistoryRecorder::new(
196                std::env::temp_dir().join("entdb.optimizer_history.server.fallback.json"),
197                optimizer_history_schema_hash(),
198                16,
199                1024,
200            )
201        })?;
202
203        let mut optimizer_config = OptimizerConfig::default();
204        if let Ok(v) = std::env::var("ENTDB_CBO") {
205            optimizer_config.cbo_enabled = v == "1" || v.eq_ignore_ascii_case("true");
206        }
207        if let Ok(v) = std::env::var("ENTDB_HBO") {
208            optimizer_config.hbo_enabled = v == "1" || v.eq_ignore_ascii_case("true");
209        }
210        if let Ok(v) = std::env::var("ENTDB_OPT_MAX_SEARCH_MS") {
211            if let Ok(ms) = v.parse::<u64>() {
212                optimizer_config.max_search_ms = ms;
213            }
214        }
215        if let Ok(v) = std::env::var("ENTDB_OPT_MAX_JOIN_RELATIONS") {
216            if let Ok(n) = v.parse::<usize>() {
217                optimizer_config.max_join_relations = n;
218            }
219        }
220        let optimizer_config = optimizer_config.sanitize();
221
222        Ok(Self {
223            disk_manager,
224            log_manager,
225            buffer_pool,
226            catalog,
227            txn_manager: Arc::new(txn_manager),
228            optimizer_history: Arc::new(optimizer_history),
229            optimizer_config,
230        })
231    }
232}
233
234pub fn optimizer_history_path_for_data_path(data_path: &Path) -> PathBuf {
235    let mut p = data_path.to_path_buf();
236    p.set_extension("optimizer_history.json");
237    p
238}
239
240pub fn optimizer_history_schema_hash() -> &'static str {
241    "optimizer_history_schema_v1_planner_v1"
242}
243
244fn validate_catalog_page_references(catalog: &Catalog) -> Result<()> {
245    let bp = catalog.buffer_pool();
246    for table in catalog.list_tables() {
247        bp.fetch_page(table.first_page_id).map_err(|e| {
248            EntDbError::Corruption(format!(
249                "catalog table '{}' references missing first_page_id {}: {e}",
250                table.name, table.first_page_id
251            ))
252        })?;
253
254        for idx in &table.indexes {
255            bp.fetch_page(idx.root_page_id).map_err(|e| {
256                EntDbError::Corruption(format!(
257                    "catalog index '{}.{}' references missing root_page_id {}: {e}",
258                    table.name, idx.name, idx.root_page_id
259                ))
260            })?;
261        }
262    }
263    Ok(())
264}
265
266struct EntHandlerFactory {
267    startup_handler: Arc<EntStartupHandler>,
268    query_handler: Arc<EntHandler>,
269}
270
271pub enum EntStartupHandler {
272    Md5(
273        Md5PasswordAuthStartupHandler<
274            crate::server::auth::EntAuthSource,
275            DefaultServerParameterProvider,
276        >,
277    ),
278    Scram(
279        SASLScramAuthStartupHandler<
280            crate::server::auth::EntAuthSource,
281            DefaultServerParameterProvider,
282        >,
283    ),
284}
285
286#[async_trait::async_trait]
287impl StartupHandler for EntStartupHandler {
288    async fn on_startup<C>(
289        &self,
290        client: &mut C,
291        message: PgWireFrontendMessage,
292    ) -> pgwire::error::PgWireResult<()>
293    where
294        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
295        C::Error: std::fmt::Debug,
296        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
297    {
298        match self {
299            EntStartupHandler::Md5(h) => h.on_startup(client, message).await,
300            EntStartupHandler::Scram(h) => h.on_startup(client, message).await,
301        }
302    }
303}
304
305impl EntHandlerFactory {
306    fn new(config: Arc<ServerConfig>, db: Arc<Database>, metrics: Arc<ServerMetrics>) -> Self {
307        let auth_source = Arc::new(crate::server::auth::EntAuthSource {
308            method: config.auth_method,
309            expected_user: config.auth_user.clone(),
310            expected_password: config.auth_password.clone(),
311            scram_iterations: config.scram_iterations,
312        });
313        let params = Arc::new(DefaultServerParameterProvider::default());
314        let startup_handler = match config.auth_method {
315            crate::server::auth::AuthMethod::Md5 => Arc::new(EntStartupHandler::Md5(
316                Md5PasswordAuthStartupHandler::new(auth_source, params),
317            )),
318            crate::server::auth::AuthMethod::ScramSha256 => {
319                let mut scram = SASLScramAuthStartupHandler::new(auth_source, params);
320                scram.set_iterations(config.scram_iterations);
321                Arc::new(EntStartupHandler::Scram(scram))
322            }
323        };
324        Self {
325            startup_handler,
326            query_handler: Arc::new(EntHandler::new(
327                db,
328                config.max_statement_bytes,
329                config.query_timeout_ms,
330                metrics,
331            )),
332        }
333    }
334}
335
336impl PgWireServerHandlers for EntHandlerFactory {
337    type StartupHandler = EntStartupHandler;
338    type SimpleQueryHandler = EntHandler;
339    type ExtendedQueryHandler = EntHandler;
340    type CopyHandler = NoopCopyHandler;
341    type ErrorHandler = NoopErrorHandler;
342
343    fn simple_query_handler(&self) -> Arc<Self::SimpleQueryHandler> {
344        Arc::clone(&self.query_handler)
345    }
346
347    fn extended_query_handler(&self) -> Arc<Self::ExtendedQueryHandler> {
348        Arc::clone(&self.query_handler)
349    }
350
351    fn startup_handler(&self) -> Arc<Self::StartupHandler> {
352        Arc::clone(&self.startup_handler)
353    }
354
355    fn copy_handler(&self) -> Arc<Self::CopyHandler> {
356        Arc::new(NoopCopyHandler)
357    }
358
359    fn error_handler(&self) -> Arc<Self::ErrorHandler> {
360        Arc::new(NoopErrorHandler)
361    }
362}
363
364pub async fn run(config: ServerConfig) -> Result<()> {
365    config.validate()?;
366    let config = Arc::new(config);
367    let database = Arc::new(Database::open(&config.data_path, config.buffer_pool_size)?);
368    let tls_acceptor = build_tls_acceptor(&config)?;
369    let listener = TcpListener::bind(config.listen_addr()).await?;
370    serve(listener, config, database, tls_acceptor, None).await
371}
372
373pub async fn serve(
374    listener: TcpListener,
375    config: Arc<ServerConfig>,
376    database: Arc<Database>,
377    tls_acceptor: Option<Arc<TlsAcceptor>>,
378    mut shutdown: Option<oneshot::Receiver<()>>,
379) -> Result<()> {
380    let addr = listener.local_addr()?;
381    info!(%addr, "entdb server listening");
382    let conn_limit = Arc::new(Semaphore::new(config.max_connections));
383    let metrics = Arc::new(ServerMetrics::default());
384
385    loop {
386        tokio::select! {
387            _ = async {
388                if let Some(rx) = &mut shutdown {
389                    let _ = rx.await;
390                }
391            }, if shutdown.is_some() => {
392                info!(%addr, "shutdown signal received");
393                let bp_stats: BufferPoolStats = database.buffer_pool.stats();
394                metrics.set_buffer_pool_pressure(bp_stats);
395
396                let flush_started = std::time::Instant::now();
397                database.buffer_pool.flush_all()?;
398                metrics.on_shutdown_flush(flush_started.elapsed().as_nanos() as u64);
399
400                let persist_started = std::time::Instant::now();
401                database.txn_manager.persist_state()?;
402                metrics.on_shutdown_persist(persist_started.elapsed().as_nanos() as u64);
403                info!(?bp_stats, metrics=?metrics.snapshot(), "server shutdown metrics");
404                break;
405            }
406            accepted = listener.accept() => {
407                let (socket, peer) = accepted?;
408                let permit = match Arc::clone(&conn_limit).try_acquire_owned() {
409                    Ok(p) => p,
410                    Err(_) => {
411                        metrics.on_connection_refused();
412                        info!(%peer, "connection refused due to max_connections limit");
413                        continue;
414                    }
415                };
416                metrics.on_connection_accepted();
417                let factory = Arc::new(EntHandlerFactory::new(
418                    Arc::clone(&config),
419                    Arc::clone(&database),
420                    Arc::clone(&metrics),
421                ));
422                let tls_for_conn = tls_acceptor.clone();
423                let metrics_for_conn = Arc::clone(&metrics);
424                info!(%peer, "accepted connection");
425                tokio::spawn(async move {
426                    let _permit: OwnedSemaphorePermit = permit;
427                    if let Err(err) = process_socket(socket, tls_for_conn, factory).await {
428                        error!(%peer, error = %err, "connection processing error");
429                    }
430                    metrics_for_conn.on_connection_closed();
431                });
432            }
433        }
434    }
435
436    Ok(())
437}
438
439fn build_tls_acceptor(config: &ServerConfig) -> Result<Option<Arc<TlsAcceptor>>> {
440    let (Some(cert_path), Some(key_path)) = (&config.tls_cert, &config.tls_key) else {
441        return Ok(None);
442    };
443
444    let cert_file = std::fs::File::open(cert_path)?;
445    let mut cert_reader = std::io::BufReader::new(cert_file);
446    let cert_chain = certs(&mut cert_reader)
447        .collect::<std::result::Result<Vec<_>, _>>()
448        .map_err(|e| EntDbError::Query(format!("invalid tls cert PEM: {e}")))?;
449
450    let key_file = std::fs::File::open(key_path)?;
451    let mut key_reader = std::io::BufReader::new(key_file);
452    let mut keys = pkcs8_private_keys(&mut key_reader)
453        .collect::<std::result::Result<Vec<_>, _>>()
454        .map_err(|e| EntDbError::Query(format!("invalid tls key PEM: {e}")))?;
455    let Some(key) = keys.pop() else {
456        return Err(EntDbError::Query(
457            "tls key PEM has no PKCS8 private key".to_string(),
458        ));
459    };
460
461    let rustls = RustlsServerConfig::builder()
462        .with_no_client_auth()
463        .with_single_cert(cert_chain, key.into())
464        .map_err(|e| EntDbError::Query(format!("failed to build tls config: {e}")))?;
465
466    Ok(Some(Arc::new(TlsAcceptor::from(Arc::new(rustls)))))
467}
468
469#[cfg(test)]
470mod tests {
471    use super::{serve, Database, ServerConfig};
472    use entdb::catalog::{Column, Schema};
473    use entdb::types::DataType;
474    use std::sync::Arc;
475    use tempfile::tempdir;
476    use tokio::net::TcpListener;
477    use tokio::sync::oneshot;
478
479    #[tokio::test]
480    async fn server_accepts_and_stops_with_shutdown_signal() {
481        let dir = tempdir().expect("tempdir");
482        let data_path = dir.path().join("server.db");
483        let db = Arc::new(Database::open(&data_path, 64).expect("open db"));
484        let cfg = Arc::new(ServerConfig {
485            data_path,
486            host: "127.0.0.1".to_string(),
487            port: 0,
488            buffer_pool_size: 64,
489            max_connections: 4,
490            max_statement_bytes: 1024 * 1024,
491            query_timeout_ms: 30_000,
492            auth_method: crate::server::auth::AuthMethod::Md5,
493            scram_iterations: 4096,
494            auth_user: "entdb".to_string(),
495            auth_password: "entdb".to_string(),
496            tls_cert: None,
497            tls_key: None,
498        });
499        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
500
501        let (tx, rx) = oneshot::channel();
502        let handle = tokio::spawn(async move {
503            serve(listener, cfg, db, None, Some(rx))
504                .await
505                .expect("serve")
506        });
507
508        tx.send(()).expect("signal shutdown");
509        handle.await.expect("join");
510    }
511
512    #[test]
513    fn server_config_validation_rejects_invalid_limits() {
514        let cfg = ServerConfig {
515            data_path: "x.db".into(),
516            host: "".to_string(),
517            port: 0,
518            buffer_pool_size: 0,
519            max_connections: 0,
520            max_statement_bytes: 0,
521            query_timeout_ms: 0,
522            auth_method: crate::server::auth::AuthMethod::Md5,
523            scram_iterations: 0,
524            auth_user: "".to_string(),
525            auth_password: "".to_string(),
526            tls_cert: None,
527            tls_key: None,
528        };
529        assert!(cfg.validate().is_err());
530    }
531
532    #[test]
533    fn database_open_rejects_catalog_with_missing_table_page() {
534        let dir = tempdir().expect("tempdir");
535        let data_path = dir.path().join("server-corrupt.db");
536        let db = Database::open(&data_path, 64).expect("open db");
537        let schema = Schema::new(vec![
538            Column {
539                name: "id".to_string(),
540                data_type: DataType::Int32,
541                nullable: false,
542                default: None,
543                primary_key: false,
544            },
545            Column {
546                name: "name".to_string(),
547                data_type: DataType::Text,
548                nullable: true,
549                default: None,
550                primary_key: false,
551            },
552        ]);
553        let table = db
554            .catalog
555            .create_table("users", schema)
556            .expect("create table");
557
558        db.buffer_pool
559            .delete_page(table.first_page_id)
560            .expect("delete table root page to simulate corruption");
561        drop(db);
562
563        let err = match Database::open(&data_path, 64) {
564            Ok(_) => panic!("expected startup integrity validation failure"),
565            Err(e) => e,
566        };
567        assert!(
568            err.to_string().contains("references missing first_page_id"),
569            "unexpected error: {err}"
570        );
571    }
572}