1use crate::server::type_map::{encode_value, ent_to_pg_type};
18use async_trait::async_trait;
19use entdb::catalog::Catalog;
20use entdb::error::EntDbError;
21use entdb::query::binder::Binder;
22use entdb::query::executor::{build_executor, ExecutionContext, TxExecutionContext};
23use entdb::query::history::OptimizerHistoryRecord;
24use entdb::query::optimizer::Optimizer;
25use entdb::query::planner::Planner;
26use entdb::query::polyglot::{transpile_with_meta, PolyglotOptions};
27use entdb::storage::table::Table;
28use entdb::tx::TransactionHandle;
29use futures::{stream, Sink};
30use parking_lot::Mutex;
31use pgwire::api::portal::{Format, Portal};
32use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
33use pgwire::api::results::{
34 DataRowEncoder, DescribePortalResponse, DescribeStatementResponse, FieldFormat, FieldInfo,
35 QueryResponse, Response, Tag,
36};
37use pgwire::api::stmt::{NoopQueryParser, StoredStatement};
38use pgwire::api::{ClientInfo, Type};
39use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
40use pgwire::messages::PgWireBackendMessage;
41use sqlparser::ast::{visit_expressions_mut, Expr, Statement, Value as SqlValue};
42use sqlparser::dialect::PostgreSqlDialect;
43use sqlparser::parser::Parser;
44use std::fmt::Debug;
45use std::ops::ControlFlow;
46use std::sync::Arc;
47use tracing::{debug, info_span};
48
49use super::metrics::ServerMetrics;
50use super::{optimizer_history_schema_hash, Database};
51
52pub struct EntHandler {
53 db: Arc<Database>,
54 current_txn: Mutex<Option<TransactionHandle>>,
55 query_parser: Arc<NoopQueryParser>,
56 max_statement_bytes: usize,
57 query_timeout_ms: u64,
58 metrics: Arc<ServerMetrics>,
59 polyglot: PolyglotOptions,
60 await_durable: bool,
61}
62
63impl EntHandler {
64 pub fn new(
65 db: Arc<Database>,
66 max_statement_bytes: usize,
67 query_timeout_ms: u64,
68 metrics: Arc<ServerMetrics>,
69 await_durable: bool,
70 ) -> Self {
71 Self {
72 db,
73 current_txn: Mutex::new(None),
74 query_parser: Arc::new(NoopQueryParser::new()),
75 max_statement_bytes,
76 query_timeout_ms,
77 metrics,
78 await_durable,
79 polyglot: PolyglotOptions {
80 enabled: std::env::var("ENTDB_POLYGLOT")
81 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
82 .unwrap_or(false),
83 },
84 }
85 }
86
87 fn execute_sql(
88 &self,
89 sql: &str,
90 max_rows: Option<usize>,
91 ) -> PgWireResult<Vec<Response<'static>>> {
92 if sql.len() > self.max_statement_bytes {
93 return Err(user_error(
94 "54000",
95 format!(
96 "statement exceeds configured max_statement_bytes ({} > {})",
97 sql.len(),
98 self.max_statement_bytes
99 ),
100 ));
101 }
102 let _span = info_span!(
103 "pgwire.execute_sql",
104 sql_len = sql.len(),
105 max_rows = max_rows.unwrap_or(0)
106 )
107 .entered();
108 let statements = parse_sql_with_polyglot(sql, self.polyglot)?;
109 let mut responses = Vec::new();
110 let started = std::time::Instant::now();
111
112 for stmt in &statements {
113 if started.elapsed().as_millis() as u64 > self.query_timeout_ms {
114 return Err(user_error("57014", "query timeout exceeded"));
115 }
116 debug!(statement = %tag_for_statement_name(stmt), "processing statement");
117 match stmt {
118 Statement::StartTransaction { .. } => {
119 let mut guard = self.current_txn.lock();
120 if guard.is_some() {
121 return Err(user_error(
122 "25001",
123 "transaction already active for this session",
124 ));
125 }
126 *guard = Some(self.db.txn_manager.begin());
127 responses.push(Response::TransactionStart(Tag::new("BEGIN")));
128 }
129 Statement::Commit { .. } => {
130 let tx = {
131 let mut guard = self.current_txn.lock();
132 guard
133 .take()
134 .ok_or_else(|| user_error("25000", "no active transaction to COMMIT"))?
135 };
136 self.db
137 .txn_manager
138 .commit_with_options(tx.txn_id, Some(self.await_durable))
139 .map_err(map_entdb_error)?;
140 responses.push(Response::TransactionEnd(Tag::new("COMMIT")));
141 }
142 Statement::Rollback { .. } => {
143 let tx = {
144 let mut guard = self.current_txn.lock();
145 guard.take().ok_or_else(|| {
146 user_error("25000", "no active transaction to ROLLBACK")
147 })?
148 };
149 self.db.txn_manager.abort(tx.txn_id);
150 responses.push(Response::TransactionEnd(Tag::new("ROLLBACK")));
151 }
152 _ => {
153 let active_tx = *self.current_txn.lock();
154 let exec_resp = if let Some(tx) = active_tx {
155 self.execute_statement_in_txn(&tx, stmt, max_rows, None)?
156 } else {
157 let tx = self.db.txn_manager.begin();
158 match self.execute_statement_in_txn(&tx, stmt, max_rows, None) {
159 Ok(resp) => {
160 if let Err(e) = self
161 .db
162 .txn_manager
163 .commit_with_options(tx.txn_id, Some(self.await_durable))
164 {
165 self.db.txn_manager.abort(tx.txn_id);
166 return Err(map_entdb_error(e));
167 }
168 resp
169 }
170 Err(e) => {
171 self.db.txn_manager.abort(tx.txn_id);
172 return Err(e);
173 }
174 }
175 };
176 responses.push(exec_resp);
177 }
178 }
179 }
180
181 Ok(responses)
182 }
183
184 fn execute_statement_in_txn(
185 &self,
186 tx: &TransactionHandle,
187 stmt: &Statement,
188 max_rows: Option<usize>,
189 row_format: Option<&Format>,
190 ) -> PgWireResult<Response<'static>> {
191 let stmt_started = std::time::Instant::now();
192 let _span = info_span!(
193 "pgwire.execute_statement",
194 statement = %tag_for_statement_name(stmt),
195 txn_id = tx.txn_id
196 )
197 .entered();
198 let binder = Binder::new(Arc::clone(&self.db.catalog));
199 let planner = Planner;
200
201 let bound = match binder.bind(stmt) {
202 Ok(b) => b,
203 Err(e) => {
204 self.record_optimizer_history_error(stmt, "bind_error", stmt_started.elapsed());
205 return Err(map_entdb_error(e));
206 }
207 };
208 let fingerprint = Optimizer::fingerprint_bound_statement(&bound);
209 let history = self.db.optimizer_history.read_for_fingerprint(&fingerprint);
210 self.ensure_statement_timeout(stmt_started)?;
211 let plan = match planner.plan(bound) {
212 Ok(p) => p,
213 Err(e) => {
214 self.record_optimizer_history(OptimizerHistoryRecord {
215 fingerprint,
216 plan_signature: "error".to_string(),
217 schema_hash: optimizer_history_schema_hash().to_string(),
218 captured_at_ms: now_epoch_millis(),
219 rowcount_observed_json: "{\"root\":0}".to_string(),
220 latency_ms: stmt_started.elapsed().as_millis() as u64,
221 memory_peak_bytes: 0,
222 success: false,
223 error_class: Some("plan_error".to_string()),
224 confidence: 0.0,
225 });
226 return Err(map_entdb_error(e));
227 }
228 };
229 self.ensure_statement_timeout(stmt_started)?;
230 let optimized_outcome = Optimizer::optimize_with_trace_and_history(
231 plan,
232 &fingerprint,
233 self.db.optimizer_config,
234 &history,
235 );
236 let chosen_plan_signature = optimized_outcome
237 .trace
238 .chosen_plan_signature
239 .clone()
240 .unwrap_or_else(|| "baseline".to_string());
241 let optimized = optimized_outcome.plan;
242 self.ensure_statement_timeout(stmt_started)?;
243
244 let ctx = ExecutionContext {
245 catalog: Arc::clone(&self.db.catalog),
246 tx: Some(TxExecutionContext {
247 txn_id: tx.txn_id,
248 snapshot_ts: tx.snapshot_ts,
249 txn_manager: Arc::clone(&self.db.txn_manager),
250 }),
251 };
252
253 let mut exec = match build_executor(&optimized, &ctx) {
254 Ok(exec) => exec,
255 Err(e) => {
256 self.record_optimizer_history(OptimizerHistoryRecord {
257 fingerprint,
258 plan_signature: chosen_plan_signature,
259 schema_hash: optimizer_history_schema_hash().to_string(),
260 captured_at_ms: now_epoch_millis(),
261 rowcount_observed_json: "{\"root\":0}".to_string(),
262 latency_ms: stmt_started.elapsed().as_millis() as u64,
263 memory_peak_bytes: 0,
264 success: false,
265 error_class: Some("executor_build_error".to_string()),
266 confidence: 0.0,
267 });
268 return Err(map_entdb_error(e));
269 }
270 };
271 self.ensure_statement_timeout(stmt_started)?;
272 if let Err(e) = exec.open() {
273 self.record_optimizer_history(OptimizerHistoryRecord {
274 fingerprint,
275 plan_signature: chosen_plan_signature,
276 schema_hash: optimizer_history_schema_hash().to_string(),
277 captured_at_ms: now_epoch_millis(),
278 rowcount_observed_json: "{\"root\":0}".to_string(),
279 latency_ms: stmt_started.elapsed().as_millis() as u64,
280 memory_peak_bytes: 0,
281 success: false,
282 error_class: Some("executor_open_error".to_string()),
283 confidence: 0.0,
284 });
285 return Err(map_entdb_error(e));
286 }
287 self.ensure_statement_timeout(stmt_started)?;
288
289 let schema = exec.schema().clone();
290 let mut rows = Vec::new();
291 while let Some(row) = match exec.next() {
292 Ok(r) => r,
293 Err(e) => {
294 self.record_optimizer_history(OptimizerHistoryRecord {
295 fingerprint: fingerprint.clone(),
296 plan_signature: chosen_plan_signature.clone(),
297 schema_hash: optimizer_history_schema_hash().to_string(),
298 captured_at_ms: now_epoch_millis(),
299 rowcount_observed_json: format!("{{\"root\":{}}}", rows.len()),
300 latency_ms: stmt_started.elapsed().as_millis() as u64,
301 memory_peak_bytes: 0,
302 success: false,
303 error_class: Some("executor_next_error".to_string()),
304 confidence: 0.0,
305 });
306 return Err(map_entdb_error(e));
307 }
308 } {
309 self.ensure_statement_timeout(stmt_started)?;
310 rows.push(row);
311 if let Some(limit) = max_rows {
312 if limit > 0 && rows.len() >= limit {
313 break;
314 }
315 }
316 }
317 if let Err(e) = exec.close() {
318 self.record_optimizer_history(OptimizerHistoryRecord {
319 fingerprint,
320 plan_signature: chosen_plan_signature,
321 schema_hash: optimizer_history_schema_hash().to_string(),
322 captured_at_ms: now_epoch_millis(),
323 rowcount_observed_json: format!("{{\"root\":{}}}", rows.len()),
324 latency_ms: stmt_started.elapsed().as_millis() as u64,
325 memory_peak_bytes: 0,
326 success: false,
327 error_class: Some("executor_close_error".to_string()),
328 confidence: 0.0,
329 });
330 return Err(map_entdb_error(e));
331 }
332
333 let observed_rows = if schema.columns.is_empty() {
334 exec.affected_rows().unwrap_or(0)
335 } else {
336 rows.len() as u64
337 };
338 self.record_optimizer_history(OptimizerHistoryRecord {
339 fingerprint,
340 plan_signature: chosen_plan_signature,
341 schema_hash: optimizer_history_schema_hash().to_string(),
342 captured_at_ms: now_epoch_millis(),
343 rowcount_observed_json: format!("{{\"root\":{observed_rows}}}"),
344 latency_ms: stmt_started.elapsed().as_millis() as u64,
345 memory_peak_bytes: 0,
346 success: true,
347 error_class: None,
348 confidence: 1.0,
349 });
350
351 if schema.columns.is_empty() {
352 let affected = exec.affected_rows().unwrap_or(0) as usize;
353 let tag = execution_tag_for_statement(stmt, affected);
354 return Ok(Response::Execution(tag));
355 }
356
357 let fields = schema
358 .columns
359 .iter()
360 .enumerate()
361 .map(|(idx, c)| {
362 FieldInfo::new(
363 c.name.clone(),
364 None,
365 None,
366 ent_to_pg_type(&c.data_type),
367 row_format
368 .map(|f| f.format_for(idx))
369 .unwrap_or(FieldFormat::Text),
370 )
371 })
372 .collect::<Vec<_>>();
373 let schema = Arc::new(fields);
374
375 let mut encoded = Vec::with_capacity(rows.len());
376 for row in &rows {
377 let mut encoder = DataRowEncoder::new(Arc::clone(&schema));
378 for (idx, value) in row.iter().enumerate() {
379 let field = &schema[idx];
380 encode_value(&mut encoder, value, field.datatype(), field.format())?;
381 }
382 encoded.push(encoder.finish()?);
383 }
384
385 let data_row_stream = stream::iter(encoded.into_iter().map(Ok));
386 let mut query = QueryResponse::new(schema, data_row_stream);
387 query.set_command_tag(tag_for_statement_name(stmt));
388 Ok(Response::Query(query))
389 }
390
391 fn describe_statement_from_stmt(
392 &self,
393 stmt: &Statement,
394 format: &Format,
395 ) -> PgWireResult<Vec<FieldInfo>> {
396 if !statement_returns_rows(stmt) {
397 return Ok(Vec::new());
398 }
399
400 let tx = self.db.txn_manager.begin();
401 let binder = Binder::new(Arc::clone(&self.db.catalog));
402 let planner = Planner;
403
404 let bound = binder.bind(stmt).map_err(map_entdb_error)?;
405 let fingerprint = Optimizer::fingerprint_bound_statement(&bound);
406 let history = self.db.optimizer_history.read_for_fingerprint(&fingerprint);
407 let plan = planner.plan(bound).map_err(map_entdb_error)?;
408 let optimized = Optimizer::optimize_with_trace_and_history(
409 plan,
410 &fingerprint,
411 self.db.optimizer_config,
412 &history,
413 )
414 .plan;
415
416 let ctx = ExecutionContext {
417 catalog: Arc::clone(&self.db.catalog),
418 tx: Some(TxExecutionContext {
419 txn_id: tx.txn_id,
420 snapshot_ts: tx.snapshot_ts,
421 txn_manager: Arc::clone(&self.db.txn_manager),
422 }),
423 };
424
425 let exec = build_executor(&optimized, &ctx).map_err(map_entdb_error)?;
426 let fields = exec
427 .schema()
428 .columns
429 .iter()
430 .enumerate()
431 .map(|(idx, col)| {
432 FieldInfo::new(
433 col.name.clone(),
434 None,
435 None,
436 ent_to_pg_type(&col.data_type),
437 format.format_for(idx),
438 )
439 })
440 .collect::<Vec<_>>();
441
442 self.db.txn_manager.abort(tx.txn_id);
443 Ok(fields)
444 }
445
446 fn describe_statement_inner(&self, sql: &str, format: &Format) -> PgWireResult<Vec<FieldInfo>> {
447 let stmt = parse_single_statement_with_polyglot(
448 sql,
449 "describe supports a single statement",
450 self.polyglot,
451 )?;
452 self.describe_statement_from_stmt(&stmt, format)
453 }
454}
455
456impl EntHandler {
457 fn record_optimizer_history(&self, entry: OptimizerHistoryRecord) {
458 self.db.optimizer_history.try_record(entry);
459 }
460
461 fn record_optimizer_history_error(
462 &self,
463 stmt: &Statement,
464 class: &str,
465 elapsed: std::time::Duration,
466 ) {
467 self.record_optimizer_history(OptimizerHistoryRecord {
468 fingerprint: format!("stmt:{}", tag_for_statement_name(stmt)),
469 plan_signature: "error".to_string(),
470 schema_hash: optimizer_history_schema_hash().to_string(),
471 captured_at_ms: now_epoch_millis(),
472 rowcount_observed_json: "{\"root\":0}".to_string(),
473 latency_ms: elapsed.as_millis() as u64,
474 memory_peak_bytes: 0,
475 success: false,
476 error_class: Some(class.to_string()),
477 confidence: 0.0,
478 });
479 }
480}
481
482fn now_epoch_millis() -> u64 {
483 std::time::SystemTime::now()
484 .duration_since(std::time::UNIX_EPOCH)
485 .map(|d| d.as_millis() as u64)
486 .unwrap_or(0)
487}
488
489#[async_trait]
490impl SimpleQueryHandler for EntHandler {
491 async fn do_query<'a, C>(
492 &self,
493 _client: &mut C,
494 query: &'a str,
495 ) -> PgWireResult<Vec<Response<'a>>>
496 where
497 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
498 C::Error: Debug,
499 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
500 {
501 let started = std::time::Instant::now();
502 let result = self.execute_sql(query, None);
503 self.record_query_metric(started, result.as_ref().err());
504 result
505 }
506}
507
508#[async_trait]
509impl ExtendedQueryHandler for EntHandler {
510 type Statement = String;
511 type QueryParser = NoopQueryParser;
512
513 fn query_parser(&self) -> Arc<Self::QueryParser> {
514 Arc::clone(&self.query_parser)
515 }
516
517 async fn do_query<'a, C>(
518 &self,
519 _client: &mut C,
520 portal: &'a Portal<Self::Statement>,
521 max_rows: usize,
522 ) -> PgWireResult<Response<'a>>
523 where
524 C: ClientInfo + Unpin + Send + Sync,
525 {
526 let started = std::time::Instant::now();
527 let result = (|| -> PgWireResult<Response<'a>> {
528 let stmt = statement_from_portal(portal, self.polyglot)?;
529 if matches!(
530 stmt,
531 Statement::StartTransaction { .. }
532 | Statement::Commit { .. }
533 | Statement::Rollback { .. }
534 ) {
535 let mut responses = self.execute_sql(&stmt.to_string(), Some(max_rows))?;
536 if responses.is_empty() {
537 return Ok(Response::EmptyQuery);
538 }
539 return Ok(responses.remove(0));
540 }
541
542 let active_tx = *self.current_txn.lock();
543 if let Some(tx) = active_tx {
544 return self.execute_statement_in_txn(
545 &tx,
546 &stmt,
547 Some(max_rows),
548 Some(&portal.result_column_format),
549 );
550 }
551
552 let tx = self.db.txn_manager.begin();
553 match self.execute_statement_in_txn(
554 &tx,
555 &stmt,
556 Some(max_rows),
557 Some(&portal.result_column_format),
558 ) {
559 Ok(resp) => {
560 if let Err(e) = self
561 .db
562 .txn_manager
563 .commit_with_options(tx.txn_id, Some(self.await_durable))
564 {
565 self.db.txn_manager.abort(tx.txn_id);
566 Err(map_entdb_error(e))
567 } else {
568 Ok(resp)
569 }
570 }
571 Err(e) => {
572 self.db.txn_manager.abort(tx.txn_id);
573 Err(e)
574 }
575 }
576 })();
577 self.record_query_metric(started, result.as_ref().err());
578 result
579 }
580
581 async fn do_describe_statement<C>(
582 &self,
583 _client: &mut C,
584 stmt: &StoredStatement<Self::Statement>,
585 ) -> PgWireResult<DescribeStatementResponse>
586 where
587 C: ClientInfo + Unpin + Send + Sync,
588 {
589 let mut parameter_types = stmt.parameter_types.clone();
590 let inferred_count = max_placeholder_index(&stmt.statement);
591 if parameter_types.len() < inferred_count {
592 parameter_types.resize(inferred_count, Type::INT4);
593 }
594 for ty in &mut parameter_types {
595 if *ty == Type::UNKNOWN {
596 *ty = Type::INT4;
597 }
598 }
599
600 let fields = self
601 .describe_statement_inner(&stmt.statement, &Format::UnifiedBinary)
602 .or_else(|_| {
603 let normalized = normalize_placeholders_for_describe(&stmt.statement);
604 self.describe_statement_inner(&normalized, &Format::UnifiedBinary)
605 })
606 .unwrap_or_default();
607 Ok(DescribeStatementResponse::new(parameter_types, fields))
608 }
609
610 async fn do_describe_portal<C>(
611 &self,
612 _client: &mut C,
613 portal: &Portal<Self::Statement>,
614 ) -> PgWireResult<DescribePortalResponse>
615 where
616 C: ClientInfo + Unpin + Send + Sync,
617 {
618 let stmt = statement_from_portal(portal, self.polyglot)?;
619 let fields = self.describe_statement_from_stmt(&stmt, &portal.result_column_format)?;
620 Ok(DescribePortalResponse::new(fields))
621 }
622}
623
624#[cfg(test)]
625fn parse_sql(sql: &str) -> PgWireResult<Vec<Statement>> {
626 let dialect = PostgreSqlDialect {};
627 Parser::parse_sql(&dialect, sql).map_err(|e| user_error("42601", format!("parse error: {e}")))
628}
629
630fn parse_sql_with_polyglot(sql: &str, opts: PolyglotOptions) -> PgWireResult<Vec<Statement>> {
631 let transpiled = transpile_with_meta(sql, opts).map_err(map_entdb_error)?;
632 let dialect = PostgreSqlDialect {};
633 Parser::parse_sql(&dialect, &transpiled.transpiled_sql).map_err(|e| {
634 if transpiled.changed {
635 user_error(
636 "42601",
637 format!(
638 "parse error: {e}; original_sql={:?}; transpiled_sql={:?}",
639 transpiled.original_sql, transpiled.transpiled_sql
640 ),
641 )
642 } else {
643 user_error("42601", format!("parse error: {e}"))
644 }
645 })
646}
647
648impl EntHandler {
649 fn ensure_statement_timeout(&self, started: std::time::Instant) -> PgWireResult<()> {
650 if started.elapsed().as_millis() as u64 > self.query_timeout_ms {
651 return Err(user_error("57014", "query timeout exceeded"));
652 }
653 Ok(())
654 }
655
656 fn record_query_metric(&self, started: std::time::Instant, err: Option<&PgWireError>) {
657 let elapsed = started.elapsed().as_nanos() as u64;
658 let sqlstate = err.and_then(sqlstate_from_error);
659 self.metrics.on_query_finished(elapsed, sqlstate);
660 }
661}
662
663fn sqlstate_from_error(err: &PgWireError) -> Option<&str> {
664 match err {
665 PgWireError::UserError(info) => Some(info.code.as_str()),
666 _ => Some("XX000"),
667 }
668}
669
670#[cfg(test)]
671fn parse_single_statement(sql: &str, message: &'static str) -> PgWireResult<Statement> {
672 let mut statements = parse_sql(sql)?;
673 if statements.len() != 1 {
674 return Err(user_error("0A000", message));
675 }
676 Ok(statements.remove(0))
677}
678
679fn parse_single_statement_with_polyglot(
680 sql: &str,
681 message: &'static str,
682 opts: PolyglotOptions,
683) -> PgWireResult<Statement> {
684 let mut statements = parse_sql_with_polyglot(sql, opts)?;
685 if statements.len() != 1 {
686 return Err(user_error("0A000", message));
687 }
688 Ok(statements.remove(0))
689}
690
691fn statement_from_portal(
692 portal: &Portal<String>,
693 opts: PolyglotOptions,
694) -> PgWireResult<Statement> {
695 let mut stmt = parse_single_statement_with_polyglot(
696 &portal.statement.statement,
697 "extended protocol supports a single statement per execute",
698 opts,
699 )?;
700 if portal.parameter_len() == 0 {
701 return Ok(stmt);
702 }
703
704 let values = collect_portal_values(portal)?;
705 bind_statement_placeholders(&mut stmt, &values)?;
706 Ok(stmt)
707}
708
709fn collect_portal_values(portal: &Portal<String>) -> PgWireResult<Vec<SqlValue>> {
710 let mut values = Vec::with_capacity(portal.parameter_len());
711 for i in 0..portal.parameter_len() {
712 let pg_type = portal
713 .statement
714 .parameter_types
715 .get(i)
716 .unwrap_or(&Type::UNKNOWN);
717 values.push(parameter_sql_value(portal, i, pg_type)?);
718 }
719 Ok(values)
720}
721
722fn bind_statement_placeholders(stmt: &mut Statement, values: &[SqlValue]) -> PgWireResult<()> {
723 let mut question_mark_idx = 0usize;
724 let outcome = visit_expressions_mut(stmt, |expr| {
725 if let Expr::Value(v) = expr {
726 if let SqlValue::Placeholder(placeholder) = &v.value {
727 let idx = match placeholder_index(placeholder, values.len(), &mut question_mark_idx)
728 {
729 Ok(idx) => idx,
730 Err(err) => return ControlFlow::Break(err),
731 };
732 v.value = values[idx - 1].clone();
733 }
734 }
735 ControlFlow::Continue(())
736 });
737
738 match outcome {
739 ControlFlow::Continue(()) => Ok(()),
740 ControlFlow::Break(err) => Err(err),
741 }
742}
743
744fn placeholder_index(
745 placeholder: &str,
746 total: usize,
747 question_mark_idx: &mut usize,
748) -> PgWireResult<usize> {
749 let idx = if placeholder == "?" {
750 *question_mark_idx += 1;
751 *question_mark_idx
752 } else if let Some(raw) = placeholder.strip_prefix('$') {
753 raw.parse::<usize>()
754 .map_err(|_| user_error("22023", format!("invalid placeholder '{placeholder}'")))?
755 } else {
756 return Err(user_error(
757 "22023",
758 format!("unsupported placeholder syntax '{placeholder}'"),
759 ));
760 };
761
762 if idx == 0 || idx > total {
763 return Err(user_error(
764 "22023",
765 format!("placeholder '{placeholder}' is out of range"),
766 ));
767 }
768 Ok(idx)
769}
770
771fn parameter_sql_value(
772 portal: &Portal<String>,
773 idx: usize,
774 pg_type: &Type,
775) -> PgWireResult<SqlValue> {
776 if *pg_type == Type::UNKNOWN {
777 return parameter_sql_value_unknown(portal, idx);
778 }
779 if *pg_type == Type::BOOL {
780 return Ok(match portal.parameter::<bool>(idx, pg_type)? {
781 Some(v) => SqlValue::Boolean(v),
782 None => SqlValue::Null,
783 });
784 }
785 if *pg_type == Type::INT2 {
786 return Ok(match portal.parameter::<i16>(idx, pg_type)? {
787 Some(v) => SqlValue::Number(v.to_string(), false),
788 None => SqlValue::Null,
789 });
790 }
791 if *pg_type == Type::INT4 {
792 return Ok(match portal.parameter::<i32>(idx, pg_type)? {
793 Some(v) => SqlValue::Number(v.to_string(), false),
794 None => SqlValue::Null,
795 });
796 }
797 if *pg_type == Type::INT8 {
798 return Ok(match portal.parameter::<i64>(idx, pg_type)? {
799 Some(v) => SqlValue::Number(v.to_string(), false),
800 None => SqlValue::Null,
801 });
802 }
803 if *pg_type == Type::FLOAT4 {
804 return Ok(match portal.parameter::<f32>(idx, pg_type)? {
805 Some(v) => SqlValue::Number((v as f64).to_string(), false),
806 None => SqlValue::Null,
807 });
808 }
809 if *pg_type == Type::FLOAT8 {
810 return Ok(match portal.parameter::<f64>(idx, pg_type)? {
811 Some(v) => SqlValue::Number(v.to_string(), false),
812 None => SqlValue::Null,
813 });
814 }
815
816 Ok(match portal.parameter::<String>(idx, pg_type)? {
817 Some(v) => SqlValue::SingleQuotedString(v),
818 None => SqlValue::Null,
819 })
820}
821
822fn parameter_sql_value_unknown(portal: &Portal<String>, idx: usize) -> PgWireResult<SqlValue> {
823 let Some(raw) = portal.parameters.get(idx) else {
824 return Ok(SqlValue::Null);
825 };
826 let Some(bytes) = raw else {
827 return Ok(SqlValue::Null);
828 };
829
830 if portal.parameter_format.is_binary(idx) {
831 match bytes.len() {
832 1 => {
833 let b = bytes[0];
834 if b == 0 || b == 1 {
835 return Ok(SqlValue::Boolean(b == 1));
836 }
837 }
838 4 => {
839 let v = i32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
840 return Ok(SqlValue::Number(v.to_string(), false));
841 }
842 8 => {
843 let v = i64::from_be_bytes([
844 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
845 ]);
846 return Ok(SqlValue::Number(v.to_string(), false));
847 }
848 _ => {}
849 }
850 }
851
852 match std::str::from_utf8(bytes) {
853 Ok(s) => {
854 if let Ok(i) = s.parse::<i64>() {
855 return Ok(SqlValue::Number(i.to_string(), false));
856 }
857 if let Ok(f) = s.parse::<f64>() {
858 return Ok(SqlValue::Number(f.to_string(), false));
859 }
860 if s.eq_ignore_ascii_case("true") || s.eq_ignore_ascii_case("false") {
861 return Ok(SqlValue::Boolean(s.eq_ignore_ascii_case("true")));
862 }
863 Ok(SqlValue::SingleQuotedString(s.to_string()))
864 }
865 Err(_) => Ok(SqlValue::SingleQuotedString(
866 String::from_utf8_lossy(bytes).into_owned(),
867 )),
868 }
869}
870
871fn tag_for_statement(stmt: &Statement) -> Tag {
872 Tag::new(tag_for_statement_name(stmt))
873}
874
875fn statement_returns_rows(stmt: &Statement) -> bool {
876 match stmt {
877 Statement::Query(_) => true,
878 Statement::Insert(insert) => insert.returning.is_some(),
879 Statement::Update { returning, .. } => returning.is_some(),
880 Statement::Delete(delete) => delete.returning.is_some(),
881 _ => false,
882 }
883}
884
885fn execution_tag_for_statement(stmt: &Statement, affected: usize) -> Tag {
886 match stmt {
887 Statement::Insert(_) => Tag::new("INSERT 0").with_rows(affected),
889 _ => tag_for_statement(stmt).with_rows(affected),
890 }
891}
892
893fn tag_for_statement_name(stmt: &Statement) -> &'static str {
894 match stmt {
895 Statement::Query(_) => "SELECT",
896 Statement::Insert(_) => "INSERT",
897 Statement::Update { .. } => "UPDATE",
898 Statement::Delete { .. } => "DELETE",
899 Statement::Truncate { .. } => "TRUNCATE",
900 Statement::CreateTable { .. } => "CREATE TABLE",
901 Statement::CreateIndex(_) => "CREATE INDEX",
902 Statement::AlterTable { .. } => "ALTER TABLE",
903 Statement::Drop {
904 object_type: sqlparser::ast::ObjectType::Table,
905 ..
906 } => "DROP TABLE",
907 Statement::Drop {
908 object_type: sqlparser::ast::ObjectType::Index,
909 ..
910 } => "DROP INDEX",
911 Statement::StartTransaction { .. } => "BEGIN",
912 Statement::Commit { .. } => "COMMIT",
913 Statement::Rollback { .. } => "ROLLBACK",
914 _ => "OK",
915 }
916}
917
918fn map_entdb_error(err: EntDbError) -> PgWireError {
919 let (code, message) = match err {
920 EntDbError::Query(m) => (query_sqlstate(&m), m),
921 EntDbError::BufferPoolFull => ("53200", "buffer pool full".to_string()),
922 EntDbError::PagePinned(pid) => ("55006", format!("page {pid} is pinned")),
923 EntDbError::PageNotFound(pid) => ("XX000", format!("page {pid} not found")),
924 EntDbError::Io(e) => ("58000", format!("io error: {e}")),
925 EntDbError::Wal(m) => ("XX000", format!("wal error: {m}")),
926 EntDbError::Corruption(m) => ("XX001", format!("corruption: {m}")),
927 EntDbError::InvalidPage(m) => ("XX000", format!("invalid page: {m}")),
928 EntDbError::PageAlreadyPresent(pid) => ("XX000", format!("page {pid} already present")),
929 };
930 user_error(code, message)
931}
932
933fn query_sqlstate(message: &str) -> &'static str {
934 let lower = message.to_ascii_lowercase();
935 if lower.contains("parse error") {
936 return "42601";
937 }
938 if lower.contains("already active") {
939 return "25001";
940 }
941 if lower.contains("no active transaction") {
942 return "25000";
943 }
944 if lower.contains("write-write conflict") {
945 return "40001";
946 }
947 if lower.contains("does not exist") {
948 if lower.contains("table") {
949 return "42P01";
950 }
951 if lower.contains("column") {
952 return "42703";
953 }
954 }
955 if lower.contains("cannot cast value")
956 || lower.contains("invalid")
957 || lower.contains("unsupported")
958 || lower.contains("must be numeric")
959 {
960 return "22000";
961 }
962 "XX000"
963}
964
965fn user_error(code: impl Into<String>, message: impl Into<String>) -> PgWireError {
966 PgWireError::UserError(Box::new(ErrorInfo::new(
967 "ERROR".to_string(),
968 code.into(),
969 message.into(),
970 )))
971}
972
973fn max_placeholder_index(sql: &str) -> usize {
974 let bytes = sql.as_bytes();
975 let mut max_idx = 0usize;
976 let mut i = 0usize;
977 while i < bytes.len() {
978 if bytes[i] == b'$' {
979 let start = i + 1;
980 let mut j = start;
981 while j < bytes.len() && bytes[j].is_ascii_digit() {
982 j += 1;
983 }
984 if j > start {
985 if let Ok(n) = sql[start..j].parse::<usize>() {
986 max_idx = max_idx.max(n);
987 }
988 i = j;
989 continue;
990 }
991 }
992 i += 1;
993 }
994 max_idx
995}
996
997fn normalize_placeholders_for_describe(sql: &str) -> String {
998 let bytes = sql.as_bytes();
999 let mut out = String::with_capacity(sql.len());
1000 let mut i = 0usize;
1001 while i < bytes.len() {
1002 if bytes[i] == b'$' {
1003 let start = i + 1;
1004 let mut j = start;
1005 while j < bytes.len() && bytes[j].is_ascii_digit() {
1006 j += 1;
1007 }
1008 if j > start {
1009 out.push('0');
1012 i = j;
1013 continue;
1014 }
1015 }
1016 out.push(bytes[i] as char);
1017 i += 1;
1018 }
1019 out
1020}
1021
1022pub fn scan_max_txn_id_from_storage(catalog: &Catalog) -> Result<u64, EntDbError> {
1023 let mut max_txn = 0_u64;
1024 for table in catalog.list_tables() {
1025 let t = Table::open(table.table_id, table.first_page_id, catalog.buffer_pool());
1026 for (_, tuple) in t.scan() {
1027 let decoded = entdb::query::executor::decode_stored_row(&tuple.data)?;
1028 if let entdb::query::executor::DecodedRow::Versioned(v) = decoded {
1029 max_txn = max_txn.max(v.created_txn);
1030 if let Some(d) = v.deleted_txn {
1031 max_txn = max_txn.max(d);
1032 }
1033 }
1034 }
1035 }
1036 Ok(max_txn)
1037}
1038
1039#[cfg(test)]
1040mod tests {
1041 use super::{
1042 bind_statement_placeholders, execution_tag_for_statement, max_placeholder_index,
1043 normalize_placeholders_for_describe, parse_single_statement, parse_sql, EntHandler,
1044 };
1045 use crate::server::Database;
1046 use entdb::catalog::Catalog;
1047 use entdb::query::history::OptimizerHistoryRecorder;
1048 use entdb::query::optimizer::OptimizerConfig;
1049 use entdb::storage::buffer_pool::BufferPool;
1050 use entdb::storage::disk_manager::DiskManager;
1051 use entdb::tx::TransactionManager;
1052 use entdb::wal::log_manager::LogManager;
1053 use pgwire::api::results::Tag;
1054 use pgwire::error::PgWireError;
1055 use proptest::prelude::*;
1056 use sqlparser::ast::Value as SqlValue;
1057 use std::sync::Arc;
1058 use tempfile::tempdir;
1059
1060 fn test_handler() -> EntHandler {
1061 let dir = tempdir().expect("tempdir");
1062 let db_path = dir.path().join("handler.db");
1063 let wal_path = dir.path().join("handler.wal");
1064
1065 let dm = Arc::new(DiskManager::new(&db_path).expect("disk"));
1066 let lm = Arc::new(LogManager::new(&wal_path, 4096).expect("wal"));
1067 let bp = Arc::new(BufferPool::with_log_manager(
1068 32,
1069 Arc::clone(&dm),
1070 Arc::clone(&lm),
1071 ));
1072 let catalog = Arc::new(Catalog::load(Arc::clone(&bp)).expect("catalog"));
1073 let txn = Arc::new(TransactionManager::new());
1074 let history_path = dir.path().join("handler.optimizer_history.json");
1075 let optimizer_history = Arc::new(
1076 OptimizerHistoryRecorder::new(
1077 &history_path,
1078 super::optimizer_history_schema_hash(),
1079 16,
1080 128,
1081 )
1082 .expect("optimizer history"),
1083 );
1084 let db = Arc::new(Database {
1085 disk_manager: dm,
1086 log_manager: lm,
1087 buffer_pool: bp,
1088 catalog,
1089 txn_manager: txn,
1090 optimizer_history,
1091 optimizer_config: OptimizerConfig::default(),
1092 });
1093 EntHandler::new(
1094 db,
1095 1024 * 1024,
1096 30_000,
1097 Arc::new(super::ServerMetrics::default()),
1098 false,
1099 )
1100 }
1101
1102 #[test]
1103 fn bind_statement_placeholders_replaces_parameter_nodes() {
1104 let mut stmt = parse_single_statement(
1105 "SELECT id FROM t WHERE v > $1 ORDER BY id LIMIT $2",
1106 "single",
1107 )
1108 .expect("parse");
1109 bind_statement_placeholders(
1110 &mut stmt,
1111 &[
1112 SqlValue::Number("7".to_string(), false),
1113 SqlValue::Number("3".to_string(), false),
1114 ],
1115 )
1116 .expect("bind placeholders");
1117
1118 let rendered = stmt.to_string();
1119 assert!(rendered.contains("v > 7"), "rendered SQL: {rendered}");
1120 assert!(rendered.contains("LIMIT 3"), "rendered SQL: {rendered}");
1121 }
1122
1123 #[test]
1124 fn bind_statement_placeholders_rejects_out_of_range_index() {
1125 let mut stmt = parse_single_statement("SELECT $3", "single").expect("parse");
1126 let err = match bind_statement_placeholders(
1127 &mut stmt,
1128 &[SqlValue::Number("1".to_string(), false)],
1129 ) {
1130 Ok(_) => panic!("expected out-of-range placeholder failure"),
1131 Err(e) => e,
1132 };
1133 match err {
1134 PgWireError::UserError(info) => {
1135 assert_eq!(info.code, "22023");
1136 }
1137 other => panic!("unexpected error: {other:?}"),
1138 }
1139 }
1140
1141 #[test]
1142 fn parse_single_statement_rejects_multi_statement_sql() {
1143 let err = match parse_single_statement("SELECT 1; SELECT 2", "single") {
1144 Ok(_) => panic!("expected single-statement rejection"),
1145 Err(e) => e,
1146 };
1147 match err {
1148 PgWireError::UserError(info) => {
1149 assert_eq!(info.code, "0A000");
1150 }
1151 other => panic!("unexpected error: {other:?}"),
1152 }
1153 }
1154
1155 #[test]
1156 fn max_placeholder_index_finds_highest_parameter() {
1157 assert_eq!(max_placeholder_index("SELECT $1, $2, $10"), 10);
1158 assert_eq!(max_placeholder_index("SELECT 1"), 0);
1159 }
1160
1161 #[test]
1162 fn normalize_placeholders_for_describe_rewrites_all_numbered_params() {
1163 let sql = "SELECT id FROM t WHERE v > $1 ORDER BY id LIMIT $2";
1164 let normalized = normalize_placeholders_for_describe(sql);
1165 assert_eq!(
1166 normalized,
1167 "SELECT id FROM t WHERE v > 0 ORDER BY id LIMIT 0"
1168 );
1169 }
1170
1171 #[test]
1172 fn execution_tag_for_insert_matches_postgres_format() {
1173 let stmt = parse_single_statement("INSERT INTO t VALUES (1)", "single").expect("parse");
1174 let tag = execution_tag_for_statement(&stmt, 2);
1175 assert_eq!(tag, Tag::new("INSERT 0").with_rows(2));
1176 }
1177
1178 #[test]
1179 fn execution_tag_for_update_and_delete_include_row_count() {
1180 let update = parse_single_statement("UPDATE t SET v = 1", "single").expect("parse");
1181 let delete = parse_single_statement("DELETE FROM t", "single").expect("parse");
1182 assert_eq!(
1183 execution_tag_for_statement(&update, 3),
1184 Tag::new("UPDATE").with_rows(3)
1185 );
1186 assert_eq!(
1187 execution_tag_for_statement(&delete, 4),
1188 Tag::new("DELETE").with_rows(4)
1189 );
1190 }
1191
1192 #[test]
1193 fn handler_begin_commit_round_trip() {
1194 let handler = test_handler();
1195 let out = handler
1196 .execute_sql("BEGIN; COMMIT;", None)
1197 .expect("execute");
1198 assert_eq!(out.len(), 2);
1199 }
1200
1201 #[test]
1202 fn handler_enforces_max_statement_size() {
1203 let handler = test_handler();
1204 let huge = "X".repeat(2 * 1024 * 1024);
1205 let err = match handler.execute_sql(&huge, None) {
1206 Ok(_) => panic!("max statement size should be enforced"),
1207 Err(e) => e,
1208 };
1209 assert!(err.to_string().contains("max_statement_bytes"));
1210 }
1211
1212 proptest! {
1213 #[test]
1214 fn parse_sql_never_panics_on_random_input(input in ".*") {
1215 let _ = parse_sql(&input);
1216 }
1217 }
1218}