Skip to main content

entdb_server/server/
mod.rs

1/*
2 * Copyright 2026 EntDB Authors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17pub mod auth;
18pub mod handler;
19pub mod metrics;
20pub mod type_map;
21
22use crate::server::handler::{scan_max_txn_id_from_storage, EntHandler};
23use crate::server::metrics::ServerMetrics;
24use entdb::catalog::Catalog;
25use entdb::error::{EntDbError, Result};
26use entdb::query::history::OptimizerHistoryRecorder;
27use entdb::query::optimizer::OptimizerConfig;
28use entdb::storage::buffer_pool::BufferPool;
29use entdb::storage::buffer_pool::BufferPoolStats;
30use entdb::storage::disk_manager::DiskManager;
31use entdb::tx::{DurabilityMode, TransactionManager};
32use entdb::wal::log_manager::LogManager;
33use entdb::wal::recovery::RecoveryManager;
34use futures::Sink;
35use pgwire::api::auth::md5pass::Md5PasswordAuthStartupHandler;
36use pgwire::api::auth::scram::SASLScramAuthStartupHandler;
37use pgwire::api::auth::DefaultServerParameterProvider;
38use pgwire::api::auth::StartupHandler;
39use pgwire::api::copy::NoopCopyHandler;
40use pgwire::api::ClientInfo;
41use pgwire::api::NoopErrorHandler;
42use pgwire::api::PgWireServerHandlers;
43use pgwire::error::PgWireError;
44use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
45use pgwire::tokio::process_socket;
46use rustls_pemfile::{certs, pkcs8_private_keys};
47use std::path::{Path, PathBuf};
48use std::sync::Arc;
49use tokio::net::TcpListener;
50use tokio::sync::{oneshot, OwnedSemaphorePermit, Semaphore};
51use tokio_rustls::rustls::ServerConfig as RustlsServerConfig;
52use tokio_rustls::TlsAcceptor;
53use tracing::{error, info};
54
55#[derive(Debug, Clone)]
56pub struct ServerConfig {
57    pub data_path: PathBuf,
58    pub host: String,
59    pub port: u16,
60    pub buffer_pool_size: usize,
61    pub max_connections: usize,
62    pub max_statement_bytes: usize,
63    pub query_timeout_ms: u64,
64    pub auth_method: crate::server::auth::AuthMethod,
65    pub scram_iterations: usize,
66    pub auth_user: String,
67    pub auth_password: String,
68    pub tls_cert: Option<PathBuf>,
69    pub tls_key: Option<PathBuf>,
70    pub durability_mode: DurabilityMode,
71    pub await_durable: bool,
72}
73
74impl ServerConfig {
75    pub fn listen_addr(&self) -> String {
76        format!("{}:{}", self.host, self.port)
77    }
78
79    pub fn validate(&self) -> Result<()> {
80        if self.host.trim().is_empty() {
81            return Err(EntDbError::Query("host cannot be empty".to_string()));
82        }
83        if self.port == 0 {
84            return Err(EntDbError::Query("port must be > 0".to_string()));
85        }
86        if self.buffer_pool_size == 0 {
87            return Err(EntDbError::Query(
88                "buffer_pool_size must be > 0".to_string(),
89            ));
90        }
91        if self.max_connections == 0 {
92            return Err(EntDbError::Query("max_connections must be > 0".to_string()));
93        }
94        if self.max_statement_bytes == 0 {
95            return Err(EntDbError::Query(
96                "max_statement_bytes must be > 0".to_string(),
97            ));
98        }
99        if self.query_timeout_ms == 0 {
100            return Err(EntDbError::Query(
101                "query_timeout_ms must be > 0".to_string(),
102            ));
103        }
104        if self.auth_user.trim().is_empty() {
105            return Err(EntDbError::Query("auth_user cannot be empty".to_string()));
106        }
107        if self.auth_password.is_empty() {
108            return Err(EntDbError::Query(
109                "auth_password cannot be empty".to_string(),
110            ));
111        }
112        if matches!(
113            self.auth_method,
114            crate::server::auth::AuthMethod::ScramSha256
115        ) && self.scram_iterations < 4096
116        {
117            return Err(EntDbError::Query(
118                "scram_iterations must be >= 4096".to_string(),
119            ));
120        }
121        match (&self.tls_cert, &self.tls_key) {
122            (Some(cert), Some(key)) => {
123                if !cert.exists() {
124                    return Err(EntDbError::Query(format!(
125                        "tls_cert path does not exist: {}",
126                        cert.display()
127                    )));
128                }
129                if !key.exists() {
130                    return Err(EntDbError::Query(format!(
131                        "tls_key path does not exist: {}",
132                        key.display()
133                    )));
134                }
135            }
136            (None, None) => {}
137            _ => {
138                return Err(EntDbError::Query(
139                    "tls_cert and tls_key must both be provided together".to_string(),
140                ))
141            }
142        }
143        Ok(())
144    }
145}
146
147pub struct Database {
148    pub disk_manager: Arc<DiskManager>,
149    pub log_manager: Arc<LogManager>,
150    pub buffer_pool: Arc<BufferPool>,
151    pub catalog: Arc<Catalog>,
152    pub txn_manager: Arc<TransactionManager>,
153    pub optimizer_history: Arc<OptimizerHistoryRecorder>,
154    pub optimizer_config: OptimizerConfig,
155}
156
157impl Database {
158    pub fn open(
159        data_path: &Path,
160        buffer_pool_size: usize,
161        durability_mode: DurabilityMode,
162    ) -> Result<Self> {
163        let disk_manager = Arc::new(DiskManager::new(data_path)?);
164
165        let mut wal_path = data_path.to_path_buf();
166        wal_path.set_extension("wal");
167        let log_manager = Arc::new(LogManager::new(wal_path, 4096)?);
168
169        let buffer_pool = Arc::new(BufferPool::with_log_manager(
170            buffer_pool_size,
171            Arc::clone(&disk_manager),
172            Arc::clone(&log_manager),
173        ));
174
175        RecoveryManager::new(Arc::clone(&log_manager), Arc::clone(&buffer_pool)).recover()?;
176
177        let catalog = Arc::new(Catalog::load(Arc::clone(&buffer_pool))?);
178        validate_catalog_page_references(&catalog)?;
179
180        let mut txn_state_path = data_path.to_path_buf();
181        txn_state_path.set_extension("txn.json");
182        let mut txn_wal_path = data_path.to_path_buf();
183        txn_wal_path.set_extension("txn.wal");
184
185        let txn_manager = TransactionManager::with_wal_persistence(&txn_state_path, &txn_wal_path)
186            .or_else(|_| TransactionManager::with_persistence(&txn_state_path))
187            .unwrap_or_else(|_| TransactionManager::new());
188        txn_manager.set_durability_mode(durability_mode);
189
190        if let Ok(max_txn) = scan_max_txn_id_from_storage(&catalog) {
191            txn_manager.ensure_min_next_txn_id(max_txn.saturating_add(1));
192        }
193
194        let optimizer_history_path = optimizer_history_path_for_data_path(data_path);
195        let optimizer_history = OptimizerHistoryRecorder::new(
196            optimizer_history_path,
197            optimizer_history_schema_hash(),
198            16,
199            1024,
200        )
201        .or_else(|_| {
202            OptimizerHistoryRecorder::new(
203                std::env::temp_dir().join("entdb.optimizer_history.server.fallback.json"),
204                optimizer_history_schema_hash(),
205                16,
206                1024,
207            )
208        })?;
209
210        let mut optimizer_config = OptimizerConfig::default();
211        if let Ok(v) = std::env::var("ENTDB_CBO") {
212            optimizer_config.cbo_enabled = v == "1" || v.eq_ignore_ascii_case("true");
213        }
214        if let Ok(v) = std::env::var("ENTDB_HBO") {
215            optimizer_config.hbo_enabled = v == "1" || v.eq_ignore_ascii_case("true");
216        }
217        if let Ok(v) = std::env::var("ENTDB_OPT_MAX_SEARCH_MS") {
218            if let Ok(ms) = v.parse::<u64>() {
219                optimizer_config.max_search_ms = ms;
220            }
221        }
222        if let Ok(v) = std::env::var("ENTDB_OPT_MAX_JOIN_RELATIONS") {
223            if let Ok(n) = v.parse::<usize>() {
224                optimizer_config.max_join_relations = n;
225            }
226        }
227        let optimizer_config = optimizer_config.sanitize();
228
229        Ok(Self {
230            disk_manager,
231            log_manager,
232            buffer_pool,
233            catalog,
234            txn_manager: Arc::new(txn_manager),
235            optimizer_history: Arc::new(optimizer_history),
236            optimizer_config,
237        })
238    }
239}
240
241pub fn optimizer_history_path_for_data_path(data_path: &Path) -> PathBuf {
242    let mut p = data_path.to_path_buf();
243    p.set_extension("optimizer_history.json");
244    p
245}
246
247pub fn optimizer_history_schema_hash() -> &'static str {
248    "optimizer_history_schema_v1_planner_v1"
249}
250
251fn validate_catalog_page_references(catalog: &Catalog) -> Result<()> {
252    let bp = catalog.buffer_pool();
253    for table in catalog.list_tables() {
254        bp.fetch_page(table.first_page_id).map_err(|e| {
255            EntDbError::Corruption(format!(
256                "catalog table '{}' references missing first_page_id {}: {e}",
257                table.name, table.first_page_id
258            ))
259        })?;
260
261        for idx in &table.indexes {
262            bp.fetch_page(idx.root_page_id).map_err(|e| {
263                EntDbError::Corruption(format!(
264                    "catalog index '{}.{}' references missing root_page_id {}: {e}",
265                    table.name, idx.name, idx.root_page_id
266                ))
267            })?;
268        }
269    }
270    Ok(())
271}
272
273struct EntHandlerFactory {
274    startup_handler: Arc<EntStartupHandler>,
275    query_handler: Arc<EntHandler>,
276}
277
278pub enum EntStartupHandler {
279    Md5(
280        Md5PasswordAuthStartupHandler<
281            crate::server::auth::EntAuthSource,
282            DefaultServerParameterProvider,
283        >,
284    ),
285    Scram(
286        SASLScramAuthStartupHandler<
287            crate::server::auth::EntAuthSource,
288            DefaultServerParameterProvider,
289        >,
290    ),
291}
292
293#[async_trait::async_trait]
294impl StartupHandler for EntStartupHandler {
295    async fn on_startup<C>(
296        &self,
297        client: &mut C,
298        message: PgWireFrontendMessage,
299    ) -> pgwire::error::PgWireResult<()>
300    where
301        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
302        C::Error: std::fmt::Debug,
303        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
304    {
305        match self {
306            EntStartupHandler::Md5(h) => h.on_startup(client, message).await,
307            EntStartupHandler::Scram(h) => h.on_startup(client, message).await,
308        }
309    }
310}
311
312impl EntHandlerFactory {
313    fn new(config: Arc<ServerConfig>, db: Arc<Database>, metrics: Arc<ServerMetrics>) -> Self {
314        let auth_source = Arc::new(crate::server::auth::EntAuthSource {
315            method: config.auth_method,
316            expected_user: config.auth_user.clone(),
317            expected_password: config.auth_password.clone(),
318            scram_iterations: config.scram_iterations,
319        });
320        let params = Arc::new(DefaultServerParameterProvider::default());
321        let startup_handler = match config.auth_method {
322            crate::server::auth::AuthMethod::Md5 => Arc::new(EntStartupHandler::Md5(
323                Md5PasswordAuthStartupHandler::new(auth_source, params),
324            )),
325            crate::server::auth::AuthMethod::ScramSha256 => {
326                let mut scram = SASLScramAuthStartupHandler::new(auth_source, params);
327                scram.set_iterations(config.scram_iterations);
328                Arc::new(EntStartupHandler::Scram(scram))
329            }
330        };
331        Self {
332            startup_handler,
333            query_handler: Arc::new(EntHandler::new(
334                db,
335                config.max_statement_bytes,
336                config.query_timeout_ms,
337                metrics,
338                config.await_durable,
339            )),
340        }
341    }
342}
343
344impl PgWireServerHandlers for EntHandlerFactory {
345    type StartupHandler = EntStartupHandler;
346    type SimpleQueryHandler = EntHandler;
347    type ExtendedQueryHandler = EntHandler;
348    type CopyHandler = NoopCopyHandler;
349    type ErrorHandler = NoopErrorHandler;
350
351    fn simple_query_handler(&self) -> Arc<Self::SimpleQueryHandler> {
352        Arc::clone(&self.query_handler)
353    }
354
355    fn extended_query_handler(&self) -> Arc<Self::ExtendedQueryHandler> {
356        Arc::clone(&self.query_handler)
357    }
358
359    fn startup_handler(&self) -> Arc<Self::StartupHandler> {
360        Arc::clone(&self.startup_handler)
361    }
362
363    fn copy_handler(&self) -> Arc<Self::CopyHandler> {
364        Arc::new(NoopCopyHandler)
365    }
366
367    fn error_handler(&self) -> Arc<Self::ErrorHandler> {
368        Arc::new(NoopErrorHandler)
369    }
370}
371
372pub async fn run(config: ServerConfig) -> Result<()> {
373    config.validate()?;
374    let config = Arc::new(config);
375    let database = Arc::new(Database::open(
376        &config.data_path,
377        config.buffer_pool_size,
378        config.durability_mode,
379    )?);
380    let tls_acceptor = build_tls_acceptor(&config)?;
381    let listener = TcpListener::bind(config.listen_addr()).await?;
382    serve(listener, config, database, tls_acceptor, None).await
383}
384
385pub async fn serve(
386    listener: TcpListener,
387    config: Arc<ServerConfig>,
388    database: Arc<Database>,
389    tls_acceptor: Option<Arc<TlsAcceptor>>,
390    mut shutdown: Option<oneshot::Receiver<()>>,
391) -> Result<()> {
392    let addr = listener.local_addr()?;
393    info!(%addr, "entdb server listening");
394    let conn_limit = Arc::new(Semaphore::new(config.max_connections));
395    let metrics = Arc::new(ServerMetrics::default());
396
397    loop {
398        tokio::select! {
399            _ = async {
400                if let Some(rx) = &mut shutdown {
401                    let _ = rx.await;
402                }
403            }, if shutdown.is_some() => {
404                info!(%addr, "shutdown signal received");
405                let bp_stats: BufferPoolStats = database.buffer_pool.stats();
406                metrics.set_buffer_pool_pressure(bp_stats);
407
408                let flush_started = std::time::Instant::now();
409                database.buffer_pool.flush_all()?;
410                metrics.on_shutdown_flush(flush_started.elapsed().as_nanos() as u64);
411
412                let persist_started = std::time::Instant::now();
413                database.txn_manager.persist_state()?;
414                metrics.on_shutdown_persist(persist_started.elapsed().as_nanos() as u64);
415                info!(?bp_stats, metrics=?metrics.snapshot(), "server shutdown metrics");
416                break;
417            }
418            accepted = listener.accept() => {
419                let (socket, peer) = accepted?;
420                let permit = match Arc::clone(&conn_limit).try_acquire_owned() {
421                    Ok(p) => p,
422                    Err(_) => {
423                        metrics.on_connection_refused();
424                        info!(%peer, "connection refused due to max_connections limit");
425                        continue;
426                    }
427                };
428                metrics.on_connection_accepted();
429                let factory = Arc::new(EntHandlerFactory::new(
430                    Arc::clone(&config),
431                    Arc::clone(&database),
432                    Arc::clone(&metrics),
433                ));
434                let tls_for_conn = tls_acceptor.clone();
435                let metrics_for_conn = Arc::clone(&metrics);
436                info!(%peer, "accepted connection");
437                tokio::spawn(async move {
438                    let _permit: OwnedSemaphorePermit = permit;
439                    if let Err(err) = process_socket(socket, tls_for_conn, factory).await {
440                        error!(%peer, error = %err, "connection processing error");
441                    }
442                    metrics_for_conn.on_connection_closed();
443                });
444            }
445        }
446    }
447
448    Ok(())
449}
450
451fn build_tls_acceptor(config: &ServerConfig) -> Result<Option<Arc<TlsAcceptor>>> {
452    let (Some(cert_path), Some(key_path)) = (&config.tls_cert, &config.tls_key) else {
453        return Ok(None);
454    };
455
456    let cert_file = std::fs::File::open(cert_path)?;
457    let mut cert_reader = std::io::BufReader::new(cert_file);
458    let cert_chain = certs(&mut cert_reader)
459        .collect::<std::result::Result<Vec<_>, _>>()
460        .map_err(|e| EntDbError::Query(format!("invalid tls cert PEM: {e}")))?;
461
462    let key_file = std::fs::File::open(key_path)?;
463    let mut key_reader = std::io::BufReader::new(key_file);
464    let mut keys = pkcs8_private_keys(&mut key_reader)
465        .collect::<std::result::Result<Vec<_>, _>>()
466        .map_err(|e| EntDbError::Query(format!("invalid tls key PEM: {e}")))?;
467    let Some(key) = keys.pop() else {
468        return Err(EntDbError::Query(
469            "tls key PEM has no PKCS8 private key".to_string(),
470        ));
471    };
472
473    let rustls = RustlsServerConfig::builder()
474        .with_no_client_auth()
475        .with_single_cert(cert_chain, key.into())
476        .map_err(|e| EntDbError::Query(format!("failed to build tls config: {e}")))?;
477
478    Ok(Some(Arc::new(TlsAcceptor::from(Arc::new(rustls)))))
479}
480
481#[cfg(test)]
482mod tests {
483    use super::{serve, Database, ServerConfig};
484    use entdb::catalog::{Column, Schema};
485    use entdb::types::DataType;
486    use entdb::DurabilityMode;
487    use std::sync::Arc;
488    use tempfile::tempdir;
489    use tokio::net::TcpListener;
490    use tokio::sync::oneshot;
491
492    #[tokio::test]
493    async fn server_accepts_and_stops_with_shutdown_signal() {
494        let dir = tempdir().expect("tempdir");
495        let data_path = dir.path().join("server.db");
496        let db = Arc::new(Database::open(&data_path, 64, DurabilityMode::Full).expect("open db"));
497        let cfg = Arc::new(ServerConfig {
498            data_path,
499            host: "127.0.0.1".to_string(),
500            port: 0,
501            buffer_pool_size: 64,
502            max_connections: 4,
503            max_statement_bytes: 1024 * 1024,
504            query_timeout_ms: 30_000,
505            auth_method: crate::server::auth::AuthMethod::Md5,
506            scram_iterations: 4096,
507            auth_user: "entdb".to_string(),
508            auth_password: "entdb".to_string(),
509            tls_cert: None,
510            tls_key: None,
511            durability_mode: DurabilityMode::Full,
512            await_durable: false,
513        });
514        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
515
516        let (tx, rx) = oneshot::channel();
517        let handle = tokio::spawn(async move {
518            serve(listener, cfg, db, None, Some(rx))
519                .await
520                .expect("serve")
521        });
522
523        tx.send(()).expect("signal shutdown");
524        handle.await.expect("join");
525    }
526
527    #[test]
528    fn server_config_validation_rejects_invalid_limits() {
529        let cfg = ServerConfig {
530            data_path: "x.db".into(),
531            host: "".to_string(),
532            port: 0,
533            buffer_pool_size: 0,
534            max_connections: 0,
535            max_statement_bytes: 0,
536            query_timeout_ms: 0,
537            auth_method: crate::server::auth::AuthMethod::Md5,
538            scram_iterations: 0,
539            auth_user: "".to_string(),
540            auth_password: "".to_string(),
541            tls_cert: None,
542            tls_key: None,
543            durability_mode: DurabilityMode::Full,
544            await_durable: false,
545        };
546        assert!(cfg.validate().is_err());
547    }
548
549    #[test]
550    fn database_open_rejects_catalog_with_missing_table_page() {
551        let dir = tempdir().expect("tempdir");
552        let data_path = dir.path().join("server-corrupt.db");
553        let db = Database::open(&data_path, 64, DurabilityMode::Full).expect("open db");
554        let schema = Schema::new(vec![
555            Column {
556                name: "id".to_string(),
557                data_type: DataType::Int32,
558                nullable: false,
559                default: None,
560                primary_key: false,
561            },
562            Column {
563                name: "name".to_string(),
564                data_type: DataType::Text,
565                nullable: true,
566                default: None,
567                primary_key: false,
568            },
569        ]);
570        let table = db
571            .catalog
572            .create_table("users", schema)
573            .expect("create table");
574
575        db.buffer_pool
576            .delete_page(table.first_page_id)
577            .expect("delete table root page to simulate corruption");
578        drop(db);
579
580        let err = match Database::open(&data_path, 64, DurabilityMode::Full) {
581            Ok(_) => panic!("expected startup integrity validation failure"),
582            Err(e) => e,
583        };
584        assert!(
585            err.to_string().contains("references missing first_page_id"),
586            "unexpected error: {err}"
587        );
588    }
589}