1use std::collections::HashMap;
9use std::sync::Arc;
10
11use tokio::io::{AsyncRead, AsyncWrite};
12use tokio::net::TcpListener;
13
14use super::catalog_views::translate_pg_catalog_query;
15use super::protocol::{
16 read_frame, read_startup, write_frame, write_raw_byte, BackendMessage, ColumnDescriptor,
17 DescribeTarget, FrontendMessage, PgWireError, TransactionStatus,
18};
19use super::types::{pg_param_to_value, value_to_pg_wire_bytes, PgOid};
20use crate::runtime::ai::ask_response_envelope::{
21 AskResult, Citation, Mode, SourceRow, Validation, ValidationError, ValidationWarning,
22};
23use crate::runtime::RedDBRuntime;
24use crate::storage::query::unified::{UnifiedRecord, UnifiedResult};
25use crate::storage::schema::Value;
26
27#[derive(Debug, Clone)]
29pub struct PgWireConfig {
30 pub bind_addr: String,
33 pub server_version: String,
37}
38
39#[derive(Debug, Clone)]
40struct PgPreparedStatement {
41 sql: String,
42 param_type_oids: Vec<u32>,
43}
44
45#[derive(Debug, Clone)]
46struct PgPortal {
47 sql: String,
48 params: Vec<Value>,
49 #[allow(dead_code)]
50 result_format_codes: Vec<i16>,
51 row_description_sent: bool,
52 described_result: Option<crate::runtime::RuntimeQueryResult>,
53}
54
55impl Default for PgWireConfig {
56 fn default() -> Self {
57 Self {
58 bind_addr: "127.0.0.1:5432".to_string(),
59 server_version: "15.0 (RedDB 3.1)".to_string(),
60 }
61 }
62}
63
64fn run_runtime_blocking<T>(f: impl FnOnce() -> T) -> T {
84 use tokio::runtime::{Handle, RuntimeFlavor};
85 match Handle::try_current().map(|h| h.runtime_flavor()) {
86 Ok(RuntimeFlavor::MultiThread) => tokio::task::block_in_place(f),
87 _ => f(),
88 }
89}
90
91pub async fn start_pg_wire_listener(
94 config: PgWireConfig,
95 runtime: Arc<RedDBRuntime>,
96) -> Result<(), Box<dyn std::error::Error>> {
97 let listener = TcpListener::bind(&config.bind_addr).await?;
98 tracing::info!(
99 transport = "pg-wire",
100 bind = %config.bind_addr,
101 "listener online"
102 );
103 let cfg = Arc::new(config);
104 loop {
105 let (stream, peer) = listener.accept().await?;
106 let rt = Arc::clone(&runtime);
107 let cfg = Arc::clone(&cfg);
108 let peer_str = peer.to_string();
109 tokio::spawn(async move {
110 if let Err(e) = handle_connection(stream, rt, cfg).await {
111 tracing::warn!(
112 transport = "pg-wire",
113 peer = %peer_str,
114 err = %e,
115 "connection failed"
116 );
117 }
118 });
119 }
120}
121
122pub(crate) async fn handle_connection<S>(
124 mut stream: S,
125 runtime: Arc<RedDBRuntime>,
126 config: Arc<PgWireConfig>,
127) -> Result<(), PgWireError>
128where
129 S: AsyncRead + AsyncWrite + Unpin + Send,
130{
131 loop {
136 match read_startup(&mut stream).await? {
137 FrontendMessage::SslRequest | FrontendMessage::GssEncRequest => {
138 write_raw_byte(&mut stream, b'N').await?;
141 continue;
142 }
143 FrontendMessage::Startup(params) => {
144 send_auth_ok(&mut stream, &config, ¶ms).await?;
145 break;
146 }
147 FrontendMessage::Unknown { .. } => {
148 return Ok(());
150 }
151 other => {
152 return Err(PgWireError::Protocol(format!(
153 "unexpected startup frame: {other:?}"
154 )));
155 }
156 }
157 }
158
159 let mut prepared: HashMap<String, PgPreparedStatement> = HashMap::new();
160 let mut portals: HashMap<String, PgPortal> = HashMap::new();
161
162 loop {
164 let frame = match read_frame(&mut stream).await {
165 Ok(f) => f,
166 Err(PgWireError::Eof) => return Ok(()),
167 Err(e) => return Err(e),
168 };
169
170 match frame {
171 FrontendMessage::Query(sql) => {
172 handle_simple_query(&mut stream, &runtime, &sql).await?;
173 }
174 FrontendMessage::Parse(msg) => {
175 handle_parse(&mut stream, &mut prepared, msg).await?;
176 }
177 FrontendMessage::Bind(msg) => {
178 handle_bind(&mut stream, &prepared, &mut portals, msg).await?;
179 }
180 FrontendMessage::Describe(msg) => {
181 handle_describe(&mut stream, &runtime, &prepared, &mut portals, msg).await?;
182 }
183 FrontendMessage::Execute(msg) => {
184 handle_execute(&mut stream, &runtime, &mut portals, msg).await?;
185 }
186 FrontendMessage::Close(msg) => {
187 handle_close(&mut stream, &mut prepared, &mut portals, msg).await?;
188 }
189 FrontendMessage::Terminate => return Ok(()),
190 FrontendMessage::Flush => {
191 continue;
194 }
195 FrontendMessage::Sync => {
196 write_frame(
197 &mut stream,
198 &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
199 )
200 .await?;
201 }
202 FrontendMessage::PasswordMessage(_) => {
203 continue;
205 }
206 FrontendMessage::Unknown { tag, .. } => {
207 send_error(
208 &mut stream,
209 "0A000",
210 &format!("unsupported frame tag 0x{tag:02x}"),
211 )
212 .await?;
213 write_frame(
214 &mut stream,
215 &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
216 )
217 .await?;
218 }
219 other => {
220 send_error(
221 &mut stream,
222 "0A000",
223 &format!("unsupported frame {other:?}"),
224 )
225 .await?;
226 write_frame(
227 &mut stream,
228 &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
229 )
230 .await?;
231 }
232 }
233 }
234}
235
236async fn handle_parse<S>(
237 stream: &mut S,
238 prepared: &mut HashMap<String, PgPreparedStatement>,
239 msg: super::protocol::ParseMessage,
240) -> Result<(), PgWireError>
241where
242 S: AsyncWrite + Unpin,
243{
244 let inferred_param_type_oids = infer_pg_cast_param_type_oids(&msg.query);
245 let sql = rewrite_pg_parameter_casts(&msg.query);
246 let parsed_param_count = match crate::storage::query::modes::parse_multi(&sql) {
247 Ok(parsed) => Some(
248 crate::storage::query::user_params::scan_parameters(&parsed)
249 .into_iter()
250 .map(|param| param.index + 1)
251 .max()
252 .unwrap_or(0),
253 ),
254 Err(err) => {
255 if pg_scalar_select_param_index(&sql).is_none() {
256 send_error(stream, "42601", &err.to_string()).await?;
257 return Ok(());
258 }
259 None
260 }
261 };
262 let mut param_type_oids = msg.param_type_oids;
263 if param_type_oids.is_empty() {
264 let count = parsed_param_count
265 .or_else(|| pg_scalar_select_param_index(&sql).map(|idx| idx + 1))
266 .unwrap_or(0);
267 param_type_oids.resize(count, PgOid::Unknown.as_u32());
268 }
269 for (idx, oid) in inferred_param_type_oids {
270 if idx >= param_type_oids.len() {
271 param_type_oids.resize(idx + 1, PgOid::Unknown.as_u32());
272 }
273 if param_type_oids[idx] == PgOid::Unknown.as_u32() {
274 param_type_oids[idx] = oid;
275 }
276 }
277 prepared.insert(
278 msg.statement,
279 PgPreparedStatement {
280 sql,
281 param_type_oids,
282 },
283 );
284 write_frame(stream, &BackendMessage::ParseComplete).await
285}
286
287async fn handle_bind<S>(
288 stream: &mut S,
289 prepared: &HashMap<String, PgPreparedStatement>,
290 portals: &mut HashMap<String, PgPortal>,
291 msg: super::protocol::BindMessage,
292) -> Result<(), PgWireError>
293where
294 S: AsyncWrite + Unpin,
295{
296 let Some(stmt) = prepared.get(&msg.statement) else {
297 send_error(
298 stream,
299 "26000",
300 &format!("prepared statement {:?} does not exist", msg.statement),
301 )
302 .await?;
303 return Ok(());
304 };
305 let params = match bind_pg_params(stmt, &msg) {
306 Ok(params) => params,
307 Err(err) => {
308 send_error(stream, "22023", &err).await?;
309 return Ok(());
310 }
311 };
312 portals.insert(
313 msg.portal,
314 PgPortal {
315 sql: stmt.sql.clone(),
316 params,
317 result_format_codes: msg.result_format_codes,
318 row_description_sent: false,
319 described_result: None,
320 },
321 );
322 write_frame(stream, &BackendMessage::BindComplete).await
323}
324
325async fn handle_describe<S>(
326 stream: &mut S,
327 runtime: &RedDBRuntime,
328 prepared: &HashMap<String, PgPreparedStatement>,
329 portals: &mut HashMap<String, PgPortal>,
330 msg: super::protocol::DescribeMessage,
331) -> Result<(), PgWireError>
332where
333 S: AsyncWrite + Unpin,
334{
335 match msg.target {
336 DescribeTarget::Statement => {
337 let Some(stmt) = prepared.get(&msg.name) else {
338 send_error(
339 stream,
340 "26000",
341 &format!("prepared statement {:?} does not exist", msg.name),
342 )
343 .await?;
344 return Ok(());
345 };
346 write_frame(
347 stream,
348 &BackendMessage::ParameterDescription(stmt.param_type_oids.clone()),
349 )
350 .await?;
351 if is_ask_query(&stmt.sql) {
352 emit_ask_row_description(stream).await
353 } else {
354 write_frame(stream, &BackendMessage::NoData).await
355 }
356 }
357 DescribeTarget::Portal => {
358 let Some(portal) = portals.get_mut(&msg.name) else {
359 send_error(
360 stream,
361 "34000",
362 &format!("portal {:?} does not exist", msg.name),
363 )
364 .await?;
365 return Ok(());
366 };
367 if is_ask_query(&portal.sql) {
368 emit_ask_row_description(stream).await?;
369 portal.row_description_sent = true;
370 Ok(())
371 } else if is_row_returning_query(&portal.sql) {
372 let result = match execute_pg_query_result(runtime, &portal.sql, &portal.params) {
373 Ok(result) => result,
374 Err(err) => {
375 let code = classify_sqlstate(&err);
376 send_error(stream, code, &err).await?;
377 return Ok(());
378 }
379 };
380 emit_row_description_for_result(stream, &result).await?;
381 portal.row_description_sent = true;
382 portal.described_result = Some(result);
383 Ok(())
384 } else {
385 write_frame(stream, &BackendMessage::NoData).await
386 }
387 }
388 }
389}
390
391async fn handle_execute<S>(
392 stream: &mut S,
393 runtime: &RedDBRuntime,
394 portals: &mut HashMap<String, PgPortal>,
395 msg: super::protocol::ExecuteMessage,
396) -> Result<(), PgWireError>
397where
398 S: AsyncWrite + Unpin,
399{
400 let Some(portal) = portals.get_mut(&msg.portal) else {
401 send_error(
402 stream,
403 "34000",
404 &format!("portal {:?} does not exist", msg.portal),
405 )
406 .await?;
407 return Ok(());
408 };
409 let _max_rows = msg.max_rows;
410 let was_described = portal.row_description_sent || portal.described_result.is_some();
411 portal.row_description_sent = false;
412 let result = match portal.described_result.take() {
413 Some(result) => Ok(result),
414 None => execute_pg_query_result(runtime, &portal.sql, &portal.params),
415 };
416 match result {
417 Ok(result) if was_described => {
418 emit_success_result_without_row_description(stream, &result).await
419 }
420 Ok(result) => emit_success_result(stream, &result).await,
421 Err(err) => {
422 let code = classify_sqlstate(&err);
423 send_error(stream, code, &err).await
424 }
425 }
426}
427
428async fn handle_close<S>(
429 stream: &mut S,
430 prepared: &mut HashMap<String, PgPreparedStatement>,
431 portals: &mut HashMap<String, PgPortal>,
432 msg: super::protocol::CloseMessage,
433) -> Result<(), PgWireError>
434where
435 S: AsyncWrite + Unpin,
436{
437 match msg.target {
438 DescribeTarget::Statement => {
439 prepared.remove(&msg.name);
440 }
441 DescribeTarget::Portal => {
442 portals.remove(&msg.name);
443 }
444 }
445 write_frame(stream, &BackendMessage::CloseComplete).await
446}
447
448fn bind_pg_params(
449 stmt: &PgPreparedStatement,
450 msg: &super::protocol::BindMessage,
451) -> Result<Vec<Value>, String> {
452 if !matches!(msg.param_format_codes.len(), 0 | 1)
453 && msg.param_format_codes.len() != msg.params.len()
454 {
455 return Err("Bind format count must be 0, 1, or match parameter count".to_string());
456 }
457 msg.params
458 .iter()
459 .enumerate()
460 .map(|(idx, param)| {
461 let oid = stmt
462 .param_type_oids
463 .get(idx)
464 .copied()
465 .unwrap_or(PgOid::Unknown.as_u32());
466 let format_code = match msg.param_format_codes.as_slice() {
467 [] => 0,
468 [format] => *format,
469 formats => formats[idx],
470 };
471 pg_param_to_value(oid, format_code, param.as_deref())
472 })
473 .collect()
474}
475
476fn execute_pg_query_result(
477 runtime: &RedDBRuntime,
478 sql: &str,
479 params: &[Value],
480) -> Result<crate::runtime::RuntimeQueryResult, String> {
481 if let Some(result) = try_execute_pg_scalar_select(sql, params) {
482 return Ok(result);
483 }
484 if params.is_empty() {
485 return match translate_pg_catalog_query(runtime, sql) {
486 Ok(Some(result)) => Ok(crate::runtime::RuntimeQueryResult {
487 query: sql.to_string(),
488 mode: crate::storage::query::modes::QueryMode::Sql,
489 statement: "select",
490 engine: "pg-catalog",
491 result,
492 affected_rows: 0,
493 statement_type: "select",
494 bookmark: None,
495 }),
496 Ok(None) => {
497 run_runtime_blocking(|| runtime.execute_query(sql)).map_err(|err| err.to_string())
498 }
499 Err(err) => Err(err.to_string()),
500 };
501 }
502
503 let parsed = crate::storage::query::modes::parse_multi(sql).map_err(|err| err.to_string())?;
504 let bound =
505 crate::storage::query::user_params::bind(&parsed, params).map_err(|err| err.to_string())?;
506 run_runtime_blocking(|| runtime.execute_query_expr(bound)).map_err(|err| err.to_string())
507}
508
509fn try_execute_pg_scalar_select(
510 sql: &str,
511 params: &[Value],
512) -> Option<crate::runtime::RuntimeQueryResult> {
513 let index = pg_scalar_select_param_index(sql)?;
514 let value = params.get(index)?.clone();
515 let mut result = UnifiedResult::with_columns(vec!["?column?".to_string()]);
516 let mut record = UnifiedRecord::new();
517 record.set("?column?", value);
518 result.push(record);
519 Some(crate::runtime::RuntimeQueryResult {
520 query: sql.to_string(),
521 mode: crate::storage::query::modes::QueryMode::Sql,
522 statement: "select",
523 engine: "pg-wire",
524 result,
525 affected_rows: 0,
526 statement_type: "select",
527 bookmark: None,
528 })
529}
530
531fn pg_scalar_select_param_index(sql: &str) -> Option<usize> {
532 let trimmed = sql.trim().trim_end_matches(';').trim();
533 let lower = trimmed.to_ascii_lowercase();
534 let body = lower.strip_prefix("select ")?;
535 let param = if let Some(inner) = body.strip_prefix("cast(") {
536 let end = inner.find(" as ")?;
537 &inner[..end]
538 } else {
539 body.split_whitespace().next()?
540 };
541 let digits = param.strip_prefix('$')?;
542 let n = digits.parse::<usize>().ok()?;
543 n.checked_sub(1)
544}
545
546fn rewrite_pg_parameter_casts(sql: &str) -> String {
547 let mut out = String::with_capacity(sql.len());
548 let bytes = sql.as_bytes();
549 let mut cursor = 0;
550 let mut pos = 0;
551 while pos < bytes.len() {
552 if bytes[pos] != b'$' {
553 pos += 1;
554 continue;
555 }
556 let param_start = pos;
557 pos += 1;
558 let digits_start = pos;
559 while pos < bytes.len() && bytes[pos].is_ascii_digit() {
560 pos += 1;
561 }
562 if digits_start == pos {
563 continue;
564 }
565 if pos + 2 <= bytes.len() && &bytes[pos..pos + 2] == b"::" {
566 let param_end = pos;
567 pos += 2;
568 let type_start = pos;
569 while pos < bytes.len()
570 && (bytes[pos].is_ascii_alphanumeric() || matches!(bytes[pos], b'_' | b'.'))
571 {
572 pos += 1;
573 }
574 if type_start != pos {
575 out.push_str(&sql[cursor..param_start]);
576 out.push_str(&sql[param_start..param_end]);
577 cursor = pos;
578 continue;
579 }
580 }
581 }
582 out.push_str(&sql[cursor..]);
583 out
584}
585
586fn infer_pg_cast_param_type_oids(sql: &str) -> Vec<(usize, u32)> {
587 let mut out = Vec::new();
588 let bytes = sql.as_bytes();
589 let mut pos = 0;
590 while pos < bytes.len() {
591 if bytes[pos] != b'$' {
592 pos += 1;
593 continue;
594 }
595 pos += 1;
596 let digits_start = pos;
597 while pos < bytes.len() && bytes[pos].is_ascii_digit() {
598 pos += 1;
599 }
600 if digits_start == pos {
601 continue;
602 }
603 let Some(param_index) = sql[digits_start..pos]
604 .parse::<usize>()
605 .ok()
606 .and_then(|idx| idx.checked_sub(1))
607 else {
608 continue;
609 };
610 if pos + 2 > bytes.len() || &bytes[pos..pos + 2] != b"::" {
611 continue;
612 }
613 pos += 2;
614 let type_start = pos;
615 while pos < bytes.len()
616 && (bytes[pos].is_ascii_alphanumeric() || matches!(bytes[pos], b'_' | b'.'))
617 {
618 pos += 1;
619 }
620 if type_start == pos {
621 continue;
622 }
623 if let Some(oid) = pg_cast_type_oid(&sql[type_start..pos]) {
624 out.push((param_index, oid));
625 }
626 }
627 out
628}
629
630fn pg_cast_type_oid(ty: &str) -> Option<u32> {
631 let lower = ty.to_ascii_lowercase();
632 let short = lower.rsplit('.').next().unwrap_or(lower.as_str());
633 let oid = match short {
634 "bool" | "boolean" => PgOid::Bool,
635 "int2" | "smallint" => PgOid::Int2,
636 "int" | "int4" | "integer" => PgOid::Int4,
637 "int8" | "bigint" => PgOid::Int8,
638 "float4" | "real" => PgOid::Float4,
639 "float8" | "double" | "doubleprecision" => PgOid::Float8,
640 "numeric" | "decimal" => PgOid::Numeric,
641 "bytea" => PgOid::Bytea,
642 "json" => PgOid::Json,
643 "jsonb" => PgOid::Jsonb,
644 "text" => PgOid::Text,
645 "varchar" | "character varying" => PgOid::Varchar,
646 "uuid" => PgOid::Uuid,
647 "timestamp" => PgOid::Timestamp,
648 "timestamptz" | "timestampz" => PgOid::TimestampTz,
649 "vector" => PgOid::Vector,
650 _ => return None,
651 };
652 Some(oid.as_u32())
653}
654
655fn is_row_returning_query(sql: &str) -> bool {
656 let trimmed = sql.trim_start().to_ascii_lowercase();
657 trimmed.starts_with("select")
658 || trimmed.starts_with("with")
659 || trimmed.starts_with("ask")
660 || trimmed.starts_with("search")
661 || trimmed.starts_with("vector")
662 || trimmed.starts_with("hybrid")
663}
664
665fn is_ask_query(sql: &str) -> bool {
666 sql.trim_start().to_ascii_lowercase().starts_with("ask")
667}
668
669async fn send_auth_ok<S>(
670 stream: &mut S,
671 config: &PgWireConfig,
672 params: &super::protocol::StartupParams,
673) -> Result<(), PgWireError>
674where
675 S: AsyncWrite + Unpin,
676{
677 write_frame(stream, &BackendMessage::AuthenticationOk).await?;
679
680 for (name, value) in [
682 ("server_version", config.server_version.as_str()),
683 ("server_encoding", "UTF8"),
684 ("client_encoding", "UTF8"),
685 ("DateStyle", "ISO, MDY"),
686 ("TimeZone", "UTC"),
687 ("integer_datetimes", "on"),
688 ("standard_conforming_strings", "on"),
689 (
690 "application_name",
691 params.get("application_name").unwrap_or(""),
692 ),
693 ] {
694 write_frame(
695 stream,
696 &BackendMessage::ParameterStatus {
697 name: name.to_string(),
698 value: value.to_string(),
699 },
700 )
701 .await?;
702 }
703
704 write_frame(
707 stream,
708 &BackendMessage::BackendKeyData {
709 pid: std::process::id(),
710 key: 0xDEADBEEF,
711 },
712 )
713 .await?;
714
715 write_frame(
716 stream,
717 &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
718 )
719 .await?;
720 Ok(())
721}
722
723async fn handle_simple_query<S>(
724 stream: &mut S,
725 runtime: &RedDBRuntime,
726 sql: &str,
727) -> Result<(), PgWireError>
728where
729 S: AsyncWrite + Unpin,
730{
731 if sql.trim().is_empty() {
734 write_frame(stream, &BackendMessage::EmptyQueryResponse).await?;
735 write_frame(
736 stream,
737 &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
738 )
739 .await?;
740 return Ok(());
741 }
742
743 if let Some(tag) = pg_session_compat_command_tag(sql) {
744 write_frame(stream, &BackendMessage::CommandComplete(tag.to_string())).await?;
745 write_frame(
746 stream,
747 &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
748 )
749 .await?;
750 return Ok(());
751 }
752
753 let query_result = match translate_pg_catalog_query(runtime, sql) {
754 Ok(Some(result)) => Ok(crate::runtime::RuntimeQueryResult {
755 query: sql.to_string(),
756 mode: crate::storage::query::modes::QueryMode::Sql,
757 statement: "select",
758 engine: "pg-catalog",
759 result,
760 affected_rows: 0,
761 statement_type: "select",
762 bookmark: None,
763 }),
764 Ok(None) => run_runtime_blocking(|| runtime.execute_query(sql)),
765 Err(err) => Err(err),
766 };
767
768 match query_result {
769 Ok(result) => {
770 emit_success_result(stream, &result).await?;
771 }
772 Err(err) => {
773 let code = classify_sqlstate(&err.to_string());
777 send_error(stream, code, &err.to_string()).await?;
778 }
779 }
780
781 write_frame(
782 stream,
783 &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
784 )
785 .await?;
786 Ok(())
787}
788
789fn pg_session_compat_command_tag(sql: &str) -> Option<&'static str> {
790 let lower = sql.trim().trim_end_matches(';').to_ascii_lowercase();
791 if lower.starts_with("set ") {
792 return Some("SET");
793 }
794 None
795}
796
797async fn emit_success_result<S>(
798 stream: &mut S,
799 result: &crate::runtime::RuntimeQueryResult,
800) -> Result<(), PgWireError>
801where
802 S: AsyncWrite + Unpin,
803{
804 if result.statement == "ask" {
805 emit_ask_result_row(stream, result).await?;
806 write_frame(
807 stream,
808 &BackendMessage::CommandComplete("SELECT 1".to_string()),
809 )
810 .await?;
811 } else if result_returns_rows(result) {
812 emit_result_rows(stream, &result.result).await?;
813 write_frame(
814 stream,
815 &BackendMessage::CommandComplete(format!("SELECT {}", result.result.records.len())),
816 )
817 .await?;
818 } else {
819 let tag = match result.statement_type {
824 "insert" => format!("INSERT 0 {}", result.affected_rows),
825 "update" => format!("UPDATE {}", result.affected_rows),
826 "delete" => format!("DELETE {}", result.affected_rows),
827 other => other.to_uppercase(),
828 };
829 write_frame(stream, &BackendMessage::CommandComplete(tag)).await?;
830 }
831 Ok(())
832}
833
834async fn emit_success_result_without_row_description<S>(
835 stream: &mut S,
836 result: &crate::runtime::RuntimeQueryResult,
837) -> Result<(), PgWireError>
838where
839 S: AsyncWrite + Unpin,
840{
841 if result.statement == "ask" {
842 let row = ask_query_result_to_pg_wire_row(result)
843 .ok_or_else(|| PgWireError::Protocol("ASK result missing row body".to_string()))?;
844 write_frame(stream, &BackendMessage::DataRow(row.cells)).await?;
845 write_frame(
846 stream,
847 &BackendMessage::CommandComplete("SELECT 1".to_string()),
848 )
849 .await?;
850 } else if result_returns_rows(result) {
851 emit_result_data_rows(stream, &result.result).await?;
852 write_frame(
853 stream,
854 &BackendMessage::CommandComplete(format!("SELECT {}", result.result.records.len())),
855 )
856 .await?;
857 } else {
858 let tag = match result.statement_type {
859 "insert" => format!("INSERT 0 {}", result.affected_rows),
860 "update" => format!("UPDATE {}", result.affected_rows),
861 "delete" => format!("DELETE {}", result.affected_rows),
862 other => other.to_uppercase(),
863 };
864 write_frame(stream, &BackendMessage::CommandComplete(tag)).await?;
865 }
866 Ok(())
867}
868
869async fn emit_row_description_for_result<S>(
870 stream: &mut S,
871 result: &crate::runtime::RuntimeQueryResult,
872) -> Result<(), PgWireError>
873where
874 S: AsyncWrite + Unpin,
875{
876 if result.statement == "ask" {
877 emit_ask_row_description(stream).await
878 } else if result_returns_rows(result) {
879 emit_result_row_description(stream, &result.result).await
880 } else {
881 write_frame(stream, &BackendMessage::NoData).await
882 }
883}
884
885fn result_returns_rows(result: &crate::runtime::RuntimeQueryResult) -> bool {
886 result.statement_type == "select"
887}
888
889async fn emit_result_rows<S>(
890 stream: &mut S,
891 result: &crate::storage::query::unified::UnifiedResult,
892) -> Result<(), PgWireError>
893where
894 S: AsyncWrite + Unpin,
895{
896 emit_result_row_description(stream, result).await?;
897 emit_result_data_rows(stream, result).await
898}
899
900async fn emit_result_row_description<S>(
901 stream: &mut S,
902 result: &crate::storage::query::unified::UnifiedResult,
903) -> Result<(), PgWireError>
904where
905 S: AsyncWrite + Unpin,
906{
907 let columns: Vec<String> = if !result.columns.is_empty() {
911 result.columns.clone()
912 } else if let Some(first) = result.records.first() {
913 record_field_names(first)
914 } else {
915 Vec::new()
916 };
917
918 let type_oids: Vec<PgOid> = columns
922 .iter()
923 .map(|col| {
924 result
925 .records
926 .first()
927 .and_then(|r| record_get(r, col))
928 .map(PgOid::from_value)
929 .unwrap_or(PgOid::Text)
930 })
931 .collect();
932
933 let descriptors: Vec<ColumnDescriptor> = columns
934 .iter()
935 .zip(type_oids.iter())
936 .map(|(name, oid)| ColumnDescriptor {
937 name: name.clone(),
938 table_oid: 0,
939 column_attr: 0,
940 type_oid: oid.as_u32(),
941 type_size: -1,
942 type_mod: -1,
943 format: 0,
944 })
945 .collect();
946
947 write_frame(stream, &BackendMessage::RowDescription(descriptors)).await
948}
949
950async fn emit_result_data_rows<S>(
951 stream: &mut S,
952 result: &crate::storage::query::unified::UnifiedResult,
953) -> Result<(), PgWireError>
954where
955 S: AsyncWrite + Unpin,
956{
957 let columns: Vec<String> = if !result.columns.is_empty() {
958 result.columns.clone()
959 } else if let Some(first) = result.records.first() {
960 record_field_names(first)
961 } else {
962 Vec::new()
963 };
964 for record in &result.records {
965 let fields: Vec<Option<Vec<u8>>> = columns
966 .iter()
967 .map(|col| record_get(record, col).and_then(value_to_pg_wire_bytes))
968 .collect();
969 write_frame(stream, &BackendMessage::DataRow(fields)).await?;
970 }
971
972 Ok(())
973}
974
975async fn emit_ask_result_row<S>(
976 stream: &mut S,
977 result: &crate::runtime::RuntimeQueryResult,
978) -> Result<(), PgWireError>
979where
980 S: AsyncWrite + Unpin,
981{
982 let row = ask_query_result_to_pg_wire_row(result)
983 .ok_or_else(|| PgWireError::Protocol("ASK result missing row body".to_string()))?;
984
985 emit_ask_row_description(stream).await?;
986 write_frame(stream, &BackendMessage::DataRow(row.cells)).await?;
987 Ok(())
988}
989
990async fn emit_ask_row_description<S>(stream: &mut S) -> Result<(), PgWireError>
991where
992 S: AsyncWrite + Unpin,
993{
994 let descriptors: Vec<ColumnDescriptor> = crate::runtime::ai::pg_wire_ask_row_encoder::columns()
995 .iter()
996 .map(|col| ColumnDescriptor {
997 name: col.name.to_string(),
998 table_oid: 0,
999 column_attr: 0,
1000 type_oid: col.oid.as_u32(),
1001 type_size: -1,
1002 type_mod: -1,
1003 format: 0,
1004 })
1005 .collect();
1006 write_frame(stream, &BackendMessage::RowDescription(descriptors)).await
1007}
1008
1009fn ask_query_result_to_pg_wire_row(
1010 result: &crate::runtime::RuntimeQueryResult,
1011) -> Option<crate::runtime::ai::pg_wire_ask_row_encoder::AskRow> {
1012 if result.statement != "ask" {
1013 return None;
1014 }
1015 let record = result.result.records.first()?;
1016 let sources_flat_json =
1017 json_field(record, "sources_flat").unwrap_or(crate::json::Value::Array(Vec::new()));
1018 let citations_json =
1019 json_field(record, "citations").unwrap_or(crate::json::Value::Array(Vec::new()));
1020 let validation_json = json_field(record, "validation")
1021 .unwrap_or_else(|| crate::json::Value::Object(Default::default()));
1022
1023 let effective_mode = match text_field(record, "mode").as_deref() {
1024 Some("lenient") => Mode::Lenient,
1025 _ => Mode::Strict,
1026 };
1027
1028 let ask = AskResult {
1029 answer: text_field(record, "answer")?,
1030 sources_flat: ask_sources_flat(&sources_flat_json),
1031 citations: ask_citations(&citations_json),
1032 validation: ask_validation(&validation_json),
1033 cache_hit: bool_field(record, "cache_hit").unwrap_or(false),
1034 provider: text_field(record, "provider").unwrap_or_default(),
1035 model: text_field(record, "model").unwrap_or_default(),
1036 prompt_tokens: u32_field(record, "prompt_tokens").unwrap_or(0),
1037 completion_tokens: u32_field(record, "completion_tokens").unwrap_or(0),
1038 cost_usd: f64_field(record, "cost_usd").unwrap_or(0.0),
1039 effective_mode,
1040 retry_count: u32_field(record, "retry_count").unwrap_or(0),
1041 };
1042
1043 Some(crate::runtime::ai::pg_wire_ask_row_encoder::encode(&ask))
1044}
1045
1046fn record_field<'a>(record: &'a UnifiedRecord, key: &str) -> Option<&'a Value> {
1047 record.iter_fields().find_map(|(name, value)| {
1048 let name: &str = name;
1049 (name == key).then_some(value)
1050 })
1051}
1052
1053fn text_field(record: &UnifiedRecord, key: &str) -> Option<String> {
1054 match record_field(record, key)? {
1055 Value::Text(s) => Some(s.to_string()),
1056 Value::Email(s) | Value::Url(s) | Value::NodeRef(s) | Value::EdgeRef(s) => Some(s.clone()),
1057 other => Some(other.to_string()),
1058 }
1059}
1060
1061fn bool_field(record: &UnifiedRecord, key: &str) -> Option<bool> {
1062 match record_field(record, key)? {
1063 Value::Boolean(value) => Some(*value),
1064 _ => None,
1065 }
1066}
1067
1068fn u32_field(record: &UnifiedRecord, key: &str) -> Option<u32> {
1069 match record_field(record, key)? {
1070 Value::Integer(n) => (*n >= 0).then_some((*n).min(u32::MAX as i64) as u32),
1071 Value::UnsignedInteger(n) => Some((*n).min(u32::MAX as u64) as u32),
1072 Value::BigInt(n)
1073 | Value::TimestampMs(n)
1074 | Value::Timestamp(n)
1075 | Value::Duration(n)
1076 | Value::Decimal(n) => (*n >= 0).then_some((*n).min(u32::MAX as i64) as u32),
1077 Value::Float(n) => (*n >= 0.0).then_some((*n).min(u32::MAX as f64) as u32),
1078 _ => None,
1079 }
1080}
1081
1082fn f64_field(record: &UnifiedRecord, key: &str) -> Option<f64> {
1083 match record_field(record, key)? {
1084 Value::Integer(n) => Some(*n as f64),
1085 Value::UnsignedInteger(n) => Some(*n as f64),
1086 Value::BigInt(n)
1087 | Value::TimestampMs(n)
1088 | Value::Timestamp(n)
1089 | Value::Duration(n)
1090 | Value::Decimal(n) => Some(*n as f64),
1091 Value::Float(n) => Some(*n),
1092 _ => None,
1093 }
1094}
1095
1096fn json_field(record: &UnifiedRecord, key: &str) -> Option<crate::json::Value> {
1097 match record_field(record, key)? {
1098 Value::Json(bytes) => crate::json::from_slice(bytes).ok(),
1099 Value::Text(text) => crate::json::from_str(text).ok(),
1100 _ => None,
1101 }
1102}
1103
1104fn ask_sources_flat(value: &crate::json::Value) -> Vec<SourceRow> {
1105 value
1106 .as_array()
1107 .unwrap_or(&[])
1108 .iter()
1109 .filter_map(|source| {
1110 let urn = source
1111 .get("urn")
1112 .and_then(crate::json::Value::as_str)?
1113 .to_string();
1114 let payload = source
1115 .get("payload")
1116 .and_then(crate::json::Value::as_str)
1117 .map(ToString::to_string)
1118 .unwrap_or_else(|| source.to_string_compact());
1119 Some(SourceRow { urn, payload })
1120 })
1121 .collect()
1122}
1123
1124fn ask_citations(value: &crate::json::Value) -> Vec<Citation> {
1125 value
1126 .as_array()
1127 .unwrap_or(&[])
1128 .iter()
1129 .filter_map(|citation| {
1130 let marker = citation
1131 .get("marker")
1132 .and_then(crate::json::Value::as_u64)?;
1133 let urn = citation
1134 .get("urn")
1135 .and_then(crate::json::Value::as_str)?
1136 .to_string();
1137 Some(Citation {
1138 marker: marker.min(u32::MAX as u64) as u32,
1139 urn,
1140 })
1141 })
1142 .collect()
1143}
1144
1145fn ask_validation(value: &crate::json::Value) -> Validation {
1146 Validation {
1147 ok: value
1148 .get("ok")
1149 .and_then(crate::json::Value::as_bool)
1150 .unwrap_or(true),
1151 warnings: validation_items(value, "warnings")
1152 .into_iter()
1153 .map(|(kind, detail)| ValidationWarning { kind, detail })
1154 .collect(),
1155 errors: validation_items(value, "errors")
1156 .into_iter()
1157 .map(|(kind, detail)| ValidationError { kind, detail })
1158 .collect(),
1159 }
1160}
1161
1162fn validation_items(value: &crate::json::Value, key: &str) -> Vec<(String, String)> {
1163 value
1164 .get(key)
1165 .and_then(crate::json::Value::as_array)
1166 .unwrap_or(&[])
1167 .iter()
1168 .filter_map(|item| {
1169 Some((
1170 item.get("kind")
1171 .and_then(crate::json::Value::as_str)?
1172 .to_string(),
1173 item.get("detail")
1174 .and_then(crate::json::Value::as_str)
1175 .unwrap_or("")
1176 .to_string(),
1177 ))
1178 })
1179 .collect()
1180}
1181
1182fn record_get<'a>(record: &'a UnifiedRecord, key: &str) -> Option<&'a Value> {
1186 record.get(key)
1187}
1188
1189fn record_field_names(record: &UnifiedRecord) -> Vec<String> {
1198 record
1202 .column_names()
1203 .into_iter()
1204 .map(|k| k.to_string())
1205 .collect()
1206}
1207
1208async fn send_error<S>(stream: &mut S, code: &str, message: &str) -> Result<(), PgWireError>
1209where
1210 S: AsyncWrite + Unpin,
1211{
1212 write_frame(
1213 stream,
1214 &BackendMessage::ErrorResponse {
1215 severity: "ERROR".to_string(),
1216 code: code.to_string(),
1217 message: message.to_string(),
1218 },
1219 )
1220 .await
1221}
1222
1223fn classify_sqlstate(msg: &str) -> &'static str {
1227 let lower = msg.to_ascii_lowercase();
1228 if lower.contains("not found") || lower.contains("does not exist") {
1229 "42P01"
1231 } else if lower.contains("parse") || lower.contains("expected") || lower.contains("syntax") {
1232 "42601"
1233 } else if lower.contains("already exists") {
1234 "42P07"
1235 } else if lower.contains("permission") || lower.contains("auth") {
1236 "28000"
1237 } else {
1238 "XX000"
1239 }
1240}
1241
1242#[cfg(test)]
1243mod tests {
1244 use super::*;
1245 use crate::api::RedDBOptions;
1246 use crate::runtime::RuntimeQueryResult;
1247 use crate::storage::query::modes::QueryMode;
1248 use crate::storage::query::unified::UnifiedResult;
1249 use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
1250
1251 #[tokio::test]
1252 async fn extended_parse_bind_execute_returns_rows() {
1253 let runtime = Arc::new(RedDBRuntime::with_options(RedDBOptions::in_memory()).unwrap());
1254 let config = Arc::new(PgWireConfig::default());
1255 let (server_io, mut client_io) = tokio::io::duplex(64 * 1024);
1256 let server = tokio::spawn(async move {
1257 handle_connection(server_io, runtime, config).await.unwrap();
1258 });
1259
1260 write_startup(&mut client_io).await;
1261 read_until_ready(&mut client_io).await;
1262
1263 write_frontend_frame(
1264 &mut client_io,
1265 b'P',
1266 parse_body("", "SELECT $1::int", &[PgOid::Int4.as_u32()]),
1267 )
1268 .await;
1269 write_frontend_frame(
1270 &mut client_io,
1271 b'B',
1272 bind_body("", "", &[0], &[Some(b"42".as_slice())], &[]),
1273 )
1274 .await;
1275 write_frontend_frame(&mut client_io, b'D', describe_body(b'P', "")).await;
1276 write_frontend_frame(&mut client_io, b'E', execute_body("", 0)).await;
1277 write_frontend_frame(&mut client_io, b'S', Vec::new()).await;
1278
1279 let frames = read_until_ready(&mut client_io).await;
1280 assert_eq!(
1281 frames.iter().map(|(tag, _)| *tag).collect::<Vec<_>>(),
1282 b"12TDCZ"
1283 );
1284 let columns = decode_row_description(&frames[2].1);
1285 assert_eq!(columns.len(), 1);
1286 let cells = decode_data_row(&frames[3].1);
1287 assert_eq!(cells.len(), 1);
1288 assert_eq!(cells[0].as_deref(), Some(b"42".as_slice()));
1289 assert_eq!(decode_command_complete(&frames[4].1), "SELECT 1");
1290
1291 write_frontend_frame(&mut client_io, b'X', Vec::new()).await;
1292 server.await.unwrap();
1293 }
1294
1295 #[test]
1296 fn infer_pg_cast_param_type_oids_from_parameter_casts() {
1297 assert_eq!(
1298 infer_pg_cast_param_type_oids("INSERT INTO t (id, name) VALUES ($1::int, $2::text)"),
1299 vec![(0, PgOid::Int4.as_u32()), (1, PgOid::Text.as_u32())]
1300 );
1301 assert_eq!(
1302 infer_pg_cast_param_type_oids("SEARCH SIMILAR [1.0] COLLECTION v LIMIT $1::int8"),
1303 vec![(0, PgOid::Int8.as_u32())]
1304 );
1305 }
1306
1307 #[test]
1308 fn pg_session_compat_accepts_driver_setup_set_commands() {
1309 assert_eq!(
1310 pg_session_compat_command_tag("SET extra_float_digits = 3"),
1311 Some("SET")
1312 );
1313 assert_eq!(
1314 pg_session_compat_command_tag("SET application_name = 'pgjdbc'"),
1315 Some("SET")
1316 );
1317 assert_eq!(pg_session_compat_command_tag("SELECT 1"), None);
1318 }
1319
1320 #[tokio::test]
1321 async fn ask_success_result_uses_canonical_pg_wire_row_shape() {
1322 let mut result = UnifiedResult::with_columns(vec![
1323 "answer".into(),
1324 "provider".into(),
1325 "model".into(),
1326 "prompt_tokens".into(),
1327 "completion_tokens".into(),
1328 "sources_count".into(),
1329 "sources_flat".into(),
1330 "citations".into(),
1331 "validation".into(),
1332 ]);
1333 let mut record = UnifiedRecord::new();
1334 record.set("answer", Value::text("Deploy failed [^1]."));
1335 record.set("provider", Value::text("openai"));
1336 record.set("model", Value::text("gpt-4o-mini"));
1337 record.set("prompt_tokens", Value::Integer(11));
1338 record.set("completion_tokens", Value::Integer(7));
1339 record.set(
1340 "sources_flat",
1341 Value::Json(
1342 br#"[{"urn":"urn:reddb:row:deployments:1","kind":"row","collection":"deployments","id":"1"}]"#
1343 .to_vec(),
1344 ),
1345 );
1346 record.set(
1347 "citations",
1348 Value::Json(br#"[{"marker":1,"urn":"urn:reddb:row:deployments:1"}]"#.to_vec()),
1349 );
1350 record.set(
1351 "validation",
1352 Value::Json(br#"{"ok":true,"warnings":[],"errors":[]}"#.to_vec()),
1353 );
1354 result.push(record);
1355
1356 let qr = RuntimeQueryResult {
1357 query: "ASK 'why did deploy fail?'".to_string(),
1358 mode: QueryMode::Sql,
1359 statement: "ask",
1360 engine: "runtime-ai",
1361 result,
1362 affected_rows: 0,
1363 statement_type: "select",
1364 bookmark: None,
1365 };
1366
1367 let mut out = Vec::new();
1368 emit_success_result(&mut out, &qr).await.unwrap();
1369 let frames = decode_frames(&out);
1370
1371 assert_eq!(
1372 frames.iter().map(|(tag, _)| *tag).collect::<Vec<_>>(),
1373 b"TDC"
1374 );
1375
1376 let columns = decode_row_description(frames[0].1);
1377 assert_eq!(
1378 columns,
1379 vec![
1380 ("answer".to_string(), PgOid::Text.as_u32()),
1381 ("cache_hit".to_string(), PgOid::Bool.as_u32()),
1382 ("citations".to_string(), PgOid::Jsonb.as_u32()),
1383 ("completion_tokens".to_string(), PgOid::Int8.as_u32()),
1384 ("cost_usd".to_string(), PgOid::Numeric.as_u32()),
1385 ("mode".to_string(), PgOid::Text.as_u32()),
1386 ("model".to_string(), PgOid::Text.as_u32()),
1387 ("prompt_tokens".to_string(), PgOid::Int8.as_u32()),
1388 ("provider".to_string(), PgOid::Text.as_u32()),
1389 ("retry_count".to_string(), PgOid::Int8.as_u32()),
1390 ("sources_flat".to_string(), PgOid::Jsonb.as_u32()),
1391 ("validation".to_string(), PgOid::Jsonb.as_u32()),
1392 ]
1393 );
1394
1395 let cells = decode_data_row(frames[1].1);
1396 assert_eq!(cells.len(), 12);
1397 assert_eq!(cells[0].as_deref(), Some(b"Deploy failed [^1].".as_slice()));
1398 assert_eq!(cells[1].as_deref(), Some(b"f".as_slice()));
1399 assert_eq!(cells[4].as_deref(), Some(b"0".as_slice()));
1400 assert_eq!(cells[5].as_deref(), Some(b"strict".as_slice()));
1401 assert_eq!(cells[9].as_deref(), Some(b"0".as_slice()));
1402 assert!(std::str::from_utf8(cells[10].as_deref().unwrap())
1403 .unwrap()
1404 .contains(r#""payload""#));
1405 assert_eq!(decode_command_complete(frames[2].1), "SELECT 1");
1406 }
1407
1408 fn decode_frames(bytes: &[u8]) -> Vec<(u8, &[u8])> {
1409 let mut pos = 0;
1410 let mut frames = Vec::new();
1411 while pos < bytes.len() {
1412 let tag = bytes[pos];
1413 let len = u32::from_be_bytes([
1414 bytes[pos + 1],
1415 bytes[pos + 2],
1416 bytes[pos + 3],
1417 bytes[pos + 4],
1418 ]) as usize;
1419 let body_start = pos + 5;
1420 let body_end = pos + 1 + len;
1421 frames.push((tag, &bytes[body_start..body_end]));
1422 pos = body_end;
1423 }
1424 frames
1425 }
1426
1427 fn decode_row_description(body: &[u8]) -> Vec<(String, u32)> {
1428 let count = i16::from_be_bytes([body[0], body[1]]) as usize;
1429 let mut pos = 2;
1430 let mut columns = Vec::with_capacity(count);
1431 for _ in 0..count {
1432 let end = body[pos..].iter().position(|&b| b == 0).unwrap() + pos;
1433 let name = std::str::from_utf8(&body[pos..end]).unwrap().to_string();
1434 pos = end + 1;
1435 pos += 4; pos += 2; let oid = u32::from_be_bytes([body[pos], body[pos + 1], body[pos + 2], body[pos + 3]]);
1438 pos += 4;
1439 pos += 2; pos += 4; pos += 2; columns.push((name, oid));
1443 }
1444 columns
1445 }
1446
1447 fn decode_data_row(body: &[u8]) -> Vec<Option<Vec<u8>>> {
1448 let count = i16::from_be_bytes([body[0], body[1]]) as usize;
1449 let mut pos = 2;
1450 let mut cells = Vec::with_capacity(count);
1451 for _ in 0..count {
1452 let len = i32::from_be_bytes([body[pos], body[pos + 1], body[pos + 2], body[pos + 3]]);
1453 pos += 4;
1454 if len < 0 {
1455 cells.push(None);
1456 } else {
1457 let len = len as usize;
1458 cells.push(Some(body[pos..pos + len].to_vec()));
1459 pos += len;
1460 }
1461 }
1462 cells
1463 }
1464
1465 fn decode_command_complete(body: &[u8]) -> &str {
1466 let nul = body.iter().position(|&b| b == 0).unwrap_or(body.len());
1467 std::str::from_utf8(&body[..nul]).unwrap()
1468 }
1469
1470 async fn write_startup<W: AsyncWrite + Unpin>(stream: &mut W) {
1471 let mut payload = Vec::new();
1472 payload.extend_from_slice(&crate::wire::postgres::protocol::PG_PROTOCOL_V3.to_be_bytes());
1473 payload.extend_from_slice(b"user\0reddb\0");
1474 payload.push(0);
1475 let len = (payload.len() + 4) as u32;
1476 stream.write_all(&len.to_be_bytes()).await.unwrap();
1477 stream.write_all(&payload).await.unwrap();
1478 }
1479
1480 async fn write_frontend_frame<W: AsyncWrite + Unpin>(
1481 stream: &mut W,
1482 tag: u8,
1483 payload: Vec<u8>,
1484 ) {
1485 stream.write_all(&[tag]).await.unwrap();
1486 stream
1487 .write_all(&((payload.len() + 4) as u32).to_be_bytes())
1488 .await
1489 .unwrap();
1490 stream.write_all(&payload).await.unwrap();
1491 }
1492
1493 async fn read_backend_frame<R: AsyncRead + Unpin>(stream: &mut R) -> (u8, Vec<u8>) {
1494 let mut tag = [0u8; 1];
1495 stream.read_exact(&mut tag).await.unwrap();
1496 let mut len = [0u8; 4];
1497 stream.read_exact(&mut len).await.unwrap();
1498 let len = u32::from_be_bytes(len) as usize;
1499 let mut body = vec![0u8; len - 4];
1500 stream.read_exact(&mut body).await.unwrap();
1501 (tag[0], body)
1502 }
1503
1504 async fn read_until_ready<R: AsyncRead + Unpin>(stream: &mut R) -> Vec<(u8, Vec<u8>)> {
1505 let mut frames = Vec::new();
1506 loop {
1507 let frame = read_backend_frame(stream).await;
1508 let done = frame.0 == b'Z';
1509 frames.push(frame);
1510 if done {
1511 return frames;
1512 }
1513 }
1514 }
1515
1516 fn parse_body(statement: &str, query: &str, oids: &[u32]) -> Vec<u8> {
1517 let mut out = Vec::new();
1518 push_pg_cstring(&mut out, statement);
1519 push_pg_cstring(&mut out, query);
1520 out.extend_from_slice(&(oids.len() as i16).to_be_bytes());
1521 for oid in oids {
1522 out.extend_from_slice(&oid.to_be_bytes());
1523 }
1524 out
1525 }
1526
1527 fn bind_body(
1528 portal: &str,
1529 statement: &str,
1530 formats: &[i16],
1531 params: &[Option<&[u8]>],
1532 result_formats: &[i16],
1533 ) -> Vec<u8> {
1534 let mut out = Vec::new();
1535 push_pg_cstring(&mut out, portal);
1536 push_pg_cstring(&mut out, statement);
1537 out.extend_from_slice(&(formats.len() as i16).to_be_bytes());
1538 for format in formats {
1539 out.extend_from_slice(&format.to_be_bytes());
1540 }
1541 out.extend_from_slice(&(params.len() as i16).to_be_bytes());
1542 for param in params {
1543 match param {
1544 Some(bytes) => {
1545 out.extend_from_slice(&(bytes.len() as i32).to_be_bytes());
1546 out.extend_from_slice(bytes);
1547 }
1548 None => out.extend_from_slice(&(-1i32).to_be_bytes()),
1549 }
1550 }
1551 out.extend_from_slice(&(result_formats.len() as i16).to_be_bytes());
1552 for format in result_formats {
1553 out.extend_from_slice(&format.to_be_bytes());
1554 }
1555 out
1556 }
1557
1558 fn describe_body(target: u8, name: &str) -> Vec<u8> {
1559 let mut out = vec![target];
1560 push_pg_cstring(&mut out, name);
1561 out
1562 }
1563
1564 fn execute_body(portal: &str, max_rows: u32) -> Vec<u8> {
1565 let mut out = Vec::new();
1566 push_pg_cstring(&mut out, portal);
1567 out.extend_from_slice(&max_rows.to_be_bytes());
1568 out
1569 }
1570
1571 fn push_pg_cstring(out: &mut Vec<u8>, value: &str) {
1572 out.extend_from_slice(value.as_bytes());
1573 out.push(0);
1574 }
1575}