Skip to main content

entdb_server/server/
handler.rs

1/*
2 * Copyright 2026 EntDB Authors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::server::type_map::{encode_value, ent_to_pg_type};
18use async_trait::async_trait;
19use entdb::catalog::Catalog;
20use entdb::error::EntDbError;
21use entdb::query::binder::Binder;
22use entdb::query::executor::{build_executor, ExecutionContext, TxExecutionContext};
23use entdb::query::history::OptimizerHistoryRecord;
24use entdb::query::optimizer::Optimizer;
25use entdb::query::planner::Planner;
26use entdb::query::polyglot::{transpile_with_meta, PolyglotOptions};
27use entdb::storage::table::Table;
28use entdb::tx::TransactionHandle;
29use futures::{stream, Sink};
30use parking_lot::Mutex;
31use pgwire::api::portal::{Format, Portal};
32use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
33use pgwire::api::results::{
34    DataRowEncoder, DescribePortalResponse, DescribeStatementResponse, FieldFormat, FieldInfo,
35    QueryResponse, Response, Tag,
36};
37use pgwire::api::stmt::{NoopQueryParser, StoredStatement};
38use pgwire::api::{ClientInfo, Type};
39use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
40use pgwire::messages::PgWireBackendMessage;
41use sqlparser::ast::{visit_expressions_mut, Expr, Statement, Value as SqlValue};
42use sqlparser::dialect::PostgreSqlDialect;
43use sqlparser::parser::Parser;
44use std::fmt::Debug;
45use std::ops::ControlFlow;
46use std::sync::Arc;
47use tracing::{debug, info_span};
48
49use super::metrics::ServerMetrics;
50use super::{optimizer_history_schema_hash, Database};
51
52pub struct EntHandler {
53    db: Arc<Database>,
54    current_txn: Mutex<Option<TransactionHandle>>,
55    query_parser: Arc<NoopQueryParser>,
56    max_statement_bytes: usize,
57    query_timeout_ms: u64,
58    metrics: Arc<ServerMetrics>,
59    polyglot: PolyglotOptions,
60}
61
62impl EntHandler {
63    pub fn new(
64        db: Arc<Database>,
65        max_statement_bytes: usize,
66        query_timeout_ms: u64,
67        metrics: Arc<ServerMetrics>,
68    ) -> Self {
69        Self {
70            db,
71            current_txn: Mutex::new(None),
72            query_parser: Arc::new(NoopQueryParser::new()),
73            max_statement_bytes,
74            query_timeout_ms,
75            metrics,
76            polyglot: PolyglotOptions {
77                enabled: std::env::var("ENTDB_POLYGLOT")
78                    .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
79                    .unwrap_or(false),
80            },
81        }
82    }
83
84    fn execute_sql(
85        &self,
86        sql: &str,
87        max_rows: Option<usize>,
88    ) -> PgWireResult<Vec<Response<'static>>> {
89        if sql.len() > self.max_statement_bytes {
90            return Err(user_error(
91                "54000",
92                format!(
93                    "statement exceeds configured max_statement_bytes ({} > {})",
94                    sql.len(),
95                    self.max_statement_bytes
96                ),
97            ));
98        }
99        let _span = info_span!(
100            "pgwire.execute_sql",
101            sql_len = sql.len(),
102            max_rows = max_rows.unwrap_or(0)
103        )
104        .entered();
105        let statements = parse_sql_with_polyglot(sql, self.polyglot)?;
106        let mut responses = Vec::new();
107        let started = std::time::Instant::now();
108
109        for stmt in &statements {
110            if started.elapsed().as_millis() as u64 > self.query_timeout_ms {
111                return Err(user_error("57014", "query timeout exceeded"));
112            }
113            debug!(statement = %tag_for_statement_name(stmt), "processing statement");
114            match stmt {
115                Statement::StartTransaction { .. } => {
116                    let mut guard = self.current_txn.lock();
117                    if guard.is_some() {
118                        return Err(user_error(
119                            "25001",
120                            "transaction already active for this session",
121                        ));
122                    }
123                    *guard = Some(self.db.txn_manager.begin());
124                    responses.push(Response::TransactionStart(Tag::new("BEGIN")));
125                }
126                Statement::Commit { .. } => {
127                    let tx = {
128                        let mut guard = self.current_txn.lock();
129                        guard
130                            .take()
131                            .ok_or_else(|| user_error("25000", "no active transaction to COMMIT"))?
132                    };
133                    self.db
134                        .txn_manager
135                        .commit(tx.txn_id)
136                        .map_err(map_entdb_error)?;
137                    responses.push(Response::TransactionEnd(Tag::new("COMMIT")));
138                }
139                Statement::Rollback { .. } => {
140                    let tx = {
141                        let mut guard = self.current_txn.lock();
142                        guard.take().ok_or_else(|| {
143                            user_error("25000", "no active transaction to ROLLBACK")
144                        })?
145                    };
146                    self.db.txn_manager.abort(tx.txn_id);
147                    responses.push(Response::TransactionEnd(Tag::new("ROLLBACK")));
148                }
149                _ => {
150                    let active_tx = *self.current_txn.lock();
151                    let exec_resp = if let Some(tx) = active_tx {
152                        self.execute_statement_in_txn(&tx, stmt, max_rows, None)?
153                    } else {
154                        let tx = self.db.txn_manager.begin();
155                        match self.execute_statement_in_txn(&tx, stmt, max_rows, None) {
156                            Ok(resp) => {
157                                if let Err(e) = self.db.txn_manager.commit(tx.txn_id) {
158                                    self.db.txn_manager.abort(tx.txn_id);
159                                    return Err(map_entdb_error(e));
160                                }
161                                resp
162                            }
163                            Err(e) => {
164                                self.db.txn_manager.abort(tx.txn_id);
165                                return Err(e);
166                            }
167                        }
168                    };
169                    responses.push(exec_resp);
170                }
171            }
172        }
173
174        Ok(responses)
175    }
176
177    fn execute_statement_in_txn(
178        &self,
179        tx: &TransactionHandle,
180        stmt: &Statement,
181        max_rows: Option<usize>,
182        row_format: Option<&Format>,
183    ) -> PgWireResult<Response<'static>> {
184        let stmt_started = std::time::Instant::now();
185        let _span = info_span!(
186            "pgwire.execute_statement",
187            statement = %tag_for_statement_name(stmt),
188            txn_id = tx.txn_id
189        )
190        .entered();
191        let binder = Binder::new(Arc::clone(&self.db.catalog));
192        let planner = Planner;
193
194        let bound = match binder.bind(stmt) {
195            Ok(b) => b,
196            Err(e) => {
197                self.record_optimizer_history_error(stmt, "bind_error", stmt_started.elapsed());
198                return Err(map_entdb_error(e));
199            }
200        };
201        let fingerprint = Optimizer::fingerprint_bound_statement(&bound);
202        let history = self.db.optimizer_history.read_for_fingerprint(&fingerprint);
203        self.ensure_statement_timeout(stmt_started)?;
204        let plan = match planner.plan(bound) {
205            Ok(p) => p,
206            Err(e) => {
207                self.record_optimizer_history(OptimizerHistoryRecord {
208                    fingerprint,
209                    plan_signature: "error".to_string(),
210                    schema_hash: optimizer_history_schema_hash().to_string(),
211                    captured_at_ms: now_epoch_millis(),
212                    rowcount_observed_json: "{\"root\":0}".to_string(),
213                    latency_ms: stmt_started.elapsed().as_millis() as u64,
214                    memory_peak_bytes: 0,
215                    success: false,
216                    error_class: Some("plan_error".to_string()),
217                    confidence: 0.0,
218                });
219                return Err(map_entdb_error(e));
220            }
221        };
222        self.ensure_statement_timeout(stmt_started)?;
223        let optimized_outcome = Optimizer::optimize_with_trace_and_history(
224            plan,
225            &fingerprint,
226            self.db.optimizer_config,
227            &history,
228        );
229        let chosen_plan_signature = optimized_outcome
230            .trace
231            .chosen_plan_signature
232            .clone()
233            .unwrap_or_else(|| "baseline".to_string());
234        let optimized = optimized_outcome.plan;
235        self.ensure_statement_timeout(stmt_started)?;
236
237        let ctx = ExecutionContext {
238            catalog: Arc::clone(&self.db.catalog),
239            tx: Some(TxExecutionContext {
240                txn_id: tx.txn_id,
241                snapshot_ts: tx.snapshot_ts,
242                txn_manager: Arc::clone(&self.db.txn_manager),
243            }),
244        };
245
246        let mut exec = match build_executor(&optimized, &ctx) {
247            Ok(exec) => exec,
248            Err(e) => {
249                self.record_optimizer_history(OptimizerHistoryRecord {
250                    fingerprint,
251                    plan_signature: chosen_plan_signature,
252                    schema_hash: optimizer_history_schema_hash().to_string(),
253                    captured_at_ms: now_epoch_millis(),
254                    rowcount_observed_json: "{\"root\":0}".to_string(),
255                    latency_ms: stmt_started.elapsed().as_millis() as u64,
256                    memory_peak_bytes: 0,
257                    success: false,
258                    error_class: Some("executor_build_error".to_string()),
259                    confidence: 0.0,
260                });
261                return Err(map_entdb_error(e));
262            }
263        };
264        self.ensure_statement_timeout(stmt_started)?;
265        if let Err(e) = exec.open() {
266            self.record_optimizer_history(OptimizerHistoryRecord {
267                fingerprint,
268                plan_signature: chosen_plan_signature,
269                schema_hash: optimizer_history_schema_hash().to_string(),
270                captured_at_ms: now_epoch_millis(),
271                rowcount_observed_json: "{\"root\":0}".to_string(),
272                latency_ms: stmt_started.elapsed().as_millis() as u64,
273                memory_peak_bytes: 0,
274                success: false,
275                error_class: Some("executor_open_error".to_string()),
276                confidence: 0.0,
277            });
278            return Err(map_entdb_error(e));
279        }
280        self.ensure_statement_timeout(stmt_started)?;
281
282        let schema = exec.schema().clone();
283        let mut rows = Vec::new();
284        while let Some(row) = match exec.next() {
285            Ok(r) => r,
286            Err(e) => {
287                self.record_optimizer_history(OptimizerHistoryRecord {
288                    fingerprint: fingerprint.clone(),
289                    plan_signature: chosen_plan_signature.clone(),
290                    schema_hash: optimizer_history_schema_hash().to_string(),
291                    captured_at_ms: now_epoch_millis(),
292                    rowcount_observed_json: format!("{{\"root\":{}}}", rows.len()),
293                    latency_ms: stmt_started.elapsed().as_millis() as u64,
294                    memory_peak_bytes: 0,
295                    success: false,
296                    error_class: Some("executor_next_error".to_string()),
297                    confidence: 0.0,
298                });
299                return Err(map_entdb_error(e));
300            }
301        } {
302            self.ensure_statement_timeout(stmt_started)?;
303            rows.push(row);
304            if let Some(limit) = max_rows {
305                if limit > 0 && rows.len() >= limit {
306                    break;
307                }
308            }
309        }
310        if let Err(e) = exec.close() {
311            self.record_optimizer_history(OptimizerHistoryRecord {
312                fingerprint,
313                plan_signature: chosen_plan_signature,
314                schema_hash: optimizer_history_schema_hash().to_string(),
315                captured_at_ms: now_epoch_millis(),
316                rowcount_observed_json: format!("{{\"root\":{}}}", rows.len()),
317                latency_ms: stmt_started.elapsed().as_millis() as u64,
318                memory_peak_bytes: 0,
319                success: false,
320                error_class: Some("executor_close_error".to_string()),
321                confidence: 0.0,
322            });
323            return Err(map_entdb_error(e));
324        }
325
326        let observed_rows = if schema.columns.is_empty() {
327            exec.affected_rows().unwrap_or(0)
328        } else {
329            rows.len() as u64
330        };
331        self.record_optimizer_history(OptimizerHistoryRecord {
332            fingerprint,
333            plan_signature: chosen_plan_signature,
334            schema_hash: optimizer_history_schema_hash().to_string(),
335            captured_at_ms: now_epoch_millis(),
336            rowcount_observed_json: format!("{{\"root\":{observed_rows}}}"),
337            latency_ms: stmt_started.elapsed().as_millis() as u64,
338            memory_peak_bytes: 0,
339            success: true,
340            error_class: None,
341            confidence: 1.0,
342        });
343
344        if schema.columns.is_empty() {
345            let affected = exec.affected_rows().unwrap_or(0) as usize;
346            let tag = execution_tag_for_statement(stmt, affected);
347            return Ok(Response::Execution(tag));
348        }
349
350        let fields = schema
351            .columns
352            .iter()
353            .enumerate()
354            .map(|(idx, c)| {
355                FieldInfo::new(
356                    c.name.clone(),
357                    None,
358                    None,
359                    ent_to_pg_type(&c.data_type),
360                    row_format
361                        .map(|f| f.format_for(idx))
362                        .unwrap_or(FieldFormat::Text),
363                )
364            })
365            .collect::<Vec<_>>();
366        let schema = Arc::new(fields);
367
368        let mut encoded = Vec::with_capacity(rows.len());
369        for row in &rows {
370            let mut encoder = DataRowEncoder::new(Arc::clone(&schema));
371            for (idx, value) in row.iter().enumerate() {
372                let field = &schema[idx];
373                encode_value(&mut encoder, value, field.datatype(), field.format())?;
374            }
375            encoded.push(encoder.finish()?);
376        }
377
378        let data_row_stream = stream::iter(encoded.into_iter().map(Ok));
379        let mut query = QueryResponse::new(schema, data_row_stream);
380        query.set_command_tag(tag_for_statement_name(stmt));
381        Ok(Response::Query(query))
382    }
383
384    fn describe_statement_from_stmt(
385        &self,
386        stmt: &Statement,
387        format: &Format,
388    ) -> PgWireResult<Vec<FieldInfo>> {
389        if !statement_returns_rows(stmt) {
390            return Ok(Vec::new());
391        }
392
393        let tx = self.db.txn_manager.begin();
394        let binder = Binder::new(Arc::clone(&self.db.catalog));
395        let planner = Planner;
396
397        let bound = binder.bind(stmt).map_err(map_entdb_error)?;
398        let fingerprint = Optimizer::fingerprint_bound_statement(&bound);
399        let history = self.db.optimizer_history.read_for_fingerprint(&fingerprint);
400        let plan = planner.plan(bound).map_err(map_entdb_error)?;
401        let optimized = Optimizer::optimize_with_trace_and_history(
402            plan,
403            &fingerprint,
404            self.db.optimizer_config,
405            &history,
406        )
407        .plan;
408
409        let ctx = ExecutionContext {
410            catalog: Arc::clone(&self.db.catalog),
411            tx: Some(TxExecutionContext {
412                txn_id: tx.txn_id,
413                snapshot_ts: tx.snapshot_ts,
414                txn_manager: Arc::clone(&self.db.txn_manager),
415            }),
416        };
417
418        let exec = build_executor(&optimized, &ctx).map_err(map_entdb_error)?;
419        let fields = exec
420            .schema()
421            .columns
422            .iter()
423            .enumerate()
424            .map(|(idx, col)| {
425                FieldInfo::new(
426                    col.name.clone(),
427                    None,
428                    None,
429                    ent_to_pg_type(&col.data_type),
430                    format.format_for(idx),
431                )
432            })
433            .collect::<Vec<_>>();
434
435        self.db.txn_manager.abort(tx.txn_id);
436        Ok(fields)
437    }
438
439    fn describe_statement_inner(&self, sql: &str, format: &Format) -> PgWireResult<Vec<FieldInfo>> {
440        let stmt = parse_single_statement_with_polyglot(
441            sql,
442            "describe supports a single statement",
443            self.polyglot,
444        )?;
445        self.describe_statement_from_stmt(&stmt, format)
446    }
447}
448
449impl EntHandler {
450    fn record_optimizer_history(&self, entry: OptimizerHistoryRecord) {
451        self.db.optimizer_history.try_record(entry);
452    }
453
454    fn record_optimizer_history_error(
455        &self,
456        stmt: &Statement,
457        class: &str,
458        elapsed: std::time::Duration,
459    ) {
460        self.record_optimizer_history(OptimizerHistoryRecord {
461            fingerprint: format!("stmt:{}", tag_for_statement_name(stmt)),
462            plan_signature: "error".to_string(),
463            schema_hash: optimizer_history_schema_hash().to_string(),
464            captured_at_ms: now_epoch_millis(),
465            rowcount_observed_json: "{\"root\":0}".to_string(),
466            latency_ms: elapsed.as_millis() as u64,
467            memory_peak_bytes: 0,
468            success: false,
469            error_class: Some(class.to_string()),
470            confidence: 0.0,
471        });
472    }
473}
474
475fn now_epoch_millis() -> u64 {
476    std::time::SystemTime::now()
477        .duration_since(std::time::UNIX_EPOCH)
478        .map(|d| d.as_millis() as u64)
479        .unwrap_or(0)
480}
481
482#[async_trait]
483impl SimpleQueryHandler for EntHandler {
484    async fn do_query<'a, C>(
485        &self,
486        _client: &mut C,
487        query: &'a str,
488    ) -> PgWireResult<Vec<Response<'a>>>
489    where
490        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
491        C::Error: Debug,
492        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
493    {
494        let started = std::time::Instant::now();
495        let result = self.execute_sql(query, None);
496        self.record_query_metric(started, result.as_ref().err());
497        result
498    }
499}
500
501#[async_trait]
502impl ExtendedQueryHandler for EntHandler {
503    type Statement = String;
504    type QueryParser = NoopQueryParser;
505
506    fn query_parser(&self) -> Arc<Self::QueryParser> {
507        Arc::clone(&self.query_parser)
508    }
509
510    async fn do_query<'a, C>(
511        &self,
512        _client: &mut C,
513        portal: &'a Portal<Self::Statement>,
514        max_rows: usize,
515    ) -> PgWireResult<Response<'a>>
516    where
517        C: ClientInfo + Unpin + Send + Sync,
518    {
519        let started = std::time::Instant::now();
520        let result = (|| -> PgWireResult<Response<'a>> {
521            let stmt = statement_from_portal(portal, self.polyglot)?;
522            if matches!(
523                stmt,
524                Statement::StartTransaction { .. }
525                    | Statement::Commit { .. }
526                    | Statement::Rollback { .. }
527            ) {
528                let mut responses = self.execute_sql(&stmt.to_string(), Some(max_rows))?;
529                if responses.is_empty() {
530                    return Ok(Response::EmptyQuery);
531                }
532                return Ok(responses.remove(0));
533            }
534
535            let active_tx = *self.current_txn.lock();
536            if let Some(tx) = active_tx {
537                return self.execute_statement_in_txn(
538                    &tx,
539                    &stmt,
540                    Some(max_rows),
541                    Some(&portal.result_column_format),
542                );
543            }
544
545            let tx = self.db.txn_manager.begin();
546            match self.execute_statement_in_txn(
547                &tx,
548                &stmt,
549                Some(max_rows),
550                Some(&portal.result_column_format),
551            ) {
552                Ok(resp) => {
553                    if let Err(e) = self.db.txn_manager.commit(tx.txn_id) {
554                        self.db.txn_manager.abort(tx.txn_id);
555                        Err(map_entdb_error(e))
556                    } else {
557                        Ok(resp)
558                    }
559                }
560                Err(e) => {
561                    self.db.txn_manager.abort(tx.txn_id);
562                    Err(e)
563                }
564            }
565        })();
566        self.record_query_metric(started, result.as_ref().err());
567        result
568    }
569
570    async fn do_describe_statement<C>(
571        &self,
572        _client: &mut C,
573        stmt: &StoredStatement<Self::Statement>,
574    ) -> PgWireResult<DescribeStatementResponse>
575    where
576        C: ClientInfo + Unpin + Send + Sync,
577    {
578        let mut parameter_types = stmt.parameter_types.clone();
579        let inferred_count = max_placeholder_index(&stmt.statement);
580        if parameter_types.len() < inferred_count {
581            parameter_types.resize(inferred_count, Type::INT4);
582        }
583        for ty in &mut parameter_types {
584            if *ty == Type::UNKNOWN {
585                *ty = Type::INT4;
586            }
587        }
588
589        let fields = self
590            .describe_statement_inner(&stmt.statement, &Format::UnifiedBinary)
591            .or_else(|_| {
592                let normalized = normalize_placeholders_for_describe(&stmt.statement);
593                self.describe_statement_inner(&normalized, &Format::UnifiedBinary)
594            })
595            .unwrap_or_default();
596        Ok(DescribeStatementResponse::new(parameter_types, fields))
597    }
598
599    async fn do_describe_portal<C>(
600        &self,
601        _client: &mut C,
602        portal: &Portal<Self::Statement>,
603    ) -> PgWireResult<DescribePortalResponse>
604    where
605        C: ClientInfo + Unpin + Send + Sync,
606    {
607        let stmt = statement_from_portal(portal, self.polyglot)?;
608        let fields = self.describe_statement_from_stmt(&stmt, &portal.result_column_format)?;
609        Ok(DescribePortalResponse::new(fields))
610    }
611}
612
613#[cfg(test)]
614fn parse_sql(sql: &str) -> PgWireResult<Vec<Statement>> {
615    let dialect = PostgreSqlDialect {};
616    Parser::parse_sql(&dialect, sql).map_err(|e| user_error("42601", format!("parse error: {e}")))
617}
618
619fn parse_sql_with_polyglot(sql: &str, opts: PolyglotOptions) -> PgWireResult<Vec<Statement>> {
620    let transpiled = transpile_with_meta(sql, opts).map_err(map_entdb_error)?;
621    let dialect = PostgreSqlDialect {};
622    Parser::parse_sql(&dialect, &transpiled.transpiled_sql).map_err(|e| {
623        if transpiled.changed {
624            user_error(
625                "42601",
626                format!(
627                    "parse error: {e}; original_sql={:?}; transpiled_sql={:?}",
628                    transpiled.original_sql, transpiled.transpiled_sql
629                ),
630            )
631        } else {
632            user_error("42601", format!("parse error: {e}"))
633        }
634    })
635}
636
637impl EntHandler {
638    fn ensure_statement_timeout(&self, started: std::time::Instant) -> PgWireResult<()> {
639        if started.elapsed().as_millis() as u64 > self.query_timeout_ms {
640            return Err(user_error("57014", "query timeout exceeded"));
641        }
642        Ok(())
643    }
644
645    fn record_query_metric(&self, started: std::time::Instant, err: Option<&PgWireError>) {
646        let elapsed = started.elapsed().as_nanos() as u64;
647        let sqlstate = err.and_then(sqlstate_from_error);
648        self.metrics.on_query_finished(elapsed, sqlstate);
649    }
650}
651
652fn sqlstate_from_error(err: &PgWireError) -> Option<&str> {
653    match err {
654        PgWireError::UserError(info) => Some(info.code.as_str()),
655        _ => Some("XX000"),
656    }
657}
658
659#[cfg(test)]
660fn parse_single_statement(sql: &str, message: &'static str) -> PgWireResult<Statement> {
661    let mut statements = parse_sql(sql)?;
662    if statements.len() != 1 {
663        return Err(user_error("0A000", message));
664    }
665    Ok(statements.remove(0))
666}
667
668fn parse_single_statement_with_polyglot(
669    sql: &str,
670    message: &'static str,
671    opts: PolyglotOptions,
672) -> PgWireResult<Statement> {
673    let mut statements = parse_sql_with_polyglot(sql, opts)?;
674    if statements.len() != 1 {
675        return Err(user_error("0A000", message));
676    }
677    Ok(statements.remove(0))
678}
679
680fn statement_from_portal(
681    portal: &Portal<String>,
682    opts: PolyglotOptions,
683) -> PgWireResult<Statement> {
684    let mut stmt = parse_single_statement_with_polyglot(
685        &portal.statement.statement,
686        "extended protocol supports a single statement per execute",
687        opts,
688    )?;
689    if portal.parameter_len() == 0 {
690        return Ok(stmt);
691    }
692
693    let values = collect_portal_values(portal)?;
694    bind_statement_placeholders(&mut stmt, &values)?;
695    Ok(stmt)
696}
697
698fn collect_portal_values(portal: &Portal<String>) -> PgWireResult<Vec<SqlValue>> {
699    let mut values = Vec::with_capacity(portal.parameter_len());
700    for i in 0..portal.parameter_len() {
701        let pg_type = portal
702            .statement
703            .parameter_types
704            .get(i)
705            .unwrap_or(&Type::UNKNOWN);
706        values.push(parameter_sql_value(portal, i, pg_type)?);
707    }
708    Ok(values)
709}
710
711fn bind_statement_placeholders(stmt: &mut Statement, values: &[SqlValue]) -> PgWireResult<()> {
712    let mut question_mark_idx = 0usize;
713    let outcome = visit_expressions_mut(stmt, |expr| {
714        if let Expr::Value(v) = expr {
715            if let SqlValue::Placeholder(placeholder) = &v.value {
716                let idx = match placeholder_index(placeholder, values.len(), &mut question_mark_idx)
717                {
718                    Ok(idx) => idx,
719                    Err(err) => return ControlFlow::Break(err),
720                };
721                v.value = values[idx - 1].clone();
722            }
723        }
724        ControlFlow::Continue(())
725    });
726
727    match outcome {
728        ControlFlow::Continue(()) => Ok(()),
729        ControlFlow::Break(err) => Err(err),
730    }
731}
732
733fn placeholder_index(
734    placeholder: &str,
735    total: usize,
736    question_mark_idx: &mut usize,
737) -> PgWireResult<usize> {
738    let idx = if placeholder == "?" {
739        *question_mark_idx += 1;
740        *question_mark_idx
741    } else if let Some(raw) = placeholder.strip_prefix('$') {
742        raw.parse::<usize>()
743            .map_err(|_| user_error("22023", format!("invalid placeholder '{placeholder}'")))?
744    } else {
745        return Err(user_error(
746            "22023",
747            format!("unsupported placeholder syntax '{placeholder}'"),
748        ));
749    };
750
751    if idx == 0 || idx > total {
752        return Err(user_error(
753            "22023",
754            format!("placeholder '{placeholder}' is out of range"),
755        ));
756    }
757    Ok(idx)
758}
759
760fn parameter_sql_value(
761    portal: &Portal<String>,
762    idx: usize,
763    pg_type: &Type,
764) -> PgWireResult<SqlValue> {
765    if *pg_type == Type::UNKNOWN {
766        return parameter_sql_value_unknown(portal, idx);
767    }
768    if *pg_type == Type::BOOL {
769        return Ok(match portal.parameter::<bool>(idx, pg_type)? {
770            Some(v) => SqlValue::Boolean(v),
771            None => SqlValue::Null,
772        });
773    }
774    if *pg_type == Type::INT2 {
775        return Ok(match portal.parameter::<i16>(idx, pg_type)? {
776            Some(v) => SqlValue::Number(v.to_string(), false),
777            None => SqlValue::Null,
778        });
779    }
780    if *pg_type == Type::INT4 {
781        return Ok(match portal.parameter::<i32>(idx, pg_type)? {
782            Some(v) => SqlValue::Number(v.to_string(), false),
783            None => SqlValue::Null,
784        });
785    }
786    if *pg_type == Type::INT8 {
787        return Ok(match portal.parameter::<i64>(idx, pg_type)? {
788            Some(v) => SqlValue::Number(v.to_string(), false),
789            None => SqlValue::Null,
790        });
791    }
792    if *pg_type == Type::FLOAT4 {
793        return Ok(match portal.parameter::<f32>(idx, pg_type)? {
794            Some(v) => SqlValue::Number((v as f64).to_string(), false),
795            None => SqlValue::Null,
796        });
797    }
798    if *pg_type == Type::FLOAT8 {
799        return Ok(match portal.parameter::<f64>(idx, pg_type)? {
800            Some(v) => SqlValue::Number(v.to_string(), false),
801            None => SqlValue::Null,
802        });
803    }
804
805    Ok(match portal.parameter::<String>(idx, pg_type)? {
806        Some(v) => SqlValue::SingleQuotedString(v),
807        None => SqlValue::Null,
808    })
809}
810
811fn parameter_sql_value_unknown(portal: &Portal<String>, idx: usize) -> PgWireResult<SqlValue> {
812    let Some(raw) = portal.parameters.get(idx) else {
813        return Ok(SqlValue::Null);
814    };
815    let Some(bytes) = raw else {
816        return Ok(SqlValue::Null);
817    };
818
819    if portal.parameter_format.is_binary(idx) {
820        match bytes.len() {
821            1 => {
822                let b = bytes[0];
823                if b == 0 || b == 1 {
824                    return Ok(SqlValue::Boolean(b == 1));
825                }
826            }
827            4 => {
828                let v = i32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
829                return Ok(SqlValue::Number(v.to_string(), false));
830            }
831            8 => {
832                let v = i64::from_be_bytes([
833                    bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
834                ]);
835                return Ok(SqlValue::Number(v.to_string(), false));
836            }
837            _ => {}
838        }
839    }
840
841    match std::str::from_utf8(bytes) {
842        Ok(s) => {
843            if let Ok(i) = s.parse::<i64>() {
844                return Ok(SqlValue::Number(i.to_string(), false));
845            }
846            if let Ok(f) = s.parse::<f64>() {
847                return Ok(SqlValue::Number(f.to_string(), false));
848            }
849            if s.eq_ignore_ascii_case("true") || s.eq_ignore_ascii_case("false") {
850                return Ok(SqlValue::Boolean(s.eq_ignore_ascii_case("true")));
851            }
852            Ok(SqlValue::SingleQuotedString(s.to_string()))
853        }
854        Err(_) => Ok(SqlValue::SingleQuotedString(
855            String::from_utf8_lossy(bytes).into_owned(),
856        )),
857    }
858}
859
860fn tag_for_statement(stmt: &Statement) -> Tag {
861    Tag::new(tag_for_statement_name(stmt))
862}
863
864fn statement_returns_rows(stmt: &Statement) -> bool {
865    match stmt {
866        Statement::Query(_) => true,
867        Statement::Insert(insert) => insert.returning.is_some(),
868        Statement::Update { returning, .. } => returning.is_some(),
869        Statement::Delete(delete) => delete.returning.is_some(),
870        _ => false,
871    }
872}
873
874fn execution_tag_for_statement(stmt: &Statement, affected: usize) -> Tag {
875    match stmt {
876        // PostgreSQL command-complete format for INSERT is "INSERT 0 <rows>".
877        Statement::Insert(_) => Tag::new("INSERT 0").with_rows(affected),
878        _ => tag_for_statement(stmt).with_rows(affected),
879    }
880}
881
882fn tag_for_statement_name(stmt: &Statement) -> &'static str {
883    match stmt {
884        Statement::Query(_) => "SELECT",
885        Statement::Insert(_) => "INSERT",
886        Statement::Update { .. } => "UPDATE",
887        Statement::Delete { .. } => "DELETE",
888        Statement::Truncate { .. } => "TRUNCATE",
889        Statement::CreateTable { .. } => "CREATE TABLE",
890        Statement::CreateIndex(_) => "CREATE INDEX",
891        Statement::AlterTable { .. } => "ALTER TABLE",
892        Statement::Drop {
893            object_type: sqlparser::ast::ObjectType::Table,
894            ..
895        } => "DROP TABLE",
896        Statement::Drop {
897            object_type: sqlparser::ast::ObjectType::Index,
898            ..
899        } => "DROP INDEX",
900        Statement::StartTransaction { .. } => "BEGIN",
901        Statement::Commit { .. } => "COMMIT",
902        Statement::Rollback { .. } => "ROLLBACK",
903        _ => "OK",
904    }
905}
906
907fn map_entdb_error(err: EntDbError) -> PgWireError {
908    let (code, message) = match err {
909        EntDbError::Query(m) => (query_sqlstate(&m), m),
910        EntDbError::BufferPoolFull => ("53200", "buffer pool full".to_string()),
911        EntDbError::PagePinned(pid) => ("55006", format!("page {pid} is pinned")),
912        EntDbError::PageNotFound(pid) => ("XX000", format!("page {pid} not found")),
913        EntDbError::Io(e) => ("58000", format!("io error: {e}")),
914        EntDbError::Wal(m) => ("XX000", format!("wal error: {m}")),
915        EntDbError::Corruption(m) => ("XX001", format!("corruption: {m}")),
916        EntDbError::InvalidPage(m) => ("XX000", format!("invalid page: {m}")),
917        EntDbError::PageAlreadyPresent(pid) => ("XX000", format!("page {pid} already present")),
918    };
919    user_error(code, message)
920}
921
922fn query_sqlstate(message: &str) -> &'static str {
923    let lower = message.to_ascii_lowercase();
924    if lower.contains("parse error") {
925        return "42601";
926    }
927    if lower.contains("already active") {
928        return "25001";
929    }
930    if lower.contains("no active transaction") {
931        return "25000";
932    }
933    if lower.contains("write-write conflict") {
934        return "40001";
935    }
936    if lower.contains("does not exist") {
937        if lower.contains("table") {
938            return "42P01";
939        }
940        if lower.contains("column") {
941            return "42703";
942        }
943    }
944    if lower.contains("cannot cast value")
945        || lower.contains("invalid")
946        || lower.contains("unsupported")
947        || lower.contains("must be numeric")
948    {
949        return "22000";
950    }
951    "XX000"
952}
953
954fn user_error(code: impl Into<String>, message: impl Into<String>) -> PgWireError {
955    PgWireError::UserError(Box::new(ErrorInfo::new(
956        "ERROR".to_string(),
957        code.into(),
958        message.into(),
959    )))
960}
961
962fn max_placeholder_index(sql: &str) -> usize {
963    let bytes = sql.as_bytes();
964    let mut max_idx = 0usize;
965    let mut i = 0usize;
966    while i < bytes.len() {
967        if bytes[i] == b'$' {
968            let start = i + 1;
969            let mut j = start;
970            while j < bytes.len() && bytes[j].is_ascii_digit() {
971                j += 1;
972            }
973            if j > start {
974                if let Ok(n) = sql[start..j].parse::<usize>() {
975                    max_idx = max_idx.max(n);
976                }
977                i = j;
978                continue;
979            }
980        }
981        i += 1;
982    }
983    max_idx
984}
985
986fn normalize_placeholders_for_describe(sql: &str) -> String {
987    let bytes = sql.as_bytes();
988    let mut out = String::with_capacity(sql.len());
989    let mut i = 0usize;
990    while i < bytes.len() {
991        if bytes[i] == b'$' {
992            let start = i + 1;
993            let mut j = start;
994            while j < bytes.len() && bytes[j].is_ascii_digit() {
995                j += 1;
996            }
997            if j > start {
998                // Use a numeric placeholder value so planning/type checks succeed
999                // for common predicates and LIMIT/OFFSET contexts.
1000                out.push('0');
1001                i = j;
1002                continue;
1003            }
1004        }
1005        out.push(bytes[i] as char);
1006        i += 1;
1007    }
1008    out
1009}
1010
1011pub fn scan_max_txn_id_from_storage(catalog: &Catalog) -> Result<u64, EntDbError> {
1012    let mut max_txn = 0_u64;
1013    for table in catalog.list_tables() {
1014        let t = Table::open(table.table_id, table.first_page_id, catalog.buffer_pool());
1015        for (_, tuple) in t.scan() {
1016            let decoded = entdb::query::executor::decode_stored_row(&tuple.data)?;
1017            if let entdb::query::executor::DecodedRow::Versioned(v) = decoded {
1018                max_txn = max_txn.max(v.created_txn);
1019                if let Some(d) = v.deleted_txn {
1020                    max_txn = max_txn.max(d);
1021                }
1022            }
1023        }
1024    }
1025    Ok(max_txn)
1026}
1027
1028#[cfg(test)]
1029mod tests {
1030    use super::{
1031        bind_statement_placeholders, execution_tag_for_statement, max_placeholder_index,
1032        normalize_placeholders_for_describe, parse_single_statement, parse_sql, EntHandler,
1033    };
1034    use crate::server::Database;
1035    use entdb::catalog::Catalog;
1036    use entdb::query::history::OptimizerHistoryRecorder;
1037    use entdb::query::optimizer::OptimizerConfig;
1038    use entdb::storage::buffer_pool::BufferPool;
1039    use entdb::storage::disk_manager::DiskManager;
1040    use entdb::tx::TransactionManager;
1041    use entdb::wal::log_manager::LogManager;
1042    use pgwire::api::results::Tag;
1043    use pgwire::error::PgWireError;
1044    use proptest::prelude::*;
1045    use sqlparser::ast::Value as SqlValue;
1046    use std::sync::Arc;
1047    use tempfile::tempdir;
1048
1049    fn test_handler() -> EntHandler {
1050        let dir = tempdir().expect("tempdir");
1051        let db_path = dir.path().join("handler.db");
1052        let wal_path = dir.path().join("handler.wal");
1053
1054        let dm = Arc::new(DiskManager::new(&db_path).expect("disk"));
1055        let lm = Arc::new(LogManager::new(&wal_path, 4096).expect("wal"));
1056        let bp = Arc::new(BufferPool::with_log_manager(
1057            32,
1058            Arc::clone(&dm),
1059            Arc::clone(&lm),
1060        ));
1061        let catalog = Arc::new(Catalog::load(Arc::clone(&bp)).expect("catalog"));
1062        let txn = Arc::new(TransactionManager::new());
1063        let history_path = dir.path().join("handler.optimizer_history.json");
1064        let optimizer_history = Arc::new(
1065            OptimizerHistoryRecorder::new(
1066                &history_path,
1067                super::optimizer_history_schema_hash(),
1068                16,
1069                128,
1070            )
1071            .expect("optimizer history"),
1072        );
1073        let db = Arc::new(Database {
1074            disk_manager: dm,
1075            log_manager: lm,
1076            buffer_pool: bp,
1077            catalog,
1078            txn_manager: txn,
1079            optimizer_history,
1080            optimizer_config: OptimizerConfig::default(),
1081        });
1082        EntHandler::new(
1083            db,
1084            1024 * 1024,
1085            30_000,
1086            Arc::new(super::ServerMetrics::default()),
1087        )
1088    }
1089
1090    #[test]
1091    fn bind_statement_placeholders_replaces_parameter_nodes() {
1092        let mut stmt = parse_single_statement(
1093            "SELECT id FROM t WHERE v > $1 ORDER BY id LIMIT $2",
1094            "single",
1095        )
1096        .expect("parse");
1097        bind_statement_placeholders(
1098            &mut stmt,
1099            &[
1100                SqlValue::Number("7".to_string(), false),
1101                SqlValue::Number("3".to_string(), false),
1102            ],
1103        )
1104        .expect("bind placeholders");
1105
1106        let rendered = stmt.to_string();
1107        assert!(rendered.contains("v > 7"), "rendered SQL: {rendered}");
1108        assert!(rendered.contains("LIMIT 3"), "rendered SQL: {rendered}");
1109    }
1110
1111    #[test]
1112    fn bind_statement_placeholders_rejects_out_of_range_index() {
1113        let mut stmt = parse_single_statement("SELECT $3", "single").expect("parse");
1114        let err = match bind_statement_placeholders(
1115            &mut stmt,
1116            &[SqlValue::Number("1".to_string(), false)],
1117        ) {
1118            Ok(_) => panic!("expected out-of-range placeholder failure"),
1119            Err(e) => e,
1120        };
1121        match err {
1122            PgWireError::UserError(info) => {
1123                assert_eq!(info.code, "22023");
1124            }
1125            other => panic!("unexpected error: {other:?}"),
1126        }
1127    }
1128
1129    #[test]
1130    fn parse_single_statement_rejects_multi_statement_sql() {
1131        let err = match parse_single_statement("SELECT 1; SELECT 2", "single") {
1132            Ok(_) => panic!("expected single-statement rejection"),
1133            Err(e) => e,
1134        };
1135        match err {
1136            PgWireError::UserError(info) => {
1137                assert_eq!(info.code, "0A000");
1138            }
1139            other => panic!("unexpected error: {other:?}"),
1140        }
1141    }
1142
1143    #[test]
1144    fn max_placeholder_index_finds_highest_parameter() {
1145        assert_eq!(max_placeholder_index("SELECT $1, $2, $10"), 10);
1146        assert_eq!(max_placeholder_index("SELECT 1"), 0);
1147    }
1148
1149    #[test]
1150    fn normalize_placeholders_for_describe_rewrites_all_numbered_params() {
1151        let sql = "SELECT id FROM t WHERE v > $1 ORDER BY id LIMIT $2";
1152        let normalized = normalize_placeholders_for_describe(sql);
1153        assert_eq!(
1154            normalized,
1155            "SELECT id FROM t WHERE v > 0 ORDER BY id LIMIT 0"
1156        );
1157    }
1158
1159    #[test]
1160    fn execution_tag_for_insert_matches_postgres_format() {
1161        let stmt = parse_single_statement("INSERT INTO t VALUES (1)", "single").expect("parse");
1162        let tag = execution_tag_for_statement(&stmt, 2);
1163        assert_eq!(tag, Tag::new("INSERT 0").with_rows(2));
1164    }
1165
1166    #[test]
1167    fn execution_tag_for_update_and_delete_include_row_count() {
1168        let update = parse_single_statement("UPDATE t SET v = 1", "single").expect("parse");
1169        let delete = parse_single_statement("DELETE FROM t", "single").expect("parse");
1170        assert_eq!(
1171            execution_tag_for_statement(&update, 3),
1172            Tag::new("UPDATE").with_rows(3)
1173        );
1174        assert_eq!(
1175            execution_tag_for_statement(&delete, 4),
1176            Tag::new("DELETE").with_rows(4)
1177        );
1178    }
1179
1180    #[test]
1181    fn handler_begin_commit_round_trip() {
1182        let handler = test_handler();
1183        let out = handler
1184            .execute_sql("BEGIN; COMMIT;", None)
1185            .expect("execute");
1186        assert_eq!(out.len(), 2);
1187    }
1188
1189    #[test]
1190    fn handler_enforces_max_statement_size() {
1191        let handler = test_handler();
1192        let huge = "X".repeat(2 * 1024 * 1024);
1193        let err = match handler.execute_sql(&huge, None) {
1194            Ok(_) => panic!("max statement size should be enforced"),
1195            Err(e) => e,
1196        };
1197        assert!(err.to_string().contains("max_statement_bytes"));
1198    }
1199
1200    proptest! {
1201        #[test]
1202        fn parse_sql_never_panics_on_random_input(input in ".*") {
1203            let _ = parse_sql(&input);
1204        }
1205    }
1206}