1use std::collections::HashMap;
9use std::sync::Arc;
10
11use tokio::io::{AsyncRead, AsyncWrite};
12use tokio::net::TcpListener;
13
14use super::catalog_views::translate_pg_catalog_query;
15use super::protocol::{
16 read_frame, read_startup, write_frame, write_raw_byte, BackendMessage, ColumnDescriptor,
17 DescribeTarget, FrontendMessage, PgWireError, TransactionStatus,
18};
19use super::types::{pg_param_to_value, value_to_pg_wire_bytes, PgOid};
20use crate::runtime::ai::ask_response_envelope::{
21 AskResult, Citation, Mode, SourceRow, Validation, ValidationError, ValidationWarning,
22};
23use crate::runtime::RedDBRuntime;
24use crate::storage::query::unified::{UnifiedRecord, UnifiedResult};
25use crate::storage::schema::Value;
26
27#[derive(Debug, Clone)]
29pub struct PgWireConfig {
30 pub bind_addr: String,
33 pub server_version: String,
37}
38
39#[derive(Debug, Clone)]
40struct PgPreparedStatement {
41 sql: String,
42 param_type_oids: Vec<u32>,
43}
44
45#[derive(Debug, Clone)]
46struct PgPortal {
47 sql: String,
48 params: Vec<Value>,
49 #[allow(dead_code)]
50 result_format_codes: Vec<i16>,
51 row_description_sent: bool,
52 described_result: Option<crate::runtime::RuntimeQueryResult>,
53}
54
55impl Default for PgWireConfig {
56 fn default() -> Self {
57 Self {
58 bind_addr: "127.0.0.1:5432".to_string(),
59 server_version: "15.0 (RedDB 3.1)".to_string(),
60 }
61 }
62}
63
64pub async fn start_pg_wire_listener(
67 config: PgWireConfig,
68 runtime: Arc<RedDBRuntime>,
69) -> Result<(), Box<dyn std::error::Error>> {
70 let listener = TcpListener::bind(&config.bind_addr).await?;
71 tracing::info!(
72 transport = "pg-wire",
73 bind = %config.bind_addr,
74 "listener online"
75 );
76 let cfg = Arc::new(config);
77 loop {
78 let (stream, peer) = listener.accept().await?;
79 let rt = Arc::clone(&runtime);
80 let cfg = Arc::clone(&cfg);
81 let peer_str = peer.to_string();
82 tokio::spawn(async move {
83 if let Err(e) = handle_connection(stream, rt, cfg).await {
84 tracing::warn!(
85 transport = "pg-wire",
86 peer = %peer_str,
87 err = %e,
88 "connection failed"
89 );
90 }
91 });
92 }
93}
94
95pub(crate) async fn handle_connection<S>(
97 mut stream: S,
98 runtime: Arc<RedDBRuntime>,
99 config: Arc<PgWireConfig>,
100) -> Result<(), PgWireError>
101where
102 S: AsyncRead + AsyncWrite + Unpin + Send,
103{
104 loop {
109 match read_startup(&mut stream).await? {
110 FrontendMessage::SslRequest | FrontendMessage::GssEncRequest => {
111 write_raw_byte(&mut stream, b'N').await?;
114 continue;
115 }
116 FrontendMessage::Startup(params) => {
117 send_auth_ok(&mut stream, &config, ¶ms).await?;
118 break;
119 }
120 FrontendMessage::Unknown { .. } => {
121 return Ok(());
123 }
124 other => {
125 return Err(PgWireError::Protocol(format!(
126 "unexpected startup frame: {other:?}"
127 )));
128 }
129 }
130 }
131
132 let mut prepared: HashMap<String, PgPreparedStatement> = HashMap::new();
133 let mut portals: HashMap<String, PgPortal> = HashMap::new();
134
135 loop {
137 let frame = match read_frame(&mut stream).await {
138 Ok(f) => f,
139 Err(PgWireError::Eof) => return Ok(()),
140 Err(e) => return Err(e),
141 };
142
143 match frame {
144 FrontendMessage::Query(sql) => {
145 handle_simple_query(&mut stream, &runtime, &sql).await?;
146 }
147 FrontendMessage::Parse(msg) => {
148 handle_parse(&mut stream, &mut prepared, msg).await?;
149 }
150 FrontendMessage::Bind(msg) => {
151 handle_bind(&mut stream, &prepared, &mut portals, msg).await?;
152 }
153 FrontendMessage::Describe(msg) => {
154 handle_describe(&mut stream, &runtime, &prepared, &mut portals, msg).await?;
155 }
156 FrontendMessage::Execute(msg) => {
157 handle_execute(&mut stream, &runtime, &mut portals, msg).await?;
158 }
159 FrontendMessage::Close(msg) => {
160 handle_close(&mut stream, &mut prepared, &mut portals, msg).await?;
161 }
162 FrontendMessage::Terminate => return Ok(()),
163 FrontendMessage::Flush => {
164 continue;
167 }
168 FrontendMessage::Sync => {
169 write_frame(
170 &mut stream,
171 &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
172 )
173 .await?;
174 }
175 FrontendMessage::PasswordMessage(_) => {
176 continue;
178 }
179 FrontendMessage::Unknown { tag, .. } => {
180 send_error(
181 &mut stream,
182 "0A000",
183 &format!("unsupported frame tag 0x{tag:02x}"),
184 )
185 .await?;
186 write_frame(
187 &mut stream,
188 &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
189 )
190 .await?;
191 }
192 other => {
193 send_error(
194 &mut stream,
195 "0A000",
196 &format!("unsupported frame {other:?}"),
197 )
198 .await?;
199 write_frame(
200 &mut stream,
201 &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
202 )
203 .await?;
204 }
205 }
206 }
207}
208
209async fn handle_parse<S>(
210 stream: &mut S,
211 prepared: &mut HashMap<String, PgPreparedStatement>,
212 msg: super::protocol::ParseMessage,
213) -> Result<(), PgWireError>
214where
215 S: AsyncWrite + Unpin,
216{
217 let inferred_param_type_oids = infer_pg_cast_param_type_oids(&msg.query);
218 let sql = rewrite_pg_parameter_casts(&msg.query);
219 let parsed_param_count = match crate::storage::query::modes::parse_multi(&sql) {
220 Ok(parsed) => Some(
221 crate::storage::query::user_params::scan_parameters(&parsed)
222 .into_iter()
223 .map(|param| param.index + 1)
224 .max()
225 .unwrap_or(0),
226 ),
227 Err(err) => {
228 if pg_scalar_select_param_index(&sql).is_none() {
229 send_error(stream, "42601", &err.to_string()).await?;
230 return Ok(());
231 }
232 None
233 }
234 };
235 let mut param_type_oids = msg.param_type_oids;
236 if param_type_oids.is_empty() {
237 let count = parsed_param_count
238 .or_else(|| pg_scalar_select_param_index(&sql).map(|idx| idx + 1))
239 .unwrap_or(0);
240 param_type_oids.resize(count, PgOid::Unknown.as_u32());
241 }
242 for (idx, oid) in inferred_param_type_oids {
243 if idx >= param_type_oids.len() {
244 param_type_oids.resize(idx + 1, PgOid::Unknown.as_u32());
245 }
246 if param_type_oids[idx] == PgOid::Unknown.as_u32() {
247 param_type_oids[idx] = oid;
248 }
249 }
250 prepared.insert(
251 msg.statement,
252 PgPreparedStatement {
253 sql,
254 param_type_oids,
255 },
256 );
257 write_frame(stream, &BackendMessage::ParseComplete).await
258}
259
260async fn handle_bind<S>(
261 stream: &mut S,
262 prepared: &HashMap<String, PgPreparedStatement>,
263 portals: &mut HashMap<String, PgPortal>,
264 msg: super::protocol::BindMessage,
265) -> Result<(), PgWireError>
266where
267 S: AsyncWrite + Unpin,
268{
269 let Some(stmt) = prepared.get(&msg.statement) else {
270 send_error(
271 stream,
272 "26000",
273 &format!("prepared statement {:?} does not exist", msg.statement),
274 )
275 .await?;
276 return Ok(());
277 };
278 let params = match bind_pg_params(stmt, &msg) {
279 Ok(params) => params,
280 Err(err) => {
281 send_error(stream, "22023", &err).await?;
282 return Ok(());
283 }
284 };
285 portals.insert(
286 msg.portal,
287 PgPortal {
288 sql: stmt.sql.clone(),
289 params,
290 result_format_codes: msg.result_format_codes,
291 row_description_sent: false,
292 described_result: None,
293 },
294 );
295 write_frame(stream, &BackendMessage::BindComplete).await
296}
297
298async fn handle_describe<S>(
299 stream: &mut S,
300 runtime: &RedDBRuntime,
301 prepared: &HashMap<String, PgPreparedStatement>,
302 portals: &mut HashMap<String, PgPortal>,
303 msg: super::protocol::DescribeMessage,
304) -> Result<(), PgWireError>
305where
306 S: AsyncWrite + Unpin,
307{
308 match msg.target {
309 DescribeTarget::Statement => {
310 let Some(stmt) = prepared.get(&msg.name) else {
311 send_error(
312 stream,
313 "26000",
314 &format!("prepared statement {:?} does not exist", msg.name),
315 )
316 .await?;
317 return Ok(());
318 };
319 write_frame(
320 stream,
321 &BackendMessage::ParameterDescription(stmt.param_type_oids.clone()),
322 )
323 .await?;
324 if is_ask_query(&stmt.sql) {
325 emit_ask_row_description(stream).await
326 } else {
327 write_frame(stream, &BackendMessage::NoData).await
328 }
329 }
330 DescribeTarget::Portal => {
331 let Some(portal) = portals.get_mut(&msg.name) else {
332 send_error(
333 stream,
334 "34000",
335 &format!("portal {:?} does not exist", msg.name),
336 )
337 .await?;
338 return Ok(());
339 };
340 if is_ask_query(&portal.sql) {
341 emit_ask_row_description(stream).await?;
342 portal.row_description_sent = true;
343 Ok(())
344 } else if is_row_returning_query(&portal.sql) {
345 let result = match execute_pg_query_result(runtime, &portal.sql, &portal.params) {
346 Ok(result) => result,
347 Err(err) => {
348 let code = classify_sqlstate(&err);
349 send_error(stream, code, &err).await?;
350 return Ok(());
351 }
352 };
353 emit_row_description_for_result(stream, &result).await?;
354 portal.row_description_sent = true;
355 portal.described_result = Some(result);
356 Ok(())
357 } else {
358 write_frame(stream, &BackendMessage::NoData).await
359 }
360 }
361 }
362}
363
364async fn handle_execute<S>(
365 stream: &mut S,
366 runtime: &RedDBRuntime,
367 portals: &mut HashMap<String, PgPortal>,
368 msg: super::protocol::ExecuteMessage,
369) -> Result<(), PgWireError>
370where
371 S: AsyncWrite + Unpin,
372{
373 let Some(portal) = portals.get_mut(&msg.portal) else {
374 send_error(
375 stream,
376 "34000",
377 &format!("portal {:?} does not exist", msg.portal),
378 )
379 .await?;
380 return Ok(());
381 };
382 let _max_rows = msg.max_rows;
383 let was_described = portal.row_description_sent || portal.described_result.is_some();
384 portal.row_description_sent = false;
385 let result = match portal.described_result.take() {
386 Some(result) => Ok(result),
387 None => execute_pg_query_result(runtime, &portal.sql, &portal.params),
388 };
389 match result {
390 Ok(result) if was_described => {
391 emit_success_result_without_row_description(stream, &result).await
392 }
393 Ok(result) => emit_success_result(stream, &result).await,
394 Err(err) => {
395 let code = classify_sqlstate(&err);
396 send_error(stream, code, &err).await
397 }
398 }
399}
400
401async fn handle_close<S>(
402 stream: &mut S,
403 prepared: &mut HashMap<String, PgPreparedStatement>,
404 portals: &mut HashMap<String, PgPortal>,
405 msg: super::protocol::CloseMessage,
406) -> Result<(), PgWireError>
407where
408 S: AsyncWrite + Unpin,
409{
410 match msg.target {
411 DescribeTarget::Statement => {
412 prepared.remove(&msg.name);
413 }
414 DescribeTarget::Portal => {
415 portals.remove(&msg.name);
416 }
417 }
418 write_frame(stream, &BackendMessage::CloseComplete).await
419}
420
421fn bind_pg_params(
422 stmt: &PgPreparedStatement,
423 msg: &super::protocol::BindMessage,
424) -> Result<Vec<Value>, String> {
425 if !matches!(msg.param_format_codes.len(), 0 | 1)
426 && msg.param_format_codes.len() != msg.params.len()
427 {
428 return Err("Bind format count must be 0, 1, or match parameter count".to_string());
429 }
430 msg.params
431 .iter()
432 .enumerate()
433 .map(|(idx, param)| {
434 let oid = stmt
435 .param_type_oids
436 .get(idx)
437 .copied()
438 .unwrap_or(PgOid::Unknown.as_u32());
439 let format_code = match msg.param_format_codes.as_slice() {
440 [] => 0,
441 [format] => *format,
442 formats => formats[idx],
443 };
444 pg_param_to_value(oid, format_code, param.as_deref())
445 })
446 .collect()
447}
448
449fn execute_pg_query_result(
450 runtime: &RedDBRuntime,
451 sql: &str,
452 params: &[Value],
453) -> Result<crate::runtime::RuntimeQueryResult, String> {
454 if let Some(result) = try_execute_pg_scalar_select(sql, params) {
455 return Ok(result);
456 }
457 if params.is_empty() {
458 return match translate_pg_catalog_query(runtime, sql) {
459 Ok(Some(result)) => Ok(crate::runtime::RuntimeQueryResult {
460 query: sql.to_string(),
461 mode: crate::storage::query::modes::QueryMode::Sql,
462 statement: "select",
463 engine: "pg-catalog",
464 result,
465 affected_rows: 0,
466 statement_type: "select",
467 }),
468 Ok(None) => runtime.execute_query(sql).map_err(|err| err.to_string()),
469 Err(err) => Err(err.to_string()),
470 };
471 }
472
473 let parsed = crate::storage::query::modes::parse_multi(sql).map_err(|err| err.to_string())?;
474 let bound =
475 crate::storage::query::user_params::bind(&parsed, params).map_err(|err| err.to_string())?;
476 runtime
477 .execute_query_expr(bound)
478 .map_err(|err| err.to_string())
479}
480
481fn try_execute_pg_scalar_select(
482 sql: &str,
483 params: &[Value],
484) -> Option<crate::runtime::RuntimeQueryResult> {
485 let index = pg_scalar_select_param_index(sql)?;
486 let value = params.get(index)?.clone();
487 let mut result = UnifiedResult::with_columns(vec!["?column?".to_string()]);
488 let mut record = UnifiedRecord::new();
489 record.set("?column?", value);
490 result.push(record);
491 Some(crate::runtime::RuntimeQueryResult {
492 query: sql.to_string(),
493 mode: crate::storage::query::modes::QueryMode::Sql,
494 statement: "select",
495 engine: "pg-wire",
496 result,
497 affected_rows: 0,
498 statement_type: "select",
499 })
500}
501
502fn pg_scalar_select_param_index(sql: &str) -> Option<usize> {
503 let trimmed = sql.trim().trim_end_matches(';').trim();
504 let lower = trimmed.to_ascii_lowercase();
505 let body = lower.strip_prefix("select ")?;
506 let param = if let Some(inner) = body.strip_prefix("cast(") {
507 let end = inner.find(" as ")?;
508 &inner[..end]
509 } else {
510 body.split_whitespace().next()?
511 };
512 let digits = param.strip_prefix('$')?;
513 let n = digits.parse::<usize>().ok()?;
514 n.checked_sub(1)
515}
516
517fn rewrite_pg_parameter_casts(sql: &str) -> String {
518 let mut out = String::with_capacity(sql.len());
519 let bytes = sql.as_bytes();
520 let mut cursor = 0;
521 let mut pos = 0;
522 while pos < bytes.len() {
523 if bytes[pos] != b'$' {
524 pos += 1;
525 continue;
526 }
527 let param_start = pos;
528 pos += 1;
529 let digits_start = pos;
530 while pos < bytes.len() && bytes[pos].is_ascii_digit() {
531 pos += 1;
532 }
533 if digits_start == pos {
534 continue;
535 }
536 if pos + 2 <= bytes.len() && &bytes[pos..pos + 2] == b"::" {
537 let param_end = pos;
538 pos += 2;
539 let type_start = pos;
540 while pos < bytes.len()
541 && (bytes[pos].is_ascii_alphanumeric() || matches!(bytes[pos], b'_' | b'.'))
542 {
543 pos += 1;
544 }
545 if type_start != pos {
546 out.push_str(&sql[cursor..param_start]);
547 out.push_str(&sql[param_start..param_end]);
548 cursor = pos;
549 continue;
550 }
551 }
552 }
553 out.push_str(&sql[cursor..]);
554 out
555}
556
557fn infer_pg_cast_param_type_oids(sql: &str) -> Vec<(usize, u32)> {
558 let mut out = Vec::new();
559 let bytes = sql.as_bytes();
560 let mut pos = 0;
561 while pos < bytes.len() {
562 if bytes[pos] != b'$' {
563 pos += 1;
564 continue;
565 }
566 pos += 1;
567 let digits_start = pos;
568 while pos < bytes.len() && bytes[pos].is_ascii_digit() {
569 pos += 1;
570 }
571 if digits_start == pos {
572 continue;
573 }
574 let Some(param_index) = sql[digits_start..pos]
575 .parse::<usize>()
576 .ok()
577 .and_then(|idx| idx.checked_sub(1))
578 else {
579 continue;
580 };
581 if pos + 2 > bytes.len() || &bytes[pos..pos + 2] != b"::" {
582 continue;
583 }
584 pos += 2;
585 let type_start = pos;
586 while pos < bytes.len()
587 && (bytes[pos].is_ascii_alphanumeric() || matches!(bytes[pos], b'_' | b'.'))
588 {
589 pos += 1;
590 }
591 if type_start == pos {
592 continue;
593 }
594 if let Some(oid) = pg_cast_type_oid(&sql[type_start..pos]) {
595 out.push((param_index, oid));
596 }
597 }
598 out
599}
600
601fn pg_cast_type_oid(ty: &str) -> Option<u32> {
602 let lower = ty.to_ascii_lowercase();
603 let short = lower.rsplit('.').next().unwrap_or(lower.as_str());
604 let oid = match short {
605 "bool" | "boolean" => PgOid::Bool,
606 "int2" | "smallint" => PgOid::Int2,
607 "int" | "int4" | "integer" => PgOid::Int4,
608 "int8" | "bigint" => PgOid::Int8,
609 "float4" | "real" => PgOid::Float4,
610 "float8" | "double" | "doubleprecision" => PgOid::Float8,
611 "numeric" | "decimal" => PgOid::Numeric,
612 "bytea" => PgOid::Bytea,
613 "json" => PgOid::Json,
614 "jsonb" => PgOid::Jsonb,
615 "text" => PgOid::Text,
616 "varchar" | "character varying" => PgOid::Varchar,
617 "uuid" => PgOid::Uuid,
618 "timestamp" => PgOid::Timestamp,
619 "timestamptz" | "timestampz" => PgOid::TimestampTz,
620 "vector" => PgOid::Vector,
621 _ => return None,
622 };
623 Some(oid.as_u32())
624}
625
626fn is_row_returning_query(sql: &str) -> bool {
627 let trimmed = sql.trim_start().to_ascii_lowercase();
628 trimmed.starts_with("select")
629 || trimmed.starts_with("with")
630 || trimmed.starts_with("ask")
631 || trimmed.starts_with("search")
632 || trimmed.starts_with("vector")
633 || trimmed.starts_with("hybrid")
634}
635
636fn is_ask_query(sql: &str) -> bool {
637 sql.trim_start().to_ascii_lowercase().starts_with("ask")
638}
639
640async fn send_auth_ok<S>(
641 stream: &mut S,
642 config: &PgWireConfig,
643 params: &super::protocol::StartupParams,
644) -> Result<(), PgWireError>
645where
646 S: AsyncWrite + Unpin,
647{
648 write_frame(stream, &BackendMessage::AuthenticationOk).await?;
650
651 for (name, value) in [
653 ("server_version", config.server_version.as_str()),
654 ("server_encoding", "UTF8"),
655 ("client_encoding", "UTF8"),
656 ("DateStyle", "ISO, MDY"),
657 ("TimeZone", "UTC"),
658 ("integer_datetimes", "on"),
659 ("standard_conforming_strings", "on"),
660 (
661 "application_name",
662 params.get("application_name").unwrap_or(""),
663 ),
664 ] {
665 write_frame(
666 stream,
667 &BackendMessage::ParameterStatus {
668 name: name.to_string(),
669 value: value.to_string(),
670 },
671 )
672 .await?;
673 }
674
675 write_frame(
678 stream,
679 &BackendMessage::BackendKeyData {
680 pid: std::process::id(),
681 key: 0xDEADBEEF,
682 },
683 )
684 .await?;
685
686 write_frame(
687 stream,
688 &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
689 )
690 .await?;
691 Ok(())
692}
693
694async fn handle_simple_query<S>(
695 stream: &mut S,
696 runtime: &RedDBRuntime,
697 sql: &str,
698) -> Result<(), PgWireError>
699where
700 S: AsyncWrite + Unpin,
701{
702 if sql.trim().is_empty() {
705 write_frame(stream, &BackendMessage::EmptyQueryResponse).await?;
706 write_frame(
707 stream,
708 &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
709 )
710 .await?;
711 return Ok(());
712 }
713
714 if let Some(tag) = pg_session_compat_command_tag(sql) {
715 write_frame(stream, &BackendMessage::CommandComplete(tag.to_string())).await?;
716 write_frame(
717 stream,
718 &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
719 )
720 .await?;
721 return Ok(());
722 }
723
724 let query_result = match translate_pg_catalog_query(runtime, sql) {
725 Ok(Some(result)) => Ok(crate::runtime::RuntimeQueryResult {
726 query: sql.to_string(),
727 mode: crate::storage::query::modes::QueryMode::Sql,
728 statement: "select",
729 engine: "pg-catalog",
730 result,
731 affected_rows: 0,
732 statement_type: "select",
733 }),
734 Ok(None) => runtime.execute_query(sql),
735 Err(err) => Err(err),
736 };
737
738 match query_result {
739 Ok(result) => {
740 emit_success_result(stream, &result).await?;
741 }
742 Err(err) => {
743 let code = classify_sqlstate(&err.to_string());
747 send_error(stream, code, &err.to_string()).await?;
748 }
749 }
750
751 write_frame(
752 stream,
753 &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
754 )
755 .await?;
756 Ok(())
757}
758
759fn pg_session_compat_command_tag(sql: &str) -> Option<&'static str> {
760 let lower = sql.trim().trim_end_matches(';').to_ascii_lowercase();
761 if lower.starts_with("set ") {
762 return Some("SET");
763 }
764 None
765}
766
767async fn emit_success_result<S>(
768 stream: &mut S,
769 result: &crate::runtime::RuntimeQueryResult,
770) -> Result<(), PgWireError>
771where
772 S: AsyncWrite + Unpin,
773{
774 if result.statement == "ask" {
775 emit_ask_result_row(stream, result).await?;
776 write_frame(
777 stream,
778 &BackendMessage::CommandComplete("SELECT 1".to_string()),
779 )
780 .await?;
781 } else if result_returns_rows(result) {
782 emit_result_rows(stream, &result.result).await?;
783 write_frame(
784 stream,
785 &BackendMessage::CommandComplete(format!("SELECT {}", result.result.records.len())),
786 )
787 .await?;
788 } else {
789 let tag = match result.statement_type {
794 "insert" => format!("INSERT 0 {}", result.affected_rows),
795 "update" => format!("UPDATE {}", result.affected_rows),
796 "delete" => format!("DELETE {}", result.affected_rows),
797 other => other.to_uppercase(),
798 };
799 write_frame(stream, &BackendMessage::CommandComplete(tag)).await?;
800 }
801 Ok(())
802}
803
804async fn emit_success_result_without_row_description<S>(
805 stream: &mut S,
806 result: &crate::runtime::RuntimeQueryResult,
807) -> Result<(), PgWireError>
808where
809 S: AsyncWrite + Unpin,
810{
811 if result.statement == "ask" {
812 let row = ask_query_result_to_pg_wire_row(result)
813 .ok_or_else(|| PgWireError::Protocol("ASK result missing row body".to_string()))?;
814 write_frame(stream, &BackendMessage::DataRow(row.cells)).await?;
815 write_frame(
816 stream,
817 &BackendMessage::CommandComplete("SELECT 1".to_string()),
818 )
819 .await?;
820 } else if result_returns_rows(result) {
821 emit_result_data_rows(stream, &result.result).await?;
822 write_frame(
823 stream,
824 &BackendMessage::CommandComplete(format!("SELECT {}", result.result.records.len())),
825 )
826 .await?;
827 } else {
828 let tag = match result.statement_type {
829 "insert" => format!("INSERT 0 {}", result.affected_rows),
830 "update" => format!("UPDATE {}", result.affected_rows),
831 "delete" => format!("DELETE {}", result.affected_rows),
832 other => other.to_uppercase(),
833 };
834 write_frame(stream, &BackendMessage::CommandComplete(tag)).await?;
835 }
836 Ok(())
837}
838
839async fn emit_row_description_for_result<S>(
840 stream: &mut S,
841 result: &crate::runtime::RuntimeQueryResult,
842) -> Result<(), PgWireError>
843where
844 S: AsyncWrite + Unpin,
845{
846 if result.statement == "ask" {
847 emit_ask_row_description(stream).await
848 } else if result_returns_rows(result) {
849 emit_result_row_description(stream, &result.result).await
850 } else {
851 write_frame(stream, &BackendMessage::NoData).await
852 }
853}
854
855fn result_returns_rows(result: &crate::runtime::RuntimeQueryResult) -> bool {
856 result.statement_type == "select"
857}
858
859async fn emit_result_rows<S>(
860 stream: &mut S,
861 result: &crate::storage::query::unified::UnifiedResult,
862) -> Result<(), PgWireError>
863where
864 S: AsyncWrite + Unpin,
865{
866 emit_result_row_description(stream, result).await?;
867 emit_result_data_rows(stream, result).await
868}
869
870async fn emit_result_row_description<S>(
871 stream: &mut S,
872 result: &crate::storage::query::unified::UnifiedResult,
873) -> Result<(), PgWireError>
874where
875 S: AsyncWrite + Unpin,
876{
877 let columns: Vec<String> = if !result.columns.is_empty() {
881 result.columns.clone()
882 } else if let Some(first) = result.records.first() {
883 record_field_names(first)
884 } else {
885 Vec::new()
886 };
887
888 let type_oids: Vec<PgOid> = columns
892 .iter()
893 .map(|col| {
894 result
895 .records
896 .first()
897 .and_then(|r| record_get(r, col))
898 .map(PgOid::from_value)
899 .unwrap_or(PgOid::Text)
900 })
901 .collect();
902
903 let descriptors: Vec<ColumnDescriptor> = columns
904 .iter()
905 .zip(type_oids.iter())
906 .map(|(name, oid)| ColumnDescriptor {
907 name: name.clone(),
908 table_oid: 0,
909 column_attr: 0,
910 type_oid: oid.as_u32(),
911 type_size: -1,
912 type_mod: -1,
913 format: 0,
914 })
915 .collect();
916
917 write_frame(stream, &BackendMessage::RowDescription(descriptors)).await
918}
919
920async fn emit_result_data_rows<S>(
921 stream: &mut S,
922 result: &crate::storage::query::unified::UnifiedResult,
923) -> Result<(), PgWireError>
924where
925 S: AsyncWrite + Unpin,
926{
927 let columns: Vec<String> = if !result.columns.is_empty() {
928 result.columns.clone()
929 } else if let Some(first) = result.records.first() {
930 record_field_names(first)
931 } else {
932 Vec::new()
933 };
934 for record in &result.records {
935 let fields: Vec<Option<Vec<u8>>> = columns
936 .iter()
937 .map(|col| record_get(record, col).and_then(value_to_pg_wire_bytes))
938 .collect();
939 write_frame(stream, &BackendMessage::DataRow(fields)).await?;
940 }
941
942 Ok(())
943}
944
945async fn emit_ask_result_row<S>(
946 stream: &mut S,
947 result: &crate::runtime::RuntimeQueryResult,
948) -> Result<(), PgWireError>
949where
950 S: AsyncWrite + Unpin,
951{
952 let row = ask_query_result_to_pg_wire_row(result)
953 .ok_or_else(|| PgWireError::Protocol("ASK result missing row body".to_string()))?;
954
955 emit_ask_row_description(stream).await?;
956 write_frame(stream, &BackendMessage::DataRow(row.cells)).await?;
957 Ok(())
958}
959
960async fn emit_ask_row_description<S>(stream: &mut S) -> Result<(), PgWireError>
961where
962 S: AsyncWrite + Unpin,
963{
964 let descriptors: Vec<ColumnDescriptor> = crate::runtime::ai::pg_wire_ask_row_encoder::columns()
965 .iter()
966 .map(|col| ColumnDescriptor {
967 name: col.name.to_string(),
968 table_oid: 0,
969 column_attr: 0,
970 type_oid: col.oid.as_u32(),
971 type_size: -1,
972 type_mod: -1,
973 format: 0,
974 })
975 .collect();
976 write_frame(stream, &BackendMessage::RowDescription(descriptors)).await
977}
978
979fn ask_query_result_to_pg_wire_row(
980 result: &crate::runtime::RuntimeQueryResult,
981) -> Option<crate::runtime::ai::pg_wire_ask_row_encoder::AskRow> {
982 if result.statement != "ask" {
983 return None;
984 }
985 let record = result.result.records.first()?;
986 let sources_flat_json =
987 json_field(record, "sources_flat").unwrap_or(crate::json::Value::Array(Vec::new()));
988 let citations_json =
989 json_field(record, "citations").unwrap_or(crate::json::Value::Array(Vec::new()));
990 let validation_json = json_field(record, "validation")
991 .unwrap_or_else(|| crate::json::Value::Object(Default::default()));
992
993 let effective_mode = match text_field(record, "mode").as_deref() {
994 Some("lenient") => Mode::Lenient,
995 _ => Mode::Strict,
996 };
997
998 let ask = AskResult {
999 answer: text_field(record, "answer")?,
1000 sources_flat: ask_sources_flat(&sources_flat_json),
1001 citations: ask_citations(&citations_json),
1002 validation: ask_validation(&validation_json),
1003 cache_hit: bool_field(record, "cache_hit").unwrap_or(false),
1004 provider: text_field(record, "provider").unwrap_or_default(),
1005 model: text_field(record, "model").unwrap_or_default(),
1006 prompt_tokens: u32_field(record, "prompt_tokens").unwrap_or(0),
1007 completion_tokens: u32_field(record, "completion_tokens").unwrap_or(0),
1008 cost_usd: f64_field(record, "cost_usd").unwrap_or(0.0),
1009 effective_mode,
1010 retry_count: u32_field(record, "retry_count").unwrap_or(0),
1011 };
1012
1013 Some(crate::runtime::ai::pg_wire_ask_row_encoder::encode(&ask))
1014}
1015
1016fn record_field<'a>(record: &'a UnifiedRecord, key: &str) -> Option<&'a Value> {
1017 record.iter_fields().find_map(|(name, value)| {
1018 let name: &str = name;
1019 (name == key).then_some(value)
1020 })
1021}
1022
1023fn text_field(record: &UnifiedRecord, key: &str) -> Option<String> {
1024 match record_field(record, key)? {
1025 Value::Text(s) => Some(s.to_string()),
1026 Value::Email(s) | Value::Url(s) | Value::NodeRef(s) | Value::EdgeRef(s) => Some(s.clone()),
1027 other => Some(other.to_string()),
1028 }
1029}
1030
1031fn bool_field(record: &UnifiedRecord, key: &str) -> Option<bool> {
1032 match record_field(record, key)? {
1033 Value::Boolean(value) => Some(*value),
1034 _ => None,
1035 }
1036}
1037
1038fn u32_field(record: &UnifiedRecord, key: &str) -> Option<u32> {
1039 match record_field(record, key)? {
1040 Value::Integer(n) => (*n >= 0).then_some((*n).min(u32::MAX as i64) as u32),
1041 Value::UnsignedInteger(n) => Some((*n).min(u32::MAX as u64) as u32),
1042 Value::BigInt(n)
1043 | Value::TimestampMs(n)
1044 | Value::Timestamp(n)
1045 | Value::Duration(n)
1046 | Value::Decimal(n) => (*n >= 0).then_some((*n).min(u32::MAX as i64) as u32),
1047 Value::Float(n) => (*n >= 0.0).then_some((*n).min(u32::MAX as f64) as u32),
1048 _ => None,
1049 }
1050}
1051
1052fn f64_field(record: &UnifiedRecord, key: &str) -> Option<f64> {
1053 match record_field(record, key)? {
1054 Value::Integer(n) => Some(*n as f64),
1055 Value::UnsignedInteger(n) => Some(*n as f64),
1056 Value::BigInt(n)
1057 | Value::TimestampMs(n)
1058 | Value::Timestamp(n)
1059 | Value::Duration(n)
1060 | Value::Decimal(n) => Some(*n as f64),
1061 Value::Float(n) => Some(*n),
1062 _ => None,
1063 }
1064}
1065
1066fn json_field(record: &UnifiedRecord, key: &str) -> Option<crate::json::Value> {
1067 match record_field(record, key)? {
1068 Value::Json(bytes) => crate::json::from_slice(bytes).ok(),
1069 Value::Text(text) => crate::json::from_str(text).ok(),
1070 _ => None,
1071 }
1072}
1073
1074fn ask_sources_flat(value: &crate::json::Value) -> Vec<SourceRow> {
1075 value
1076 .as_array()
1077 .unwrap_or(&[])
1078 .iter()
1079 .filter_map(|source| {
1080 let urn = source
1081 .get("urn")
1082 .and_then(crate::json::Value::as_str)?
1083 .to_string();
1084 let payload = source
1085 .get("payload")
1086 .and_then(crate::json::Value::as_str)
1087 .map(ToString::to_string)
1088 .unwrap_or_else(|| source.to_string_compact());
1089 Some(SourceRow { urn, payload })
1090 })
1091 .collect()
1092}
1093
1094fn ask_citations(value: &crate::json::Value) -> Vec<Citation> {
1095 value
1096 .as_array()
1097 .unwrap_or(&[])
1098 .iter()
1099 .filter_map(|citation| {
1100 let marker = citation
1101 .get("marker")
1102 .and_then(crate::json::Value::as_u64)?;
1103 let urn = citation
1104 .get("urn")
1105 .and_then(crate::json::Value::as_str)?
1106 .to_string();
1107 Some(Citation {
1108 marker: marker.min(u32::MAX as u64) as u32,
1109 urn,
1110 })
1111 })
1112 .collect()
1113}
1114
1115fn ask_validation(value: &crate::json::Value) -> Validation {
1116 Validation {
1117 ok: value
1118 .get("ok")
1119 .and_then(crate::json::Value::as_bool)
1120 .unwrap_or(true),
1121 warnings: validation_items(value, "warnings")
1122 .into_iter()
1123 .map(|(kind, detail)| ValidationWarning { kind, detail })
1124 .collect(),
1125 errors: validation_items(value, "errors")
1126 .into_iter()
1127 .map(|(kind, detail)| ValidationError { kind, detail })
1128 .collect(),
1129 }
1130}
1131
1132fn validation_items(value: &crate::json::Value, key: &str) -> Vec<(String, String)> {
1133 value
1134 .get(key)
1135 .and_then(crate::json::Value::as_array)
1136 .unwrap_or(&[])
1137 .iter()
1138 .filter_map(|item| {
1139 Some((
1140 item.get("kind")
1141 .and_then(crate::json::Value::as_str)?
1142 .to_string(),
1143 item.get("detail")
1144 .and_then(crate::json::Value::as_str)
1145 .unwrap_or("")
1146 .to_string(),
1147 ))
1148 })
1149 .collect()
1150}
1151
1152fn record_get<'a>(record: &'a UnifiedRecord, key: &str) -> Option<&'a Value> {
1156 record.get(key)
1157}
1158
1159fn record_field_names(record: &UnifiedRecord) -> Vec<String> {
1168 record
1172 .column_names()
1173 .into_iter()
1174 .map(|k| k.to_string())
1175 .collect()
1176}
1177
1178async fn send_error<S>(stream: &mut S, code: &str, message: &str) -> Result<(), PgWireError>
1179where
1180 S: AsyncWrite + Unpin,
1181{
1182 write_frame(
1183 stream,
1184 &BackendMessage::ErrorResponse {
1185 severity: "ERROR".to_string(),
1186 code: code.to_string(),
1187 message: message.to_string(),
1188 },
1189 )
1190 .await
1191}
1192
1193fn classify_sqlstate(msg: &str) -> &'static str {
1197 let lower = msg.to_ascii_lowercase();
1198 if lower.contains("not found") || lower.contains("does not exist") {
1199 "42P01"
1201 } else if lower.contains("parse") || lower.contains("expected") || lower.contains("syntax") {
1202 "42601"
1203 } else if lower.contains("already exists") {
1204 "42P07"
1205 } else if lower.contains("permission") || lower.contains("auth") {
1206 "28000"
1207 } else {
1208 "XX000"
1209 }
1210}
1211
1212#[cfg(test)]
1213mod tests {
1214 use super::*;
1215 use crate::api::RedDBOptions;
1216 use crate::runtime::RuntimeQueryResult;
1217 use crate::storage::query::modes::QueryMode;
1218 use crate::storage::query::unified::UnifiedResult;
1219 use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
1220
1221 #[tokio::test]
1222 async fn extended_parse_bind_execute_returns_rows() {
1223 let runtime = Arc::new(RedDBRuntime::with_options(RedDBOptions::in_memory()).unwrap());
1224 let config = Arc::new(PgWireConfig::default());
1225 let (server_io, mut client_io) = tokio::io::duplex(64 * 1024);
1226 let server = tokio::spawn(async move {
1227 handle_connection(server_io, runtime, config).await.unwrap();
1228 });
1229
1230 write_startup(&mut client_io).await;
1231 read_until_ready(&mut client_io).await;
1232
1233 write_frontend_frame(
1234 &mut client_io,
1235 b'P',
1236 parse_body("", "SELECT $1::int", &[PgOid::Int4.as_u32()]),
1237 )
1238 .await;
1239 write_frontend_frame(
1240 &mut client_io,
1241 b'B',
1242 bind_body("", "", &[0], &[Some(b"42".as_slice())], &[]),
1243 )
1244 .await;
1245 write_frontend_frame(&mut client_io, b'D', describe_body(b'P', "")).await;
1246 write_frontend_frame(&mut client_io, b'E', execute_body("", 0)).await;
1247 write_frontend_frame(&mut client_io, b'S', Vec::new()).await;
1248
1249 let frames = read_until_ready(&mut client_io).await;
1250 assert_eq!(
1251 frames.iter().map(|(tag, _)| *tag).collect::<Vec<_>>(),
1252 b"12TDCZ"
1253 );
1254 let columns = decode_row_description(&frames[2].1);
1255 assert_eq!(columns.len(), 1);
1256 let cells = decode_data_row(&frames[3].1);
1257 assert_eq!(cells.len(), 1);
1258 assert_eq!(cells[0].as_deref(), Some(b"42".as_slice()));
1259 assert_eq!(decode_command_complete(&frames[4].1), "SELECT 1");
1260
1261 write_frontend_frame(&mut client_io, b'X', Vec::new()).await;
1262 server.await.unwrap();
1263 }
1264
1265 #[test]
1266 fn infer_pg_cast_param_type_oids_from_parameter_casts() {
1267 assert_eq!(
1268 infer_pg_cast_param_type_oids("INSERT INTO t (id, name) VALUES ($1::int, $2::text)"),
1269 vec![(0, PgOid::Int4.as_u32()), (1, PgOid::Text.as_u32())]
1270 );
1271 assert_eq!(
1272 infer_pg_cast_param_type_oids("SEARCH SIMILAR [1.0] COLLECTION v LIMIT $1::int8"),
1273 vec![(0, PgOid::Int8.as_u32())]
1274 );
1275 }
1276
1277 #[test]
1278 fn pg_session_compat_accepts_driver_setup_set_commands() {
1279 assert_eq!(
1280 pg_session_compat_command_tag("SET extra_float_digits = 3"),
1281 Some("SET")
1282 );
1283 assert_eq!(
1284 pg_session_compat_command_tag("SET application_name = 'pgjdbc'"),
1285 Some("SET")
1286 );
1287 assert_eq!(pg_session_compat_command_tag("SELECT 1"), None);
1288 }
1289
1290 #[tokio::test]
1291 async fn ask_success_result_uses_canonical_pg_wire_row_shape() {
1292 let mut result = UnifiedResult::with_columns(vec![
1293 "answer".into(),
1294 "provider".into(),
1295 "model".into(),
1296 "prompt_tokens".into(),
1297 "completion_tokens".into(),
1298 "sources_count".into(),
1299 "sources_flat".into(),
1300 "citations".into(),
1301 "validation".into(),
1302 ]);
1303 let mut record = UnifiedRecord::new();
1304 record.set("answer", Value::text("Deploy failed [^1]."));
1305 record.set("provider", Value::text("openai"));
1306 record.set("model", Value::text("gpt-4o-mini"));
1307 record.set("prompt_tokens", Value::Integer(11));
1308 record.set("completion_tokens", Value::Integer(7));
1309 record.set(
1310 "sources_flat",
1311 Value::Json(
1312 br#"[{"urn":"urn:reddb:row:deployments:1","kind":"row","collection":"deployments","id":"1"}]"#
1313 .to_vec(),
1314 ),
1315 );
1316 record.set(
1317 "citations",
1318 Value::Json(br#"[{"marker":1,"urn":"urn:reddb:row:deployments:1"}]"#.to_vec()),
1319 );
1320 record.set(
1321 "validation",
1322 Value::Json(br#"{"ok":true,"warnings":[],"errors":[]}"#.to_vec()),
1323 );
1324 result.push(record);
1325
1326 let qr = RuntimeQueryResult {
1327 query: "ASK 'why did deploy fail?'".to_string(),
1328 mode: QueryMode::Sql,
1329 statement: "ask",
1330 engine: "runtime-ai",
1331 result,
1332 affected_rows: 0,
1333 statement_type: "select",
1334 };
1335
1336 let mut out = Vec::new();
1337 emit_success_result(&mut out, &qr).await.unwrap();
1338 let frames = decode_frames(&out);
1339
1340 assert_eq!(
1341 frames.iter().map(|(tag, _)| *tag).collect::<Vec<_>>(),
1342 b"TDC"
1343 );
1344
1345 let columns = decode_row_description(frames[0].1);
1346 assert_eq!(
1347 columns,
1348 vec![
1349 ("answer".to_string(), PgOid::Text.as_u32()),
1350 ("cache_hit".to_string(), PgOid::Bool.as_u32()),
1351 ("citations".to_string(), PgOid::Jsonb.as_u32()),
1352 ("completion_tokens".to_string(), PgOid::Int8.as_u32()),
1353 ("cost_usd".to_string(), PgOid::Numeric.as_u32()),
1354 ("mode".to_string(), PgOid::Text.as_u32()),
1355 ("model".to_string(), PgOid::Text.as_u32()),
1356 ("prompt_tokens".to_string(), PgOid::Int8.as_u32()),
1357 ("provider".to_string(), PgOid::Text.as_u32()),
1358 ("retry_count".to_string(), PgOid::Int8.as_u32()),
1359 ("sources_flat".to_string(), PgOid::Jsonb.as_u32()),
1360 ("validation".to_string(), PgOid::Jsonb.as_u32()),
1361 ]
1362 );
1363
1364 let cells = decode_data_row(frames[1].1);
1365 assert_eq!(cells.len(), 12);
1366 assert_eq!(cells[0].as_deref(), Some(b"Deploy failed [^1].".as_slice()));
1367 assert_eq!(cells[1].as_deref(), Some(b"f".as_slice()));
1368 assert_eq!(cells[4].as_deref(), Some(b"0".as_slice()));
1369 assert_eq!(cells[5].as_deref(), Some(b"strict".as_slice()));
1370 assert_eq!(cells[9].as_deref(), Some(b"0".as_slice()));
1371 assert!(std::str::from_utf8(cells[10].as_deref().unwrap())
1372 .unwrap()
1373 .contains(r#""payload""#));
1374 assert_eq!(decode_command_complete(frames[2].1), "SELECT 1");
1375 }
1376
1377 fn decode_frames(bytes: &[u8]) -> Vec<(u8, &[u8])> {
1378 let mut pos = 0;
1379 let mut frames = Vec::new();
1380 while pos < bytes.len() {
1381 let tag = bytes[pos];
1382 let len = u32::from_be_bytes([
1383 bytes[pos + 1],
1384 bytes[pos + 2],
1385 bytes[pos + 3],
1386 bytes[pos + 4],
1387 ]) as usize;
1388 let body_start = pos + 5;
1389 let body_end = pos + 1 + len;
1390 frames.push((tag, &bytes[body_start..body_end]));
1391 pos = body_end;
1392 }
1393 frames
1394 }
1395
1396 fn decode_row_description(body: &[u8]) -> Vec<(String, u32)> {
1397 let count = i16::from_be_bytes([body[0], body[1]]) as usize;
1398 let mut pos = 2;
1399 let mut columns = Vec::with_capacity(count);
1400 for _ in 0..count {
1401 let end = body[pos..].iter().position(|&b| b == 0).unwrap() + pos;
1402 let name = std::str::from_utf8(&body[pos..end]).unwrap().to_string();
1403 pos = end + 1;
1404 pos += 4; pos += 2; let oid = u32::from_be_bytes([body[pos], body[pos + 1], body[pos + 2], body[pos + 3]]);
1407 pos += 4;
1408 pos += 2; pos += 4; pos += 2; columns.push((name, oid));
1412 }
1413 columns
1414 }
1415
1416 fn decode_data_row(body: &[u8]) -> Vec<Option<Vec<u8>>> {
1417 let count = i16::from_be_bytes([body[0], body[1]]) as usize;
1418 let mut pos = 2;
1419 let mut cells = Vec::with_capacity(count);
1420 for _ in 0..count {
1421 let len = i32::from_be_bytes([body[pos], body[pos + 1], body[pos + 2], body[pos + 3]]);
1422 pos += 4;
1423 if len < 0 {
1424 cells.push(None);
1425 } else {
1426 let len = len as usize;
1427 cells.push(Some(body[pos..pos + len].to_vec()));
1428 pos += len;
1429 }
1430 }
1431 cells
1432 }
1433
1434 fn decode_command_complete(body: &[u8]) -> &str {
1435 let nul = body.iter().position(|&b| b == 0).unwrap_or(body.len());
1436 std::str::from_utf8(&body[..nul]).unwrap()
1437 }
1438
1439 async fn write_startup<W: AsyncWrite + Unpin>(stream: &mut W) {
1440 let mut payload = Vec::new();
1441 payload.extend_from_slice(&crate::wire::postgres::protocol::PG_PROTOCOL_V3.to_be_bytes());
1442 payload.extend_from_slice(b"user\0reddb\0");
1443 payload.push(0);
1444 let len = (payload.len() + 4) as u32;
1445 stream.write_all(&len.to_be_bytes()).await.unwrap();
1446 stream.write_all(&payload).await.unwrap();
1447 }
1448
1449 async fn write_frontend_frame<W: AsyncWrite + Unpin>(
1450 stream: &mut W,
1451 tag: u8,
1452 payload: Vec<u8>,
1453 ) {
1454 stream.write_all(&[tag]).await.unwrap();
1455 stream
1456 .write_all(&((payload.len() + 4) as u32).to_be_bytes())
1457 .await
1458 .unwrap();
1459 stream.write_all(&payload).await.unwrap();
1460 }
1461
1462 async fn read_backend_frame<R: AsyncRead + Unpin>(stream: &mut R) -> (u8, Vec<u8>) {
1463 let mut tag = [0u8; 1];
1464 stream.read_exact(&mut tag).await.unwrap();
1465 let mut len = [0u8; 4];
1466 stream.read_exact(&mut len).await.unwrap();
1467 let len = u32::from_be_bytes(len) as usize;
1468 let mut body = vec![0u8; len - 4];
1469 stream.read_exact(&mut body).await.unwrap();
1470 (tag[0], body)
1471 }
1472
1473 async fn read_until_ready<R: AsyncRead + Unpin>(stream: &mut R) -> Vec<(u8, Vec<u8>)> {
1474 let mut frames = Vec::new();
1475 loop {
1476 let frame = read_backend_frame(stream).await;
1477 let done = frame.0 == b'Z';
1478 frames.push(frame);
1479 if done {
1480 return frames;
1481 }
1482 }
1483 }
1484
1485 fn parse_body(statement: &str, query: &str, oids: &[u32]) -> Vec<u8> {
1486 let mut out = Vec::new();
1487 push_pg_cstring(&mut out, statement);
1488 push_pg_cstring(&mut out, query);
1489 out.extend_from_slice(&(oids.len() as i16).to_be_bytes());
1490 for oid in oids {
1491 out.extend_from_slice(&oid.to_be_bytes());
1492 }
1493 out
1494 }
1495
1496 fn bind_body(
1497 portal: &str,
1498 statement: &str,
1499 formats: &[i16],
1500 params: &[Option<&[u8]>],
1501 result_formats: &[i16],
1502 ) -> Vec<u8> {
1503 let mut out = Vec::new();
1504 push_pg_cstring(&mut out, portal);
1505 push_pg_cstring(&mut out, statement);
1506 out.extend_from_slice(&(formats.len() as i16).to_be_bytes());
1507 for format in formats {
1508 out.extend_from_slice(&format.to_be_bytes());
1509 }
1510 out.extend_from_slice(&(params.len() as i16).to_be_bytes());
1511 for param in params {
1512 match param {
1513 Some(bytes) => {
1514 out.extend_from_slice(&(bytes.len() as i32).to_be_bytes());
1515 out.extend_from_slice(bytes);
1516 }
1517 None => out.extend_from_slice(&(-1i32).to_be_bytes()),
1518 }
1519 }
1520 out.extend_from_slice(&(result_formats.len() as i16).to_be_bytes());
1521 for format in result_formats {
1522 out.extend_from_slice(&format.to_be_bytes());
1523 }
1524 out
1525 }
1526
1527 fn describe_body(target: u8, name: &str) -> Vec<u8> {
1528 let mut out = vec![target];
1529 push_pg_cstring(&mut out, name);
1530 out
1531 }
1532
1533 fn execute_body(portal: &str, max_rows: u32) -> Vec<u8> {
1534 let mut out = Vec::new();
1535 push_pg_cstring(&mut out, portal);
1536 out.extend_from_slice(&max_rows.to_be_bytes());
1537 out
1538 }
1539
1540 fn push_pg_cstring(out: &mut Vec<u8>, value: &str) {
1541 out.extend_from_slice(value.as_bytes());
1542 out.push(0);
1543 }
1544}