Skip to main content

entdb_server/server/
handler.rs

1/*
2 * Copyright 2026 EntDB Authors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::server::type_map::{encode_value, ent_to_pg_type};
18use async_trait::async_trait;
19use entdb::catalog::Catalog;
20use entdb::error::EntDbError;
21use entdb::query::binder::Binder;
22use entdb::query::executor::{build_executor, ExecutionContext, TxExecutionContext};
23use entdb::query::history::OptimizerHistoryRecord;
24use entdb::query::optimizer::Optimizer;
25use entdb::query::planner::Planner;
26use entdb::query::polyglot::{transpile_with_meta, PolyglotOptions};
27use entdb::storage::table::Table;
28use entdb::tx::TransactionHandle;
29use futures::{stream, Sink};
30use parking_lot::Mutex;
31use pgwire::api::portal::{Format, Portal};
32use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
33use pgwire::api::results::{
34    DataRowEncoder, DescribePortalResponse, DescribeStatementResponse, FieldFormat, FieldInfo,
35    QueryResponse, Response, Tag,
36};
37use pgwire::api::stmt::{NoopQueryParser, StoredStatement};
38use pgwire::api::{ClientInfo, Type};
39use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
40use pgwire::messages::PgWireBackendMessage;
41use sqlparser::ast::{visit_expressions_mut, Expr, Statement, Value as SqlValue};
42use sqlparser::dialect::PostgreSqlDialect;
43use sqlparser::parser::Parser;
44use std::fmt::Debug;
45use std::ops::ControlFlow;
46use std::sync::Arc;
47use tracing::{debug, info_span};
48
49use super::metrics::ServerMetrics;
50use super::{optimizer_history_schema_hash, Database};
51
52pub struct EntHandler {
53    db: Arc<Database>,
54    current_txn: Mutex<Option<TransactionHandle>>,
55    query_parser: Arc<NoopQueryParser>,
56    max_statement_bytes: usize,
57    query_timeout_ms: u64,
58    metrics: Arc<ServerMetrics>,
59    polyglot: PolyglotOptions,
60    await_durable: bool,
61}
62
63impl EntHandler {
64    pub fn new(
65        db: Arc<Database>,
66        max_statement_bytes: usize,
67        query_timeout_ms: u64,
68        metrics: Arc<ServerMetrics>,
69        await_durable: bool,
70    ) -> Self {
71        Self {
72            db,
73            current_txn: Mutex::new(None),
74            query_parser: Arc::new(NoopQueryParser::new()),
75            max_statement_bytes,
76            query_timeout_ms,
77            metrics,
78            await_durable,
79            polyglot: PolyglotOptions {
80                enabled: std::env::var("ENTDB_POLYGLOT")
81                    .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
82                    .unwrap_or(false),
83            },
84        }
85    }
86
87    fn execute_sql(
88        &self,
89        sql: &str,
90        max_rows: Option<usize>,
91    ) -> PgWireResult<Vec<Response<'static>>> {
92        if sql.len() > self.max_statement_bytes {
93            return Err(user_error(
94                "54000",
95                format!(
96                    "statement exceeds configured max_statement_bytes ({} > {})",
97                    sql.len(),
98                    self.max_statement_bytes
99                ),
100            ));
101        }
102        let _span = info_span!(
103            "pgwire.execute_sql",
104            sql_len = sql.len(),
105            max_rows = max_rows.unwrap_or(0)
106        )
107        .entered();
108        let statements = parse_sql_with_polyglot(sql, self.polyglot)?;
109        let mut responses = Vec::new();
110        let started = std::time::Instant::now();
111
112        for stmt in &statements {
113            if started.elapsed().as_millis() as u64 > self.query_timeout_ms {
114                return Err(user_error("57014", "query timeout exceeded"));
115            }
116            debug!(statement = %tag_for_statement_name(stmt), "processing statement");
117            match stmt {
118                Statement::StartTransaction { .. } => {
119                    let mut guard = self.current_txn.lock();
120                    if guard.is_some() {
121                        return Err(user_error(
122                            "25001",
123                            "transaction already active for this session",
124                        ));
125                    }
126                    *guard = Some(self.db.txn_manager.begin());
127                    responses.push(Response::TransactionStart(Tag::new("BEGIN")));
128                }
129                Statement::Commit { .. } => {
130                    let tx = {
131                        let mut guard = self.current_txn.lock();
132                        guard
133                            .take()
134                            .ok_or_else(|| user_error("25000", "no active transaction to COMMIT"))?
135                    };
136                    self.db
137                        .txn_manager
138                        .commit_with_options(tx.txn_id, Some(self.await_durable))
139                        .map_err(map_entdb_error)?;
140                    responses.push(Response::TransactionEnd(Tag::new("COMMIT")));
141                }
142                Statement::Rollback { .. } => {
143                    let tx = {
144                        let mut guard = self.current_txn.lock();
145                        guard.take().ok_or_else(|| {
146                            user_error("25000", "no active transaction to ROLLBACK")
147                        })?
148                    };
149                    self.db.txn_manager.abort(tx.txn_id);
150                    responses.push(Response::TransactionEnd(Tag::new("ROLLBACK")));
151                }
152                _ => {
153                    let active_tx = *self.current_txn.lock();
154                    let exec_resp = if let Some(tx) = active_tx {
155                        self.execute_statement_in_txn(&tx, stmt, max_rows, None)?
156                    } else {
157                        let tx = self.db.txn_manager.begin();
158                        match self.execute_statement_in_txn(&tx, stmt, max_rows, None) {
159                            Ok(resp) => {
160                                if let Err(e) = self
161                                    .db
162                                    .txn_manager
163                                    .commit_with_options(tx.txn_id, Some(self.await_durable))
164                                {
165                                    self.db.txn_manager.abort(tx.txn_id);
166                                    return Err(map_entdb_error(e));
167                                }
168                                resp
169                            }
170                            Err(e) => {
171                                self.db.txn_manager.abort(tx.txn_id);
172                                return Err(e);
173                            }
174                        }
175                    };
176                    responses.push(exec_resp);
177                }
178            }
179        }
180
181        Ok(responses)
182    }
183
184    fn execute_statement_in_txn(
185        &self,
186        tx: &TransactionHandle,
187        stmt: &Statement,
188        max_rows: Option<usize>,
189        row_format: Option<&Format>,
190    ) -> PgWireResult<Response<'static>> {
191        let stmt_started = std::time::Instant::now();
192        let _span = info_span!(
193            "pgwire.execute_statement",
194            statement = %tag_for_statement_name(stmt),
195            txn_id = tx.txn_id
196        )
197        .entered();
198        let binder = Binder::new(Arc::clone(&self.db.catalog));
199        let planner = Planner;
200
201        let bound = match binder.bind(stmt) {
202            Ok(b) => b,
203            Err(e) => {
204                self.record_optimizer_history_error(stmt, "bind_error", stmt_started.elapsed());
205                return Err(map_entdb_error(e));
206            }
207        };
208        let fingerprint = Optimizer::fingerprint_bound_statement(&bound);
209        let history = self.db.optimizer_history.read_for_fingerprint(&fingerprint);
210        self.ensure_statement_timeout(stmt_started)?;
211        let plan = match planner.plan(bound) {
212            Ok(p) => p,
213            Err(e) => {
214                self.record_optimizer_history(OptimizerHistoryRecord {
215                    fingerprint,
216                    plan_signature: "error".to_string(),
217                    schema_hash: optimizer_history_schema_hash().to_string(),
218                    captured_at_ms: now_epoch_millis(),
219                    rowcount_observed_json: "{\"root\":0}".to_string(),
220                    latency_ms: stmt_started.elapsed().as_millis() as u64,
221                    memory_peak_bytes: 0,
222                    success: false,
223                    error_class: Some("plan_error".to_string()),
224                    confidence: 0.0,
225                });
226                return Err(map_entdb_error(e));
227            }
228        };
229        self.ensure_statement_timeout(stmt_started)?;
230        let optimized_outcome = Optimizer::optimize_with_trace_and_history(
231            plan,
232            &fingerprint,
233            self.db.optimizer_config,
234            &history,
235        );
236        let chosen_plan_signature = optimized_outcome
237            .trace
238            .chosen_plan_signature
239            .clone()
240            .unwrap_or_else(|| "baseline".to_string());
241        let optimized = optimized_outcome.plan;
242        self.ensure_statement_timeout(stmt_started)?;
243
244        let ctx = ExecutionContext {
245            catalog: Arc::clone(&self.db.catalog),
246            tx: Some(TxExecutionContext {
247                txn_id: tx.txn_id,
248                snapshot_ts: tx.snapshot_ts,
249                txn_manager: Arc::clone(&self.db.txn_manager),
250            }),
251        };
252
253        let mut exec = match build_executor(&optimized, &ctx) {
254            Ok(exec) => exec,
255            Err(e) => {
256                self.record_optimizer_history(OptimizerHistoryRecord {
257                    fingerprint,
258                    plan_signature: chosen_plan_signature,
259                    schema_hash: optimizer_history_schema_hash().to_string(),
260                    captured_at_ms: now_epoch_millis(),
261                    rowcount_observed_json: "{\"root\":0}".to_string(),
262                    latency_ms: stmt_started.elapsed().as_millis() as u64,
263                    memory_peak_bytes: 0,
264                    success: false,
265                    error_class: Some("executor_build_error".to_string()),
266                    confidence: 0.0,
267                });
268                return Err(map_entdb_error(e));
269            }
270        };
271        self.ensure_statement_timeout(stmt_started)?;
272        if let Err(e) = exec.open() {
273            self.record_optimizer_history(OptimizerHistoryRecord {
274                fingerprint,
275                plan_signature: chosen_plan_signature,
276                schema_hash: optimizer_history_schema_hash().to_string(),
277                captured_at_ms: now_epoch_millis(),
278                rowcount_observed_json: "{\"root\":0}".to_string(),
279                latency_ms: stmt_started.elapsed().as_millis() as u64,
280                memory_peak_bytes: 0,
281                success: false,
282                error_class: Some("executor_open_error".to_string()),
283                confidence: 0.0,
284            });
285            return Err(map_entdb_error(e));
286        }
287        self.ensure_statement_timeout(stmt_started)?;
288
289        let schema = exec.schema().clone();
290        let mut rows = Vec::new();
291        while let Some(row) = match exec.next() {
292            Ok(r) => r,
293            Err(e) => {
294                self.record_optimizer_history(OptimizerHistoryRecord {
295                    fingerprint: fingerprint.clone(),
296                    plan_signature: chosen_plan_signature.clone(),
297                    schema_hash: optimizer_history_schema_hash().to_string(),
298                    captured_at_ms: now_epoch_millis(),
299                    rowcount_observed_json: format!("{{\"root\":{}}}", rows.len()),
300                    latency_ms: stmt_started.elapsed().as_millis() as u64,
301                    memory_peak_bytes: 0,
302                    success: false,
303                    error_class: Some("executor_next_error".to_string()),
304                    confidence: 0.0,
305                });
306                return Err(map_entdb_error(e));
307            }
308        } {
309            self.ensure_statement_timeout(stmt_started)?;
310            rows.push(row);
311            if let Some(limit) = max_rows {
312                if limit > 0 && rows.len() >= limit {
313                    break;
314                }
315            }
316        }
317        if let Err(e) = exec.close() {
318            self.record_optimizer_history(OptimizerHistoryRecord {
319                fingerprint,
320                plan_signature: chosen_plan_signature,
321                schema_hash: optimizer_history_schema_hash().to_string(),
322                captured_at_ms: now_epoch_millis(),
323                rowcount_observed_json: format!("{{\"root\":{}}}", rows.len()),
324                latency_ms: stmt_started.elapsed().as_millis() as u64,
325                memory_peak_bytes: 0,
326                success: false,
327                error_class: Some("executor_close_error".to_string()),
328                confidence: 0.0,
329            });
330            return Err(map_entdb_error(e));
331        }
332
333        let observed_rows = if schema.columns.is_empty() {
334            exec.affected_rows().unwrap_or(0)
335        } else {
336            rows.len() as u64
337        };
338        self.record_optimizer_history(OptimizerHistoryRecord {
339            fingerprint,
340            plan_signature: chosen_plan_signature,
341            schema_hash: optimizer_history_schema_hash().to_string(),
342            captured_at_ms: now_epoch_millis(),
343            rowcount_observed_json: format!("{{\"root\":{observed_rows}}}"),
344            latency_ms: stmt_started.elapsed().as_millis() as u64,
345            memory_peak_bytes: 0,
346            success: true,
347            error_class: None,
348            confidence: 1.0,
349        });
350
351        if schema.columns.is_empty() {
352            let affected = exec.affected_rows().unwrap_or(0) as usize;
353            let tag = execution_tag_for_statement(stmt, affected);
354            return Ok(Response::Execution(tag));
355        }
356
357        let fields = schema
358            .columns
359            .iter()
360            .enumerate()
361            .map(|(idx, c)| {
362                FieldInfo::new(
363                    c.name.clone(),
364                    None,
365                    None,
366                    ent_to_pg_type(&c.data_type),
367                    row_format
368                        .map(|f| f.format_for(idx))
369                        .unwrap_or(FieldFormat::Text),
370                )
371            })
372            .collect::<Vec<_>>();
373        let schema = Arc::new(fields);
374
375        let mut encoded = Vec::with_capacity(rows.len());
376        for row in &rows {
377            let mut encoder = DataRowEncoder::new(Arc::clone(&schema));
378            for (idx, value) in row.iter().enumerate() {
379                let field = &schema[idx];
380                encode_value(&mut encoder, value, field.datatype(), field.format())?;
381            }
382            encoded.push(encoder.finish()?);
383        }
384
385        let data_row_stream = stream::iter(encoded.into_iter().map(Ok));
386        let mut query = QueryResponse::new(schema, data_row_stream);
387        query.set_command_tag(tag_for_statement_name(stmt));
388        Ok(Response::Query(query))
389    }
390
391    fn describe_statement_from_stmt(
392        &self,
393        stmt: &Statement,
394        format: &Format,
395    ) -> PgWireResult<Vec<FieldInfo>> {
396        if !statement_returns_rows(stmt) {
397            return Ok(Vec::new());
398        }
399
400        let tx = self.db.txn_manager.begin();
401        let binder = Binder::new(Arc::clone(&self.db.catalog));
402        let planner = Planner;
403
404        let bound = binder.bind(stmt).map_err(map_entdb_error)?;
405        let fingerprint = Optimizer::fingerprint_bound_statement(&bound);
406        let history = self.db.optimizer_history.read_for_fingerprint(&fingerprint);
407        let plan = planner.plan(bound).map_err(map_entdb_error)?;
408        let optimized = Optimizer::optimize_with_trace_and_history(
409            plan,
410            &fingerprint,
411            self.db.optimizer_config,
412            &history,
413        )
414        .plan;
415
416        let ctx = ExecutionContext {
417            catalog: Arc::clone(&self.db.catalog),
418            tx: Some(TxExecutionContext {
419                txn_id: tx.txn_id,
420                snapshot_ts: tx.snapshot_ts,
421                txn_manager: Arc::clone(&self.db.txn_manager),
422            }),
423        };
424
425        let exec = build_executor(&optimized, &ctx).map_err(map_entdb_error)?;
426        let fields = exec
427            .schema()
428            .columns
429            .iter()
430            .enumerate()
431            .map(|(idx, col)| {
432                FieldInfo::new(
433                    col.name.clone(),
434                    None,
435                    None,
436                    ent_to_pg_type(&col.data_type),
437                    format.format_for(idx),
438                )
439            })
440            .collect::<Vec<_>>();
441
442        self.db.txn_manager.abort(tx.txn_id);
443        Ok(fields)
444    }
445
446    fn describe_statement_inner(&self, sql: &str, format: &Format) -> PgWireResult<Vec<FieldInfo>> {
447        let stmt = parse_single_statement_with_polyglot(
448            sql,
449            "describe supports a single statement",
450            self.polyglot,
451        )?;
452        self.describe_statement_from_stmt(&stmt, format)
453    }
454}
455
456impl EntHandler {
457    fn record_optimizer_history(&self, entry: OptimizerHistoryRecord) {
458        self.db.optimizer_history.try_record(entry);
459    }
460
461    fn record_optimizer_history_error(
462        &self,
463        stmt: &Statement,
464        class: &str,
465        elapsed: std::time::Duration,
466    ) {
467        self.record_optimizer_history(OptimizerHistoryRecord {
468            fingerprint: format!("stmt:{}", tag_for_statement_name(stmt)),
469            plan_signature: "error".to_string(),
470            schema_hash: optimizer_history_schema_hash().to_string(),
471            captured_at_ms: now_epoch_millis(),
472            rowcount_observed_json: "{\"root\":0}".to_string(),
473            latency_ms: elapsed.as_millis() as u64,
474            memory_peak_bytes: 0,
475            success: false,
476            error_class: Some(class.to_string()),
477            confidence: 0.0,
478        });
479    }
480}
481
482fn now_epoch_millis() -> u64 {
483    std::time::SystemTime::now()
484        .duration_since(std::time::UNIX_EPOCH)
485        .map(|d| d.as_millis() as u64)
486        .unwrap_or(0)
487}
488
489#[async_trait]
490impl SimpleQueryHandler for EntHandler {
491    async fn do_query<'a, C>(
492        &self,
493        _client: &mut C,
494        query: &'a str,
495    ) -> PgWireResult<Vec<Response<'a>>>
496    where
497        C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
498        C::Error: Debug,
499        PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
500    {
501        let started = std::time::Instant::now();
502        let result = self.execute_sql(query, None);
503        self.record_query_metric(started, result.as_ref().err());
504        result
505    }
506}
507
508#[async_trait]
509impl ExtendedQueryHandler for EntHandler {
510    type Statement = String;
511    type QueryParser = NoopQueryParser;
512
513    fn query_parser(&self) -> Arc<Self::QueryParser> {
514        Arc::clone(&self.query_parser)
515    }
516
517    async fn do_query<'a, C>(
518        &self,
519        _client: &mut C,
520        portal: &'a Portal<Self::Statement>,
521        max_rows: usize,
522    ) -> PgWireResult<Response<'a>>
523    where
524        C: ClientInfo + Unpin + Send + Sync,
525    {
526        let started = std::time::Instant::now();
527        let result = (|| -> PgWireResult<Response<'a>> {
528            let stmt = statement_from_portal(portal, self.polyglot)?;
529            if matches!(
530                stmt,
531                Statement::StartTransaction { .. }
532                    | Statement::Commit { .. }
533                    | Statement::Rollback { .. }
534            ) {
535                let mut responses = self.execute_sql(&stmt.to_string(), Some(max_rows))?;
536                if responses.is_empty() {
537                    return Ok(Response::EmptyQuery);
538                }
539                return Ok(responses.remove(0));
540            }
541
542            let active_tx = *self.current_txn.lock();
543            if let Some(tx) = active_tx {
544                return self.execute_statement_in_txn(
545                    &tx,
546                    &stmt,
547                    Some(max_rows),
548                    Some(&portal.result_column_format),
549                );
550            }
551
552            let tx = self.db.txn_manager.begin();
553            match self.execute_statement_in_txn(
554                &tx,
555                &stmt,
556                Some(max_rows),
557                Some(&portal.result_column_format),
558            ) {
559                Ok(resp) => {
560                    if let Err(e) = self
561                        .db
562                        .txn_manager
563                        .commit_with_options(tx.txn_id, Some(self.await_durable))
564                    {
565                        self.db.txn_manager.abort(tx.txn_id);
566                        Err(map_entdb_error(e))
567                    } else {
568                        Ok(resp)
569                    }
570                }
571                Err(e) => {
572                    self.db.txn_manager.abort(tx.txn_id);
573                    Err(e)
574                }
575            }
576        })();
577        self.record_query_metric(started, result.as_ref().err());
578        result
579    }
580
581    async fn do_describe_statement<C>(
582        &self,
583        _client: &mut C,
584        stmt: &StoredStatement<Self::Statement>,
585    ) -> PgWireResult<DescribeStatementResponse>
586    where
587        C: ClientInfo + Unpin + Send + Sync,
588    {
589        let mut parameter_types = stmt.parameter_types.clone();
590        let inferred_count = max_placeholder_index(&stmt.statement);
591        if parameter_types.len() < inferred_count {
592            parameter_types.resize(inferred_count, Type::INT4);
593        }
594        for ty in &mut parameter_types {
595            if *ty == Type::UNKNOWN {
596                *ty = Type::INT4;
597            }
598        }
599
600        let fields = self
601            .describe_statement_inner(&stmt.statement, &Format::UnifiedBinary)
602            .or_else(|_| {
603                let normalized = normalize_placeholders_for_describe(&stmt.statement);
604                self.describe_statement_inner(&normalized, &Format::UnifiedBinary)
605            })
606            .unwrap_or_default();
607        Ok(DescribeStatementResponse::new(parameter_types, fields))
608    }
609
610    async fn do_describe_portal<C>(
611        &self,
612        _client: &mut C,
613        portal: &Portal<Self::Statement>,
614    ) -> PgWireResult<DescribePortalResponse>
615    where
616        C: ClientInfo + Unpin + Send + Sync,
617    {
618        let stmt = statement_from_portal(portal, self.polyglot)?;
619        let fields = self.describe_statement_from_stmt(&stmt, &portal.result_column_format)?;
620        Ok(DescribePortalResponse::new(fields))
621    }
622}
623
624#[cfg(test)]
625fn parse_sql(sql: &str) -> PgWireResult<Vec<Statement>> {
626    let dialect = PostgreSqlDialect {};
627    Parser::parse_sql(&dialect, sql).map_err(|e| user_error("42601", format!("parse error: {e}")))
628}
629
630fn parse_sql_with_polyglot(sql: &str, opts: PolyglotOptions) -> PgWireResult<Vec<Statement>> {
631    let transpiled = transpile_with_meta(sql, opts).map_err(map_entdb_error)?;
632    let dialect = PostgreSqlDialect {};
633    Parser::parse_sql(&dialect, &transpiled.transpiled_sql).map_err(|e| {
634        if transpiled.changed {
635            user_error(
636                "42601",
637                format!(
638                    "parse error: {e}; original_sql={:?}; transpiled_sql={:?}",
639                    transpiled.original_sql, transpiled.transpiled_sql
640                ),
641            )
642        } else {
643            user_error("42601", format!("parse error: {e}"))
644        }
645    })
646}
647
648impl EntHandler {
649    fn ensure_statement_timeout(&self, started: std::time::Instant) -> PgWireResult<()> {
650        if started.elapsed().as_millis() as u64 > self.query_timeout_ms {
651            return Err(user_error("57014", "query timeout exceeded"));
652        }
653        Ok(())
654    }
655
656    fn record_query_metric(&self, started: std::time::Instant, err: Option<&PgWireError>) {
657        let elapsed = started.elapsed().as_nanos() as u64;
658        let sqlstate = err.and_then(sqlstate_from_error);
659        self.metrics.on_query_finished(elapsed, sqlstate);
660    }
661}
662
663fn sqlstate_from_error(err: &PgWireError) -> Option<&str> {
664    match err {
665        PgWireError::UserError(info) => Some(info.code.as_str()),
666        _ => Some("XX000"),
667    }
668}
669
670#[cfg(test)]
671fn parse_single_statement(sql: &str, message: &'static str) -> PgWireResult<Statement> {
672    let mut statements = parse_sql(sql)?;
673    if statements.len() != 1 {
674        return Err(user_error("0A000", message));
675    }
676    Ok(statements.remove(0))
677}
678
679fn parse_single_statement_with_polyglot(
680    sql: &str,
681    message: &'static str,
682    opts: PolyglotOptions,
683) -> PgWireResult<Statement> {
684    let mut statements = parse_sql_with_polyglot(sql, opts)?;
685    if statements.len() != 1 {
686        return Err(user_error("0A000", message));
687    }
688    Ok(statements.remove(0))
689}
690
691fn statement_from_portal(
692    portal: &Portal<String>,
693    opts: PolyglotOptions,
694) -> PgWireResult<Statement> {
695    let mut stmt = parse_single_statement_with_polyglot(
696        &portal.statement.statement,
697        "extended protocol supports a single statement per execute",
698        opts,
699    )?;
700    if portal.parameter_len() == 0 {
701        return Ok(stmt);
702    }
703
704    let values = collect_portal_values(portal)?;
705    bind_statement_placeholders(&mut stmt, &values)?;
706    Ok(stmt)
707}
708
709fn collect_portal_values(portal: &Portal<String>) -> PgWireResult<Vec<SqlValue>> {
710    let mut values = Vec::with_capacity(portal.parameter_len());
711    for i in 0..portal.parameter_len() {
712        let pg_type = portal
713            .statement
714            .parameter_types
715            .get(i)
716            .unwrap_or(&Type::UNKNOWN);
717        values.push(parameter_sql_value(portal, i, pg_type)?);
718    }
719    Ok(values)
720}
721
722fn bind_statement_placeholders(stmt: &mut Statement, values: &[SqlValue]) -> PgWireResult<()> {
723    let mut question_mark_idx = 0usize;
724    let outcome = visit_expressions_mut(stmt, |expr| {
725        if let Expr::Value(v) = expr {
726            if let SqlValue::Placeholder(placeholder) = &v.value {
727                let idx = match placeholder_index(placeholder, values.len(), &mut question_mark_idx)
728                {
729                    Ok(idx) => idx,
730                    Err(err) => return ControlFlow::Break(err),
731                };
732                v.value = values[idx - 1].clone();
733            }
734        }
735        ControlFlow::Continue(())
736    });
737
738    match outcome {
739        ControlFlow::Continue(()) => Ok(()),
740        ControlFlow::Break(err) => Err(err),
741    }
742}
743
744fn placeholder_index(
745    placeholder: &str,
746    total: usize,
747    question_mark_idx: &mut usize,
748) -> PgWireResult<usize> {
749    let idx = if placeholder == "?" {
750        *question_mark_idx += 1;
751        *question_mark_idx
752    } else if let Some(raw) = placeholder.strip_prefix('$') {
753        raw.parse::<usize>()
754            .map_err(|_| user_error("22023", format!("invalid placeholder '{placeholder}'")))?
755    } else {
756        return Err(user_error(
757            "22023",
758            format!("unsupported placeholder syntax '{placeholder}'"),
759        ));
760    };
761
762    if idx == 0 || idx > total {
763        return Err(user_error(
764            "22023",
765            format!("placeholder '{placeholder}' is out of range"),
766        ));
767    }
768    Ok(idx)
769}
770
771fn parameter_sql_value(
772    portal: &Portal<String>,
773    idx: usize,
774    pg_type: &Type,
775) -> PgWireResult<SqlValue> {
776    if *pg_type == Type::UNKNOWN {
777        return parameter_sql_value_unknown(portal, idx);
778    }
779    if *pg_type == Type::BOOL {
780        return Ok(match portal.parameter::<bool>(idx, pg_type)? {
781            Some(v) => SqlValue::Boolean(v),
782            None => SqlValue::Null,
783        });
784    }
785    if *pg_type == Type::INT2 {
786        return Ok(match portal.parameter::<i16>(idx, pg_type)? {
787            Some(v) => SqlValue::Number(v.to_string(), false),
788            None => SqlValue::Null,
789        });
790    }
791    if *pg_type == Type::INT4 {
792        return Ok(match portal.parameter::<i32>(idx, pg_type)? {
793            Some(v) => SqlValue::Number(v.to_string(), false),
794            None => SqlValue::Null,
795        });
796    }
797    if *pg_type == Type::INT8 {
798        return Ok(match portal.parameter::<i64>(idx, pg_type)? {
799            Some(v) => SqlValue::Number(v.to_string(), false),
800            None => SqlValue::Null,
801        });
802    }
803    if *pg_type == Type::FLOAT4 {
804        return Ok(match portal.parameter::<f32>(idx, pg_type)? {
805            Some(v) => SqlValue::Number((v as f64).to_string(), false),
806            None => SqlValue::Null,
807        });
808    }
809    if *pg_type == Type::FLOAT8 {
810        return Ok(match portal.parameter::<f64>(idx, pg_type)? {
811            Some(v) => SqlValue::Number(v.to_string(), false),
812            None => SqlValue::Null,
813        });
814    }
815
816    Ok(match portal.parameter::<String>(idx, pg_type)? {
817        Some(v) => SqlValue::SingleQuotedString(v),
818        None => SqlValue::Null,
819    })
820}
821
822fn parameter_sql_value_unknown(portal: &Portal<String>, idx: usize) -> PgWireResult<SqlValue> {
823    let Some(raw) = portal.parameters.get(idx) else {
824        return Ok(SqlValue::Null);
825    };
826    let Some(bytes) = raw else {
827        return Ok(SqlValue::Null);
828    };
829
830    if portal.parameter_format.is_binary(idx) {
831        match bytes.len() {
832            1 => {
833                let b = bytes[0];
834                if b == 0 || b == 1 {
835                    return Ok(SqlValue::Boolean(b == 1));
836                }
837            }
838            4 => {
839                let v = i32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
840                return Ok(SqlValue::Number(v.to_string(), false));
841            }
842            8 => {
843                let v = i64::from_be_bytes([
844                    bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
845                ]);
846                return Ok(SqlValue::Number(v.to_string(), false));
847            }
848            _ => {}
849        }
850    }
851
852    match std::str::from_utf8(bytes) {
853        Ok(s) => {
854            if let Ok(i) = s.parse::<i64>() {
855                return Ok(SqlValue::Number(i.to_string(), false));
856            }
857            if let Ok(f) = s.parse::<f64>() {
858                return Ok(SqlValue::Number(f.to_string(), false));
859            }
860            if s.eq_ignore_ascii_case("true") || s.eq_ignore_ascii_case("false") {
861                return Ok(SqlValue::Boolean(s.eq_ignore_ascii_case("true")));
862            }
863            Ok(SqlValue::SingleQuotedString(s.to_string()))
864        }
865        Err(_) => Ok(SqlValue::SingleQuotedString(
866            String::from_utf8_lossy(bytes).into_owned(),
867        )),
868    }
869}
870
871fn tag_for_statement(stmt: &Statement) -> Tag {
872    Tag::new(tag_for_statement_name(stmt))
873}
874
875fn statement_returns_rows(stmt: &Statement) -> bool {
876    match stmt {
877        Statement::Query(_) => true,
878        Statement::Insert(insert) => insert.returning.is_some(),
879        Statement::Update { returning, .. } => returning.is_some(),
880        Statement::Delete(delete) => delete.returning.is_some(),
881        _ => false,
882    }
883}
884
885fn execution_tag_for_statement(stmt: &Statement, affected: usize) -> Tag {
886    match stmt {
887        // PostgreSQL command-complete format for INSERT is "INSERT 0 <rows>".
888        Statement::Insert(_) => Tag::new("INSERT 0").with_rows(affected),
889        _ => tag_for_statement(stmt).with_rows(affected),
890    }
891}
892
893fn tag_for_statement_name(stmt: &Statement) -> &'static str {
894    match stmt {
895        Statement::Query(_) => "SELECT",
896        Statement::Insert(_) => "INSERT",
897        Statement::Update { .. } => "UPDATE",
898        Statement::Delete { .. } => "DELETE",
899        Statement::Truncate { .. } => "TRUNCATE",
900        Statement::CreateTable { .. } => "CREATE TABLE",
901        Statement::CreateIndex(_) => "CREATE INDEX",
902        Statement::AlterTable { .. } => "ALTER TABLE",
903        Statement::Drop {
904            object_type: sqlparser::ast::ObjectType::Table,
905            ..
906        } => "DROP TABLE",
907        Statement::Drop {
908            object_type: sqlparser::ast::ObjectType::Index,
909            ..
910        } => "DROP INDEX",
911        Statement::StartTransaction { .. } => "BEGIN",
912        Statement::Commit { .. } => "COMMIT",
913        Statement::Rollback { .. } => "ROLLBACK",
914        _ => "OK",
915    }
916}
917
918fn map_entdb_error(err: EntDbError) -> PgWireError {
919    let (code, message) = match err {
920        EntDbError::Query(m) => (query_sqlstate(&m), m),
921        EntDbError::BufferPoolFull => ("53200", "buffer pool full".to_string()),
922        EntDbError::PagePinned(pid) => ("55006", format!("page {pid} is pinned")),
923        EntDbError::PageNotFound(pid) => ("XX000", format!("page {pid} not found")),
924        EntDbError::Io(e) => ("58000", format!("io error: {e}")),
925        EntDbError::Wal(m) => ("XX000", format!("wal error: {m}")),
926        EntDbError::Corruption(m) => ("XX001", format!("corruption: {m}")),
927        EntDbError::InvalidPage(m) => ("XX000", format!("invalid page: {m}")),
928        EntDbError::PageAlreadyPresent(pid) => ("XX000", format!("page {pid} already present")),
929    };
930    user_error(code, message)
931}
932
933fn query_sqlstate(message: &str) -> &'static str {
934    let lower = message.to_ascii_lowercase();
935    if lower.contains("parse error") {
936        return "42601";
937    }
938    if lower.contains("already active") {
939        return "25001";
940    }
941    if lower.contains("no active transaction") {
942        return "25000";
943    }
944    if lower.contains("write-write conflict") {
945        return "40001";
946    }
947    if lower.contains("does not exist") {
948        if lower.contains("table") {
949            return "42P01";
950        }
951        if lower.contains("column") {
952            return "42703";
953        }
954    }
955    if lower.contains("cannot cast value")
956        || lower.contains("invalid")
957        || lower.contains("unsupported")
958        || lower.contains("must be numeric")
959    {
960        return "22000";
961    }
962    "XX000"
963}
964
965fn user_error(code: impl Into<String>, message: impl Into<String>) -> PgWireError {
966    PgWireError::UserError(Box::new(ErrorInfo::new(
967        "ERROR".to_string(),
968        code.into(),
969        message.into(),
970    )))
971}
972
973fn max_placeholder_index(sql: &str) -> usize {
974    let bytes = sql.as_bytes();
975    let mut max_idx = 0usize;
976    let mut i = 0usize;
977    while i < bytes.len() {
978        if bytes[i] == b'$' {
979            let start = i + 1;
980            let mut j = start;
981            while j < bytes.len() && bytes[j].is_ascii_digit() {
982                j += 1;
983            }
984            if j > start {
985                if let Ok(n) = sql[start..j].parse::<usize>() {
986                    max_idx = max_idx.max(n);
987                }
988                i = j;
989                continue;
990            }
991        }
992        i += 1;
993    }
994    max_idx
995}
996
997fn normalize_placeholders_for_describe(sql: &str) -> String {
998    let bytes = sql.as_bytes();
999    let mut out = String::with_capacity(sql.len());
1000    let mut i = 0usize;
1001    while i < bytes.len() {
1002        if bytes[i] == b'$' {
1003            let start = i + 1;
1004            let mut j = start;
1005            while j < bytes.len() && bytes[j].is_ascii_digit() {
1006                j += 1;
1007            }
1008            if j > start {
1009                // Use a numeric placeholder value so planning/type checks succeed
1010                // for common predicates and LIMIT/OFFSET contexts.
1011                out.push('0');
1012                i = j;
1013                continue;
1014            }
1015        }
1016        out.push(bytes[i] as char);
1017        i += 1;
1018    }
1019    out
1020}
1021
1022pub fn scan_max_txn_id_from_storage(catalog: &Catalog) -> Result<u64, EntDbError> {
1023    let mut max_txn = 0_u64;
1024    for table in catalog.list_tables() {
1025        let t = Table::open(table.table_id, table.first_page_id, catalog.buffer_pool());
1026        for (_, tuple) in t.scan() {
1027            let decoded = entdb::query::executor::decode_stored_row(&tuple.data)?;
1028            if let entdb::query::executor::DecodedRow::Versioned(v) = decoded {
1029                max_txn = max_txn.max(v.created_txn);
1030                if let Some(d) = v.deleted_txn {
1031                    max_txn = max_txn.max(d);
1032                }
1033            }
1034        }
1035    }
1036    Ok(max_txn)
1037}
1038
1039#[cfg(test)]
1040mod tests {
1041    use super::{
1042        bind_statement_placeholders, execution_tag_for_statement, max_placeholder_index,
1043        normalize_placeholders_for_describe, parse_single_statement, parse_sql, EntHandler,
1044    };
1045    use crate::server::Database;
1046    use entdb::catalog::Catalog;
1047    use entdb::query::history::OptimizerHistoryRecorder;
1048    use entdb::query::optimizer::OptimizerConfig;
1049    use entdb::storage::buffer_pool::BufferPool;
1050    use entdb::storage::disk_manager::DiskManager;
1051    use entdb::tx::TransactionManager;
1052    use entdb::wal::log_manager::LogManager;
1053    use pgwire::api::results::Tag;
1054    use pgwire::error::PgWireError;
1055    use proptest::prelude::*;
1056    use sqlparser::ast::Value as SqlValue;
1057    use std::sync::Arc;
1058    use tempfile::tempdir;
1059
1060    fn test_handler() -> EntHandler {
1061        let dir = tempdir().expect("tempdir");
1062        let db_path = dir.path().join("handler.db");
1063        let wal_path = dir.path().join("handler.wal");
1064
1065        let dm = Arc::new(DiskManager::new(&db_path).expect("disk"));
1066        let lm = Arc::new(LogManager::new(&wal_path, 4096).expect("wal"));
1067        let bp = Arc::new(BufferPool::with_log_manager(
1068            32,
1069            Arc::clone(&dm),
1070            Arc::clone(&lm),
1071        ));
1072        let catalog = Arc::new(Catalog::load(Arc::clone(&bp)).expect("catalog"));
1073        let txn = Arc::new(TransactionManager::new());
1074        let history_path = dir.path().join("handler.optimizer_history.json");
1075        let optimizer_history = Arc::new(
1076            OptimizerHistoryRecorder::new(
1077                &history_path,
1078                super::optimizer_history_schema_hash(),
1079                16,
1080                128,
1081            )
1082            .expect("optimizer history"),
1083        );
1084        let db = Arc::new(Database {
1085            disk_manager: dm,
1086            log_manager: lm,
1087            buffer_pool: bp,
1088            catalog,
1089            txn_manager: txn,
1090            optimizer_history,
1091            optimizer_config: OptimizerConfig::default(),
1092        });
1093        EntHandler::new(
1094            db,
1095            1024 * 1024,
1096            30_000,
1097            Arc::new(super::ServerMetrics::default()),
1098            false,
1099        )
1100    }
1101
1102    #[test]
1103    fn bind_statement_placeholders_replaces_parameter_nodes() {
1104        let mut stmt = parse_single_statement(
1105            "SELECT id FROM t WHERE v > $1 ORDER BY id LIMIT $2",
1106            "single",
1107        )
1108        .expect("parse");
1109        bind_statement_placeholders(
1110            &mut stmt,
1111            &[
1112                SqlValue::Number("7".to_string(), false),
1113                SqlValue::Number("3".to_string(), false),
1114            ],
1115        )
1116        .expect("bind placeholders");
1117
1118        let rendered = stmt.to_string();
1119        assert!(rendered.contains("v > 7"), "rendered SQL: {rendered}");
1120        assert!(rendered.contains("LIMIT 3"), "rendered SQL: {rendered}");
1121    }
1122
1123    #[test]
1124    fn bind_statement_placeholders_rejects_out_of_range_index() {
1125        let mut stmt = parse_single_statement("SELECT $3", "single").expect("parse");
1126        let err = match bind_statement_placeholders(
1127            &mut stmt,
1128            &[SqlValue::Number("1".to_string(), false)],
1129        ) {
1130            Ok(_) => panic!("expected out-of-range placeholder failure"),
1131            Err(e) => e,
1132        };
1133        match err {
1134            PgWireError::UserError(info) => {
1135                assert_eq!(info.code, "22023");
1136            }
1137            other => panic!("unexpected error: {other:?}"),
1138        }
1139    }
1140
1141    #[test]
1142    fn parse_single_statement_rejects_multi_statement_sql() {
1143        let err = match parse_single_statement("SELECT 1; SELECT 2", "single") {
1144            Ok(_) => panic!("expected single-statement rejection"),
1145            Err(e) => e,
1146        };
1147        match err {
1148            PgWireError::UserError(info) => {
1149                assert_eq!(info.code, "0A000");
1150            }
1151            other => panic!("unexpected error: {other:?}"),
1152        }
1153    }
1154
1155    #[test]
1156    fn max_placeholder_index_finds_highest_parameter() {
1157        assert_eq!(max_placeholder_index("SELECT $1, $2, $10"), 10);
1158        assert_eq!(max_placeholder_index("SELECT 1"), 0);
1159    }
1160
1161    #[test]
1162    fn normalize_placeholders_for_describe_rewrites_all_numbered_params() {
1163        let sql = "SELECT id FROM t WHERE v > $1 ORDER BY id LIMIT $2";
1164        let normalized = normalize_placeholders_for_describe(sql);
1165        assert_eq!(
1166            normalized,
1167            "SELECT id FROM t WHERE v > 0 ORDER BY id LIMIT 0"
1168        );
1169    }
1170
1171    #[test]
1172    fn execution_tag_for_insert_matches_postgres_format() {
1173        let stmt = parse_single_statement("INSERT INTO t VALUES (1)", "single").expect("parse");
1174        let tag = execution_tag_for_statement(&stmt, 2);
1175        assert_eq!(tag, Tag::new("INSERT 0").with_rows(2));
1176    }
1177
1178    #[test]
1179    fn execution_tag_for_update_and_delete_include_row_count() {
1180        let update = parse_single_statement("UPDATE t SET v = 1", "single").expect("parse");
1181        let delete = parse_single_statement("DELETE FROM t", "single").expect("parse");
1182        assert_eq!(
1183            execution_tag_for_statement(&update, 3),
1184            Tag::new("UPDATE").with_rows(3)
1185        );
1186        assert_eq!(
1187            execution_tag_for_statement(&delete, 4),
1188            Tag::new("DELETE").with_rows(4)
1189        );
1190    }
1191
1192    #[test]
1193    fn handler_begin_commit_round_trip() {
1194        let handler = test_handler();
1195        let out = handler
1196            .execute_sql("BEGIN; COMMIT;", None)
1197            .expect("execute");
1198        assert_eq!(out.len(), 2);
1199    }
1200
1201    #[test]
1202    fn handler_enforces_max_statement_size() {
1203        let handler = test_handler();
1204        let huge = "X".repeat(2 * 1024 * 1024);
1205        let err = match handler.execute_sql(&huge, None) {
1206            Ok(_) => panic!("max statement size should be enforced"),
1207            Err(e) => e,
1208        };
1209        assert!(err.to_string().contains("max_statement_bytes"));
1210    }
1211
1212    proptest! {
1213        #[test]
1214        fn parse_sql_never_panics_on_random_input(input in ".*") {
1215            let _ = parse_sql(&input);
1216        }
1217    }
1218}