1use crate::server::type_map::{encode_value, ent_to_pg_type};
18use async_trait::async_trait;
19use entdb::catalog::Catalog;
20use entdb::error::EntDbError;
21use entdb::query::binder::Binder;
22use entdb::query::executor::{build_executor, ExecutionContext, TxExecutionContext};
23use entdb::query::history::OptimizerHistoryRecord;
24use entdb::query::optimizer::Optimizer;
25use entdb::query::planner::Planner;
26use entdb::query::polyglot::{transpile_with_meta, PolyglotOptions};
27use entdb::storage::table::Table;
28use entdb::tx::TransactionHandle;
29use futures::{stream, Sink};
30use parking_lot::Mutex;
31use pgwire::api::portal::{Format, Portal};
32use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
33use pgwire::api::results::{
34 DataRowEncoder, DescribePortalResponse, DescribeStatementResponse, FieldFormat, FieldInfo,
35 QueryResponse, Response, Tag,
36};
37use pgwire::api::stmt::{NoopQueryParser, StoredStatement};
38use pgwire::api::{ClientInfo, Type};
39use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
40use pgwire::messages::PgWireBackendMessage;
41use sqlparser::ast::{visit_expressions_mut, Expr, Statement, Value as SqlValue};
42use sqlparser::dialect::PostgreSqlDialect;
43use sqlparser::parser::Parser;
44use std::fmt::Debug;
45use std::ops::ControlFlow;
46use std::sync::Arc;
47use tracing::{debug, info_span};
48
49use super::metrics::ServerMetrics;
50use super::{optimizer_history_schema_hash, Database};
51
52pub struct EntHandler {
53 db: Arc<Database>,
54 current_txn: Mutex<Option<TransactionHandle>>,
55 query_parser: Arc<NoopQueryParser>,
56 max_statement_bytes: usize,
57 query_timeout_ms: u64,
58 metrics: Arc<ServerMetrics>,
59 polyglot: PolyglotOptions,
60}
61
62impl EntHandler {
63 pub fn new(
64 db: Arc<Database>,
65 max_statement_bytes: usize,
66 query_timeout_ms: u64,
67 metrics: Arc<ServerMetrics>,
68 ) -> Self {
69 Self {
70 db,
71 current_txn: Mutex::new(None),
72 query_parser: Arc::new(NoopQueryParser::new()),
73 max_statement_bytes,
74 query_timeout_ms,
75 metrics,
76 polyglot: PolyglotOptions {
77 enabled: std::env::var("ENTDB_POLYGLOT")
78 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
79 .unwrap_or(false),
80 },
81 }
82 }
83
84 fn execute_sql(
85 &self,
86 sql: &str,
87 max_rows: Option<usize>,
88 ) -> PgWireResult<Vec<Response<'static>>> {
89 if sql.len() > self.max_statement_bytes {
90 return Err(user_error(
91 "54000",
92 format!(
93 "statement exceeds configured max_statement_bytes ({} > {})",
94 sql.len(),
95 self.max_statement_bytes
96 ),
97 ));
98 }
99 let _span = info_span!(
100 "pgwire.execute_sql",
101 sql_len = sql.len(),
102 max_rows = max_rows.unwrap_or(0)
103 )
104 .entered();
105 let statements = parse_sql_with_polyglot(sql, self.polyglot)?;
106 let mut responses = Vec::new();
107 let started = std::time::Instant::now();
108
109 for stmt in &statements {
110 if started.elapsed().as_millis() as u64 > self.query_timeout_ms {
111 return Err(user_error("57014", "query timeout exceeded"));
112 }
113 debug!(statement = %tag_for_statement_name(stmt), "processing statement");
114 match stmt {
115 Statement::StartTransaction { .. } => {
116 let mut guard = self.current_txn.lock();
117 if guard.is_some() {
118 return Err(user_error(
119 "25001",
120 "transaction already active for this session",
121 ));
122 }
123 *guard = Some(self.db.txn_manager.begin());
124 responses.push(Response::TransactionStart(Tag::new("BEGIN")));
125 }
126 Statement::Commit { .. } => {
127 let tx = {
128 let mut guard = self.current_txn.lock();
129 guard
130 .take()
131 .ok_or_else(|| user_error("25000", "no active transaction to COMMIT"))?
132 };
133 self.db
134 .txn_manager
135 .commit(tx.txn_id)
136 .map_err(map_entdb_error)?;
137 responses.push(Response::TransactionEnd(Tag::new("COMMIT")));
138 }
139 Statement::Rollback { .. } => {
140 let tx = {
141 let mut guard = self.current_txn.lock();
142 guard.take().ok_or_else(|| {
143 user_error("25000", "no active transaction to ROLLBACK")
144 })?
145 };
146 self.db.txn_manager.abort(tx.txn_id);
147 responses.push(Response::TransactionEnd(Tag::new("ROLLBACK")));
148 }
149 _ => {
150 let active_tx = *self.current_txn.lock();
151 let exec_resp = if let Some(tx) = active_tx {
152 self.execute_statement_in_txn(&tx, stmt, max_rows, None)?
153 } else {
154 let tx = self.db.txn_manager.begin();
155 match self.execute_statement_in_txn(&tx, stmt, max_rows, None) {
156 Ok(resp) => {
157 if let Err(e) = self.db.txn_manager.commit(tx.txn_id) {
158 self.db.txn_manager.abort(tx.txn_id);
159 return Err(map_entdb_error(e));
160 }
161 resp
162 }
163 Err(e) => {
164 self.db.txn_manager.abort(tx.txn_id);
165 return Err(e);
166 }
167 }
168 };
169 responses.push(exec_resp);
170 }
171 }
172 }
173
174 Ok(responses)
175 }
176
177 fn execute_statement_in_txn(
178 &self,
179 tx: &TransactionHandle,
180 stmt: &Statement,
181 max_rows: Option<usize>,
182 row_format: Option<&Format>,
183 ) -> PgWireResult<Response<'static>> {
184 let stmt_started = std::time::Instant::now();
185 let _span = info_span!(
186 "pgwire.execute_statement",
187 statement = %tag_for_statement_name(stmt),
188 txn_id = tx.txn_id
189 )
190 .entered();
191 let binder = Binder::new(Arc::clone(&self.db.catalog));
192 let planner = Planner;
193
194 let bound = match binder.bind(stmt) {
195 Ok(b) => b,
196 Err(e) => {
197 self.record_optimizer_history_error(stmt, "bind_error", stmt_started.elapsed());
198 return Err(map_entdb_error(e));
199 }
200 };
201 let fingerprint = Optimizer::fingerprint_bound_statement(&bound);
202 let history = self.db.optimizer_history.read_for_fingerprint(&fingerprint);
203 self.ensure_statement_timeout(stmt_started)?;
204 let plan = match planner.plan(bound) {
205 Ok(p) => p,
206 Err(e) => {
207 self.record_optimizer_history(OptimizerHistoryRecord {
208 fingerprint,
209 plan_signature: "error".to_string(),
210 schema_hash: optimizer_history_schema_hash().to_string(),
211 captured_at_ms: now_epoch_millis(),
212 rowcount_observed_json: "{\"root\":0}".to_string(),
213 latency_ms: stmt_started.elapsed().as_millis() as u64,
214 memory_peak_bytes: 0,
215 success: false,
216 error_class: Some("plan_error".to_string()),
217 confidence: 0.0,
218 });
219 return Err(map_entdb_error(e));
220 }
221 };
222 self.ensure_statement_timeout(stmt_started)?;
223 let optimized_outcome = Optimizer::optimize_with_trace_and_history(
224 plan,
225 &fingerprint,
226 self.db.optimizer_config,
227 &history,
228 );
229 let chosen_plan_signature = optimized_outcome
230 .trace
231 .chosen_plan_signature
232 .clone()
233 .unwrap_or_else(|| "baseline".to_string());
234 let optimized = optimized_outcome.plan;
235 self.ensure_statement_timeout(stmt_started)?;
236
237 let ctx = ExecutionContext {
238 catalog: Arc::clone(&self.db.catalog),
239 tx: Some(TxExecutionContext {
240 txn_id: tx.txn_id,
241 snapshot_ts: tx.snapshot_ts,
242 txn_manager: Arc::clone(&self.db.txn_manager),
243 }),
244 };
245
246 let mut exec = match build_executor(&optimized, &ctx) {
247 Ok(exec) => exec,
248 Err(e) => {
249 self.record_optimizer_history(OptimizerHistoryRecord {
250 fingerprint,
251 plan_signature: chosen_plan_signature,
252 schema_hash: optimizer_history_schema_hash().to_string(),
253 captured_at_ms: now_epoch_millis(),
254 rowcount_observed_json: "{\"root\":0}".to_string(),
255 latency_ms: stmt_started.elapsed().as_millis() as u64,
256 memory_peak_bytes: 0,
257 success: false,
258 error_class: Some("executor_build_error".to_string()),
259 confidence: 0.0,
260 });
261 return Err(map_entdb_error(e));
262 }
263 };
264 self.ensure_statement_timeout(stmt_started)?;
265 if let Err(e) = exec.open() {
266 self.record_optimizer_history(OptimizerHistoryRecord {
267 fingerprint,
268 plan_signature: chosen_plan_signature,
269 schema_hash: optimizer_history_schema_hash().to_string(),
270 captured_at_ms: now_epoch_millis(),
271 rowcount_observed_json: "{\"root\":0}".to_string(),
272 latency_ms: stmt_started.elapsed().as_millis() as u64,
273 memory_peak_bytes: 0,
274 success: false,
275 error_class: Some("executor_open_error".to_string()),
276 confidence: 0.0,
277 });
278 return Err(map_entdb_error(e));
279 }
280 self.ensure_statement_timeout(stmt_started)?;
281
282 let schema = exec.schema().clone();
283 let mut rows = Vec::new();
284 while let Some(row) = match exec.next() {
285 Ok(r) => r,
286 Err(e) => {
287 self.record_optimizer_history(OptimizerHistoryRecord {
288 fingerprint: fingerprint.clone(),
289 plan_signature: chosen_plan_signature.clone(),
290 schema_hash: optimizer_history_schema_hash().to_string(),
291 captured_at_ms: now_epoch_millis(),
292 rowcount_observed_json: format!("{{\"root\":{}}}", rows.len()),
293 latency_ms: stmt_started.elapsed().as_millis() as u64,
294 memory_peak_bytes: 0,
295 success: false,
296 error_class: Some("executor_next_error".to_string()),
297 confidence: 0.0,
298 });
299 return Err(map_entdb_error(e));
300 }
301 } {
302 self.ensure_statement_timeout(stmt_started)?;
303 rows.push(row);
304 if let Some(limit) = max_rows {
305 if limit > 0 && rows.len() >= limit {
306 break;
307 }
308 }
309 }
310 if let Err(e) = exec.close() {
311 self.record_optimizer_history(OptimizerHistoryRecord {
312 fingerprint,
313 plan_signature: chosen_plan_signature,
314 schema_hash: optimizer_history_schema_hash().to_string(),
315 captured_at_ms: now_epoch_millis(),
316 rowcount_observed_json: format!("{{\"root\":{}}}", rows.len()),
317 latency_ms: stmt_started.elapsed().as_millis() as u64,
318 memory_peak_bytes: 0,
319 success: false,
320 error_class: Some("executor_close_error".to_string()),
321 confidence: 0.0,
322 });
323 return Err(map_entdb_error(e));
324 }
325
326 let observed_rows = if schema.columns.is_empty() {
327 exec.affected_rows().unwrap_or(0)
328 } else {
329 rows.len() as u64
330 };
331 self.record_optimizer_history(OptimizerHistoryRecord {
332 fingerprint,
333 plan_signature: chosen_plan_signature,
334 schema_hash: optimizer_history_schema_hash().to_string(),
335 captured_at_ms: now_epoch_millis(),
336 rowcount_observed_json: format!("{{\"root\":{observed_rows}}}"),
337 latency_ms: stmt_started.elapsed().as_millis() as u64,
338 memory_peak_bytes: 0,
339 success: true,
340 error_class: None,
341 confidence: 1.0,
342 });
343
344 if schema.columns.is_empty() {
345 let affected = exec.affected_rows().unwrap_or(0) as usize;
346 let tag = execution_tag_for_statement(stmt, affected);
347 return Ok(Response::Execution(tag));
348 }
349
350 let fields = schema
351 .columns
352 .iter()
353 .enumerate()
354 .map(|(idx, c)| {
355 FieldInfo::new(
356 c.name.clone(),
357 None,
358 None,
359 ent_to_pg_type(&c.data_type),
360 row_format
361 .map(|f| f.format_for(idx))
362 .unwrap_or(FieldFormat::Text),
363 )
364 })
365 .collect::<Vec<_>>();
366 let schema = Arc::new(fields);
367
368 let mut encoded = Vec::with_capacity(rows.len());
369 for row in &rows {
370 let mut encoder = DataRowEncoder::new(Arc::clone(&schema));
371 for (idx, value) in row.iter().enumerate() {
372 let field = &schema[idx];
373 encode_value(&mut encoder, value, field.datatype(), field.format())?;
374 }
375 encoded.push(encoder.finish()?);
376 }
377
378 let data_row_stream = stream::iter(encoded.into_iter().map(Ok));
379 let mut query = QueryResponse::new(schema, data_row_stream);
380 query.set_command_tag(tag_for_statement_name(stmt));
381 Ok(Response::Query(query))
382 }
383
384 fn describe_statement_from_stmt(
385 &self,
386 stmt: &Statement,
387 format: &Format,
388 ) -> PgWireResult<Vec<FieldInfo>> {
389 if !statement_returns_rows(stmt) {
390 return Ok(Vec::new());
391 }
392
393 let tx = self.db.txn_manager.begin();
394 let binder = Binder::new(Arc::clone(&self.db.catalog));
395 let planner = Planner;
396
397 let bound = binder.bind(stmt).map_err(map_entdb_error)?;
398 let fingerprint = Optimizer::fingerprint_bound_statement(&bound);
399 let history = self.db.optimizer_history.read_for_fingerprint(&fingerprint);
400 let plan = planner.plan(bound).map_err(map_entdb_error)?;
401 let optimized = Optimizer::optimize_with_trace_and_history(
402 plan,
403 &fingerprint,
404 self.db.optimizer_config,
405 &history,
406 )
407 .plan;
408
409 let ctx = ExecutionContext {
410 catalog: Arc::clone(&self.db.catalog),
411 tx: Some(TxExecutionContext {
412 txn_id: tx.txn_id,
413 snapshot_ts: tx.snapshot_ts,
414 txn_manager: Arc::clone(&self.db.txn_manager),
415 }),
416 };
417
418 let exec = build_executor(&optimized, &ctx).map_err(map_entdb_error)?;
419 let fields = exec
420 .schema()
421 .columns
422 .iter()
423 .enumerate()
424 .map(|(idx, col)| {
425 FieldInfo::new(
426 col.name.clone(),
427 None,
428 None,
429 ent_to_pg_type(&col.data_type),
430 format.format_for(idx),
431 )
432 })
433 .collect::<Vec<_>>();
434
435 self.db.txn_manager.abort(tx.txn_id);
436 Ok(fields)
437 }
438
439 fn describe_statement_inner(&self, sql: &str, format: &Format) -> PgWireResult<Vec<FieldInfo>> {
440 let stmt = parse_single_statement_with_polyglot(
441 sql,
442 "describe supports a single statement",
443 self.polyglot,
444 )?;
445 self.describe_statement_from_stmt(&stmt, format)
446 }
447}
448
449impl EntHandler {
450 fn record_optimizer_history(&self, entry: OptimizerHistoryRecord) {
451 self.db.optimizer_history.try_record(entry);
452 }
453
454 fn record_optimizer_history_error(
455 &self,
456 stmt: &Statement,
457 class: &str,
458 elapsed: std::time::Duration,
459 ) {
460 self.record_optimizer_history(OptimizerHistoryRecord {
461 fingerprint: format!("stmt:{}", tag_for_statement_name(stmt)),
462 plan_signature: "error".to_string(),
463 schema_hash: optimizer_history_schema_hash().to_string(),
464 captured_at_ms: now_epoch_millis(),
465 rowcount_observed_json: "{\"root\":0}".to_string(),
466 latency_ms: elapsed.as_millis() as u64,
467 memory_peak_bytes: 0,
468 success: false,
469 error_class: Some(class.to_string()),
470 confidence: 0.0,
471 });
472 }
473}
474
475fn now_epoch_millis() -> u64 {
476 std::time::SystemTime::now()
477 .duration_since(std::time::UNIX_EPOCH)
478 .map(|d| d.as_millis() as u64)
479 .unwrap_or(0)
480}
481
482#[async_trait]
483impl SimpleQueryHandler for EntHandler {
484 async fn do_query<'a, C>(
485 &self,
486 _client: &mut C,
487 query: &'a str,
488 ) -> PgWireResult<Vec<Response<'a>>>
489 where
490 C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
491 C::Error: Debug,
492 PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
493 {
494 let started = std::time::Instant::now();
495 let result = self.execute_sql(query, None);
496 self.record_query_metric(started, result.as_ref().err());
497 result
498 }
499}
500
501#[async_trait]
502impl ExtendedQueryHandler for EntHandler {
503 type Statement = String;
504 type QueryParser = NoopQueryParser;
505
506 fn query_parser(&self) -> Arc<Self::QueryParser> {
507 Arc::clone(&self.query_parser)
508 }
509
510 async fn do_query<'a, C>(
511 &self,
512 _client: &mut C,
513 portal: &'a Portal<Self::Statement>,
514 max_rows: usize,
515 ) -> PgWireResult<Response<'a>>
516 where
517 C: ClientInfo + Unpin + Send + Sync,
518 {
519 let started = std::time::Instant::now();
520 let result = (|| -> PgWireResult<Response<'a>> {
521 let stmt = statement_from_portal(portal, self.polyglot)?;
522 if matches!(
523 stmt,
524 Statement::StartTransaction { .. }
525 | Statement::Commit { .. }
526 | Statement::Rollback { .. }
527 ) {
528 let mut responses = self.execute_sql(&stmt.to_string(), Some(max_rows))?;
529 if responses.is_empty() {
530 return Ok(Response::EmptyQuery);
531 }
532 return Ok(responses.remove(0));
533 }
534
535 let active_tx = *self.current_txn.lock();
536 if let Some(tx) = active_tx {
537 return self.execute_statement_in_txn(
538 &tx,
539 &stmt,
540 Some(max_rows),
541 Some(&portal.result_column_format),
542 );
543 }
544
545 let tx = self.db.txn_manager.begin();
546 match self.execute_statement_in_txn(
547 &tx,
548 &stmt,
549 Some(max_rows),
550 Some(&portal.result_column_format),
551 ) {
552 Ok(resp) => {
553 if let Err(e) = self.db.txn_manager.commit(tx.txn_id) {
554 self.db.txn_manager.abort(tx.txn_id);
555 Err(map_entdb_error(e))
556 } else {
557 Ok(resp)
558 }
559 }
560 Err(e) => {
561 self.db.txn_manager.abort(tx.txn_id);
562 Err(e)
563 }
564 }
565 })();
566 self.record_query_metric(started, result.as_ref().err());
567 result
568 }
569
570 async fn do_describe_statement<C>(
571 &self,
572 _client: &mut C,
573 stmt: &StoredStatement<Self::Statement>,
574 ) -> PgWireResult<DescribeStatementResponse>
575 where
576 C: ClientInfo + Unpin + Send + Sync,
577 {
578 let mut parameter_types = stmt.parameter_types.clone();
579 let inferred_count = max_placeholder_index(&stmt.statement);
580 if parameter_types.len() < inferred_count {
581 parameter_types.resize(inferred_count, Type::INT4);
582 }
583 for ty in &mut parameter_types {
584 if *ty == Type::UNKNOWN {
585 *ty = Type::INT4;
586 }
587 }
588
589 let fields = self
590 .describe_statement_inner(&stmt.statement, &Format::UnifiedBinary)
591 .or_else(|_| {
592 let normalized = normalize_placeholders_for_describe(&stmt.statement);
593 self.describe_statement_inner(&normalized, &Format::UnifiedBinary)
594 })
595 .unwrap_or_default();
596 Ok(DescribeStatementResponse::new(parameter_types, fields))
597 }
598
599 async fn do_describe_portal<C>(
600 &self,
601 _client: &mut C,
602 portal: &Portal<Self::Statement>,
603 ) -> PgWireResult<DescribePortalResponse>
604 where
605 C: ClientInfo + Unpin + Send + Sync,
606 {
607 let stmt = statement_from_portal(portal, self.polyglot)?;
608 let fields = self.describe_statement_from_stmt(&stmt, &portal.result_column_format)?;
609 Ok(DescribePortalResponse::new(fields))
610 }
611}
612
613#[cfg(test)]
614fn parse_sql(sql: &str) -> PgWireResult<Vec<Statement>> {
615 let dialect = PostgreSqlDialect {};
616 Parser::parse_sql(&dialect, sql).map_err(|e| user_error("42601", format!("parse error: {e}")))
617}
618
619fn parse_sql_with_polyglot(sql: &str, opts: PolyglotOptions) -> PgWireResult<Vec<Statement>> {
620 let transpiled = transpile_with_meta(sql, opts).map_err(map_entdb_error)?;
621 let dialect = PostgreSqlDialect {};
622 Parser::parse_sql(&dialect, &transpiled.transpiled_sql).map_err(|e| {
623 if transpiled.changed {
624 user_error(
625 "42601",
626 format!(
627 "parse error: {e}; original_sql={:?}; transpiled_sql={:?}",
628 transpiled.original_sql, transpiled.transpiled_sql
629 ),
630 )
631 } else {
632 user_error("42601", format!("parse error: {e}"))
633 }
634 })
635}
636
637impl EntHandler {
638 fn ensure_statement_timeout(&self, started: std::time::Instant) -> PgWireResult<()> {
639 if started.elapsed().as_millis() as u64 > self.query_timeout_ms {
640 return Err(user_error("57014", "query timeout exceeded"));
641 }
642 Ok(())
643 }
644
645 fn record_query_metric(&self, started: std::time::Instant, err: Option<&PgWireError>) {
646 let elapsed = started.elapsed().as_nanos() as u64;
647 let sqlstate = err.and_then(sqlstate_from_error);
648 self.metrics.on_query_finished(elapsed, sqlstate);
649 }
650}
651
652fn sqlstate_from_error(err: &PgWireError) -> Option<&str> {
653 match err {
654 PgWireError::UserError(info) => Some(info.code.as_str()),
655 _ => Some("XX000"),
656 }
657}
658
659#[cfg(test)]
660fn parse_single_statement(sql: &str, message: &'static str) -> PgWireResult<Statement> {
661 let mut statements = parse_sql(sql)?;
662 if statements.len() != 1 {
663 return Err(user_error("0A000", message));
664 }
665 Ok(statements.remove(0))
666}
667
668fn parse_single_statement_with_polyglot(
669 sql: &str,
670 message: &'static str,
671 opts: PolyglotOptions,
672) -> PgWireResult<Statement> {
673 let mut statements = parse_sql_with_polyglot(sql, opts)?;
674 if statements.len() != 1 {
675 return Err(user_error("0A000", message));
676 }
677 Ok(statements.remove(0))
678}
679
680fn statement_from_portal(
681 portal: &Portal<String>,
682 opts: PolyglotOptions,
683) -> PgWireResult<Statement> {
684 let mut stmt = parse_single_statement_with_polyglot(
685 &portal.statement.statement,
686 "extended protocol supports a single statement per execute",
687 opts,
688 )?;
689 if portal.parameter_len() == 0 {
690 return Ok(stmt);
691 }
692
693 let values = collect_portal_values(portal)?;
694 bind_statement_placeholders(&mut stmt, &values)?;
695 Ok(stmt)
696}
697
698fn collect_portal_values(portal: &Portal<String>) -> PgWireResult<Vec<SqlValue>> {
699 let mut values = Vec::with_capacity(portal.parameter_len());
700 for i in 0..portal.parameter_len() {
701 let pg_type = portal
702 .statement
703 .parameter_types
704 .get(i)
705 .unwrap_or(&Type::UNKNOWN);
706 values.push(parameter_sql_value(portal, i, pg_type)?);
707 }
708 Ok(values)
709}
710
711fn bind_statement_placeholders(stmt: &mut Statement, values: &[SqlValue]) -> PgWireResult<()> {
712 let mut question_mark_idx = 0usize;
713 let outcome = visit_expressions_mut(stmt, |expr| {
714 if let Expr::Value(v) = expr {
715 if let SqlValue::Placeholder(placeholder) = &v.value {
716 let idx = match placeholder_index(placeholder, values.len(), &mut question_mark_idx)
717 {
718 Ok(idx) => idx,
719 Err(err) => return ControlFlow::Break(err),
720 };
721 v.value = values[idx - 1].clone();
722 }
723 }
724 ControlFlow::Continue(())
725 });
726
727 match outcome {
728 ControlFlow::Continue(()) => Ok(()),
729 ControlFlow::Break(err) => Err(err),
730 }
731}
732
733fn placeholder_index(
734 placeholder: &str,
735 total: usize,
736 question_mark_idx: &mut usize,
737) -> PgWireResult<usize> {
738 let idx = if placeholder == "?" {
739 *question_mark_idx += 1;
740 *question_mark_idx
741 } else if let Some(raw) = placeholder.strip_prefix('$') {
742 raw.parse::<usize>()
743 .map_err(|_| user_error("22023", format!("invalid placeholder '{placeholder}'")))?
744 } else {
745 return Err(user_error(
746 "22023",
747 format!("unsupported placeholder syntax '{placeholder}'"),
748 ));
749 };
750
751 if idx == 0 || idx > total {
752 return Err(user_error(
753 "22023",
754 format!("placeholder '{placeholder}' is out of range"),
755 ));
756 }
757 Ok(idx)
758}
759
760fn parameter_sql_value(
761 portal: &Portal<String>,
762 idx: usize,
763 pg_type: &Type,
764) -> PgWireResult<SqlValue> {
765 if *pg_type == Type::UNKNOWN {
766 return parameter_sql_value_unknown(portal, idx);
767 }
768 if *pg_type == Type::BOOL {
769 return Ok(match portal.parameter::<bool>(idx, pg_type)? {
770 Some(v) => SqlValue::Boolean(v),
771 None => SqlValue::Null,
772 });
773 }
774 if *pg_type == Type::INT2 {
775 return Ok(match portal.parameter::<i16>(idx, pg_type)? {
776 Some(v) => SqlValue::Number(v.to_string(), false),
777 None => SqlValue::Null,
778 });
779 }
780 if *pg_type == Type::INT4 {
781 return Ok(match portal.parameter::<i32>(idx, pg_type)? {
782 Some(v) => SqlValue::Number(v.to_string(), false),
783 None => SqlValue::Null,
784 });
785 }
786 if *pg_type == Type::INT8 {
787 return Ok(match portal.parameter::<i64>(idx, pg_type)? {
788 Some(v) => SqlValue::Number(v.to_string(), false),
789 None => SqlValue::Null,
790 });
791 }
792 if *pg_type == Type::FLOAT4 {
793 return Ok(match portal.parameter::<f32>(idx, pg_type)? {
794 Some(v) => SqlValue::Number((v as f64).to_string(), false),
795 None => SqlValue::Null,
796 });
797 }
798 if *pg_type == Type::FLOAT8 {
799 return Ok(match portal.parameter::<f64>(idx, pg_type)? {
800 Some(v) => SqlValue::Number(v.to_string(), false),
801 None => SqlValue::Null,
802 });
803 }
804
805 Ok(match portal.parameter::<String>(idx, pg_type)? {
806 Some(v) => SqlValue::SingleQuotedString(v),
807 None => SqlValue::Null,
808 })
809}
810
811fn parameter_sql_value_unknown(portal: &Portal<String>, idx: usize) -> PgWireResult<SqlValue> {
812 let Some(raw) = portal.parameters.get(idx) else {
813 return Ok(SqlValue::Null);
814 };
815 let Some(bytes) = raw else {
816 return Ok(SqlValue::Null);
817 };
818
819 if portal.parameter_format.is_binary(idx) {
820 match bytes.len() {
821 1 => {
822 let b = bytes[0];
823 if b == 0 || b == 1 {
824 return Ok(SqlValue::Boolean(b == 1));
825 }
826 }
827 4 => {
828 let v = i32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
829 return Ok(SqlValue::Number(v.to_string(), false));
830 }
831 8 => {
832 let v = i64::from_be_bytes([
833 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
834 ]);
835 return Ok(SqlValue::Number(v.to_string(), false));
836 }
837 _ => {}
838 }
839 }
840
841 match std::str::from_utf8(bytes) {
842 Ok(s) => {
843 if let Ok(i) = s.parse::<i64>() {
844 return Ok(SqlValue::Number(i.to_string(), false));
845 }
846 if let Ok(f) = s.parse::<f64>() {
847 return Ok(SqlValue::Number(f.to_string(), false));
848 }
849 if s.eq_ignore_ascii_case("true") || s.eq_ignore_ascii_case("false") {
850 return Ok(SqlValue::Boolean(s.eq_ignore_ascii_case("true")));
851 }
852 Ok(SqlValue::SingleQuotedString(s.to_string()))
853 }
854 Err(_) => Ok(SqlValue::SingleQuotedString(
855 String::from_utf8_lossy(bytes).into_owned(),
856 )),
857 }
858}
859
860fn tag_for_statement(stmt: &Statement) -> Tag {
861 Tag::new(tag_for_statement_name(stmt))
862}
863
864fn statement_returns_rows(stmt: &Statement) -> bool {
865 match stmt {
866 Statement::Query(_) => true,
867 Statement::Insert(insert) => insert.returning.is_some(),
868 Statement::Update { returning, .. } => returning.is_some(),
869 Statement::Delete(delete) => delete.returning.is_some(),
870 _ => false,
871 }
872}
873
874fn execution_tag_for_statement(stmt: &Statement, affected: usize) -> Tag {
875 match stmt {
876 Statement::Insert(_) => Tag::new("INSERT 0").with_rows(affected),
878 _ => tag_for_statement(stmt).with_rows(affected),
879 }
880}
881
882fn tag_for_statement_name(stmt: &Statement) -> &'static str {
883 match stmt {
884 Statement::Query(_) => "SELECT",
885 Statement::Insert(_) => "INSERT",
886 Statement::Update { .. } => "UPDATE",
887 Statement::Delete { .. } => "DELETE",
888 Statement::Truncate { .. } => "TRUNCATE",
889 Statement::CreateTable { .. } => "CREATE TABLE",
890 Statement::CreateIndex(_) => "CREATE INDEX",
891 Statement::AlterTable { .. } => "ALTER TABLE",
892 Statement::Drop {
893 object_type: sqlparser::ast::ObjectType::Table,
894 ..
895 } => "DROP TABLE",
896 Statement::Drop {
897 object_type: sqlparser::ast::ObjectType::Index,
898 ..
899 } => "DROP INDEX",
900 Statement::StartTransaction { .. } => "BEGIN",
901 Statement::Commit { .. } => "COMMIT",
902 Statement::Rollback { .. } => "ROLLBACK",
903 _ => "OK",
904 }
905}
906
907fn map_entdb_error(err: EntDbError) -> PgWireError {
908 let (code, message) = match err {
909 EntDbError::Query(m) => (query_sqlstate(&m), m),
910 EntDbError::BufferPoolFull => ("53200", "buffer pool full".to_string()),
911 EntDbError::PagePinned(pid) => ("55006", format!("page {pid} is pinned")),
912 EntDbError::PageNotFound(pid) => ("XX000", format!("page {pid} not found")),
913 EntDbError::Io(e) => ("58000", format!("io error: {e}")),
914 EntDbError::Wal(m) => ("XX000", format!("wal error: {m}")),
915 EntDbError::Corruption(m) => ("XX001", format!("corruption: {m}")),
916 EntDbError::InvalidPage(m) => ("XX000", format!("invalid page: {m}")),
917 EntDbError::PageAlreadyPresent(pid) => ("XX000", format!("page {pid} already present")),
918 };
919 user_error(code, message)
920}
921
922fn query_sqlstate(message: &str) -> &'static str {
923 let lower = message.to_ascii_lowercase();
924 if lower.contains("parse error") {
925 return "42601";
926 }
927 if lower.contains("already active") {
928 return "25001";
929 }
930 if lower.contains("no active transaction") {
931 return "25000";
932 }
933 if lower.contains("write-write conflict") {
934 return "40001";
935 }
936 if lower.contains("does not exist") {
937 if lower.contains("table") {
938 return "42P01";
939 }
940 if lower.contains("column") {
941 return "42703";
942 }
943 }
944 if lower.contains("cannot cast value")
945 || lower.contains("invalid")
946 || lower.contains("unsupported")
947 || lower.contains("must be numeric")
948 {
949 return "22000";
950 }
951 "XX000"
952}
953
954fn user_error(code: impl Into<String>, message: impl Into<String>) -> PgWireError {
955 PgWireError::UserError(Box::new(ErrorInfo::new(
956 "ERROR".to_string(),
957 code.into(),
958 message.into(),
959 )))
960}
961
962fn max_placeholder_index(sql: &str) -> usize {
963 let bytes = sql.as_bytes();
964 let mut max_idx = 0usize;
965 let mut i = 0usize;
966 while i < bytes.len() {
967 if bytes[i] == b'$' {
968 let start = i + 1;
969 let mut j = start;
970 while j < bytes.len() && bytes[j].is_ascii_digit() {
971 j += 1;
972 }
973 if j > start {
974 if let Ok(n) = sql[start..j].parse::<usize>() {
975 max_idx = max_idx.max(n);
976 }
977 i = j;
978 continue;
979 }
980 }
981 i += 1;
982 }
983 max_idx
984}
985
986fn normalize_placeholders_for_describe(sql: &str) -> String {
987 let bytes = sql.as_bytes();
988 let mut out = String::with_capacity(sql.len());
989 let mut i = 0usize;
990 while i < bytes.len() {
991 if bytes[i] == b'$' {
992 let start = i + 1;
993 let mut j = start;
994 while j < bytes.len() && bytes[j].is_ascii_digit() {
995 j += 1;
996 }
997 if j > start {
998 out.push('0');
1001 i = j;
1002 continue;
1003 }
1004 }
1005 out.push(bytes[i] as char);
1006 i += 1;
1007 }
1008 out
1009}
1010
1011pub fn scan_max_txn_id_from_storage(catalog: &Catalog) -> Result<u64, EntDbError> {
1012 let mut max_txn = 0_u64;
1013 for table in catalog.list_tables() {
1014 let t = Table::open(table.table_id, table.first_page_id, catalog.buffer_pool());
1015 for (_, tuple) in t.scan() {
1016 let decoded = entdb::query::executor::decode_stored_row(&tuple.data)?;
1017 if let entdb::query::executor::DecodedRow::Versioned(v) = decoded {
1018 max_txn = max_txn.max(v.created_txn);
1019 if let Some(d) = v.deleted_txn {
1020 max_txn = max_txn.max(d);
1021 }
1022 }
1023 }
1024 }
1025 Ok(max_txn)
1026}
1027
1028#[cfg(test)]
1029mod tests {
1030 use super::{
1031 bind_statement_placeholders, execution_tag_for_statement, max_placeholder_index,
1032 normalize_placeholders_for_describe, parse_single_statement, parse_sql, EntHandler,
1033 };
1034 use crate::server::Database;
1035 use entdb::catalog::Catalog;
1036 use entdb::query::history::OptimizerHistoryRecorder;
1037 use entdb::query::optimizer::OptimizerConfig;
1038 use entdb::storage::buffer_pool::BufferPool;
1039 use entdb::storage::disk_manager::DiskManager;
1040 use entdb::tx::TransactionManager;
1041 use entdb::wal::log_manager::LogManager;
1042 use pgwire::api::results::Tag;
1043 use pgwire::error::PgWireError;
1044 use proptest::prelude::*;
1045 use sqlparser::ast::Value as SqlValue;
1046 use std::sync::Arc;
1047 use tempfile::tempdir;
1048
1049 fn test_handler() -> EntHandler {
1050 let dir = tempdir().expect("tempdir");
1051 let db_path = dir.path().join("handler.db");
1052 let wal_path = dir.path().join("handler.wal");
1053
1054 let dm = Arc::new(DiskManager::new(&db_path).expect("disk"));
1055 let lm = Arc::new(LogManager::new(&wal_path, 4096).expect("wal"));
1056 let bp = Arc::new(BufferPool::with_log_manager(
1057 32,
1058 Arc::clone(&dm),
1059 Arc::clone(&lm),
1060 ));
1061 let catalog = Arc::new(Catalog::load(Arc::clone(&bp)).expect("catalog"));
1062 let txn = Arc::new(TransactionManager::new());
1063 let history_path = dir.path().join("handler.optimizer_history.json");
1064 let optimizer_history = Arc::new(
1065 OptimizerHistoryRecorder::new(
1066 &history_path,
1067 super::optimizer_history_schema_hash(),
1068 16,
1069 128,
1070 )
1071 .expect("optimizer history"),
1072 );
1073 let db = Arc::new(Database {
1074 disk_manager: dm,
1075 log_manager: lm,
1076 buffer_pool: bp,
1077 catalog,
1078 txn_manager: txn,
1079 optimizer_history,
1080 optimizer_config: OptimizerConfig::default(),
1081 });
1082 EntHandler::new(
1083 db,
1084 1024 * 1024,
1085 30_000,
1086 Arc::new(super::ServerMetrics::default()),
1087 )
1088 }
1089
1090 #[test]
1091 fn bind_statement_placeholders_replaces_parameter_nodes() {
1092 let mut stmt = parse_single_statement(
1093 "SELECT id FROM t WHERE v > $1 ORDER BY id LIMIT $2",
1094 "single",
1095 )
1096 .expect("parse");
1097 bind_statement_placeholders(
1098 &mut stmt,
1099 &[
1100 SqlValue::Number("7".to_string(), false),
1101 SqlValue::Number("3".to_string(), false),
1102 ],
1103 )
1104 .expect("bind placeholders");
1105
1106 let rendered = stmt.to_string();
1107 assert!(rendered.contains("v > 7"), "rendered SQL: {rendered}");
1108 assert!(rendered.contains("LIMIT 3"), "rendered SQL: {rendered}");
1109 }
1110
1111 #[test]
1112 fn bind_statement_placeholders_rejects_out_of_range_index() {
1113 let mut stmt = parse_single_statement("SELECT $3", "single").expect("parse");
1114 let err = match bind_statement_placeholders(
1115 &mut stmt,
1116 &[SqlValue::Number("1".to_string(), false)],
1117 ) {
1118 Ok(_) => panic!("expected out-of-range placeholder failure"),
1119 Err(e) => e,
1120 };
1121 match err {
1122 PgWireError::UserError(info) => {
1123 assert_eq!(info.code, "22023");
1124 }
1125 other => panic!("unexpected error: {other:?}"),
1126 }
1127 }
1128
1129 #[test]
1130 fn parse_single_statement_rejects_multi_statement_sql() {
1131 let err = match parse_single_statement("SELECT 1; SELECT 2", "single") {
1132 Ok(_) => panic!("expected single-statement rejection"),
1133 Err(e) => e,
1134 };
1135 match err {
1136 PgWireError::UserError(info) => {
1137 assert_eq!(info.code, "0A000");
1138 }
1139 other => panic!("unexpected error: {other:?}"),
1140 }
1141 }
1142
1143 #[test]
1144 fn max_placeholder_index_finds_highest_parameter() {
1145 assert_eq!(max_placeholder_index("SELECT $1, $2, $10"), 10);
1146 assert_eq!(max_placeholder_index("SELECT 1"), 0);
1147 }
1148
1149 #[test]
1150 fn normalize_placeholders_for_describe_rewrites_all_numbered_params() {
1151 let sql = "SELECT id FROM t WHERE v > $1 ORDER BY id LIMIT $2";
1152 let normalized = normalize_placeholders_for_describe(sql);
1153 assert_eq!(
1154 normalized,
1155 "SELECT id FROM t WHERE v > 0 ORDER BY id LIMIT 0"
1156 );
1157 }
1158
1159 #[test]
1160 fn execution_tag_for_insert_matches_postgres_format() {
1161 let stmt = parse_single_statement("INSERT INTO t VALUES (1)", "single").expect("parse");
1162 let tag = execution_tag_for_statement(&stmt, 2);
1163 assert_eq!(tag, Tag::new("INSERT 0").with_rows(2));
1164 }
1165
1166 #[test]
1167 fn execution_tag_for_update_and_delete_include_row_count() {
1168 let update = parse_single_statement("UPDATE t SET v = 1", "single").expect("parse");
1169 let delete = parse_single_statement("DELETE FROM t", "single").expect("parse");
1170 assert_eq!(
1171 execution_tag_for_statement(&update, 3),
1172 Tag::new("UPDATE").with_rows(3)
1173 );
1174 assert_eq!(
1175 execution_tag_for_statement(&delete, 4),
1176 Tag::new("DELETE").with_rows(4)
1177 );
1178 }
1179
1180 #[test]
1181 fn handler_begin_commit_round_trip() {
1182 let handler = test_handler();
1183 let out = handler
1184 .execute_sql("BEGIN; COMMIT;", None)
1185 .expect("execute");
1186 assert_eq!(out.len(), 2);
1187 }
1188
1189 #[test]
1190 fn handler_enforces_max_statement_size() {
1191 let handler = test_handler();
1192 let huge = "X".repeat(2 * 1024 * 1024);
1193 let err = match handler.execute_sql(&huge, None) {
1194 Ok(_) => panic!("max statement size should be enforced"),
1195 Err(e) => e,
1196 };
1197 assert!(err.to_string().contains("max_statement_bytes"));
1198 }
1199
1200 proptest! {
1201 #[test]
1202 fn parse_sql_never_panics_on_random_input(input in ".*") {
1203 let _ = parse_sql(&input);
1204 }
1205 }
1206}