Skip to main content

reddb_server/wire/postgres/
server.rs

1//! PostgreSQL wire-protocol listener (Phase 3.1 PG parity).
2//!
3//! Accepts TCP connections from PG-compatible clients, drives the startup
4//! handshake, and routes simple/extended-query frames into the existing
5//! `RedDBRuntime::execute_query` path. Results are adapted back into PG
6//! `RowDescription` + `DataRow` frames via `types::value_to_pg_wire_bytes`.
7
8use std::collections::HashMap;
9use std::sync::Arc;
10
11use tokio::io::{AsyncRead, AsyncWrite};
12use tokio::net::TcpListener;
13
14use super::catalog_views::translate_pg_catalog_query;
15use super::protocol::{
16    read_frame, read_startup, write_frame, write_raw_byte, BackendMessage, ColumnDescriptor,
17    DescribeTarget, FrontendMessage, PgWireError, TransactionStatus,
18};
19use super::types::{pg_param_to_value, value_to_pg_wire_bytes, PgOid};
20use crate::runtime::ai::ask_response_envelope::{
21    AskResult, Citation, Mode, SourceRow, Validation, ValidationError, ValidationWarning,
22};
23use crate::runtime::RedDBRuntime;
24use crate::storage::query::unified::{UnifiedRecord, UnifiedResult};
25use crate::storage::schema::Value;
26
27/// Startup-tuned configuration for the PG wire listener.
28#[derive(Debug, Clone)]
29pub struct PgWireConfig {
30    /// TCP bind address ("host:port"). The caller is responsible for
31    /// keeping this disjoint from the native wire / gRPC / HTTP listeners.
32    pub bind_addr: String,
33    /// PG version string sent back in `ParameterStatus`. Many drivers
34    /// sniff this to enable/disable features. RedDB advertises a
35    /// recent-enough version to get the broadest client support.
36    pub server_version: String,
37}
38
39#[derive(Debug, Clone)]
40struct PgPreparedStatement {
41    sql: String,
42    param_type_oids: Vec<u32>,
43}
44
45#[derive(Debug, Clone)]
46struct PgPortal {
47    sql: String,
48    params: Vec<Value>,
49    #[allow(dead_code)]
50    result_format_codes: Vec<i16>,
51    row_description_sent: bool,
52    described_result: Option<crate::runtime::RuntimeQueryResult>,
53}
54
55impl Default for PgWireConfig {
56    fn default() -> Self {
57        Self {
58            bind_addr: "127.0.0.1:5432".to_string(),
59            server_version: "15.0 (RedDB 3.1)".to_string(),
60        }
61    }
62}
63
64/// Run a synchronous, possibly long-parking runtime call without
65/// monopolising the async worker thread.
66///
67/// `QUEUE READ … WAIT <budget>` parks the calling thread on a condvar
68/// for up to the wait budget. Doing that directly on a tokio worker
69/// holds the worker for the whole budget, which serialises every other
70/// connection's query behind the parked waiter — a producer's
71/// `QUEUE PUSH` could not run to notify the waiter, so an enqueue during
72/// a WAIT would never be delivered — and it stalls the runtime's timer
73/// driver. `block_in_place` hands the remaining tasks on this worker to
74/// a sibling worker for the duration while keeping the current OS thread
75/// — and therefore the per-connection thread-locals `execute_query`
76/// relies on (current tenant, connection id, transaction context) —
77/// intact, which `spawn_blocking` would not.
78///
79/// `block_in_place` panics on a current-thread runtime, so we fall back
80/// to a direct call there: a single-threaded runtime has no sibling
81/// worker to hand work to and nothing else to starve (e.g. the
82/// `#[tokio::test]` harnesses that drive one connection at a time).
83fn run_runtime_blocking<T>(f: impl FnOnce() -> T) -> T {
84    use tokio::runtime::{Handle, RuntimeFlavor};
85    match Handle::try_current().map(|h| h.runtime_flavor()) {
86        Ok(RuntimeFlavor::MultiThread) => tokio::task::block_in_place(f),
87        _ => f(),
88    }
89}
90
91/// Spawn the PG wire listener. Blocks until the listener errors out.
92/// Each connection is handled in its own tokio task.
93pub async fn start_pg_wire_listener(
94    config: PgWireConfig,
95    runtime: Arc<RedDBRuntime>,
96) -> Result<(), Box<dyn std::error::Error>> {
97    let listener = TcpListener::bind(&config.bind_addr).await?;
98    tracing::info!(
99        transport = "pg-wire",
100        bind = %config.bind_addr,
101        "listener online"
102    );
103    let cfg = Arc::new(config);
104    loop {
105        let (stream, peer) = listener.accept().await?;
106        let rt = Arc::clone(&runtime);
107        let cfg = Arc::clone(&cfg);
108        let peer_str = peer.to_string();
109        tokio::spawn(async move {
110            if let Err(e) = handle_connection(stream, rt, cfg).await {
111                tracing::warn!(
112                    transport = "pg-wire",
113                    peer = %peer_str,
114                    err = %e,
115                    "connection failed"
116                );
117            }
118        });
119    }
120}
121
122/// Drive one connection's lifetime: startup → authentication → query loop.
123pub(crate) async fn handle_connection<S>(
124    mut stream: S,
125    runtime: Arc<RedDBRuntime>,
126    config: Arc<PgWireConfig>,
127) -> Result<(), PgWireError>
128where
129    S: AsyncRead + AsyncWrite + Unpin + Send,
130{
131    // Handshake. The first frame may be SSLRequest / GSSENCRequest
132    // (pre-auth negotiation) or a plain Startup. Loop once to cover
133    // SSL-not-supported path: reply 'N' and expect the client to send
134    // a regular Startup next.
135    loop {
136        match read_startup(&mut stream).await? {
137            FrontendMessage::SslRequest | FrontendMessage::GssEncRequest => {
138                // 'N' = not supported — client continues in plaintext and
139                // re-sends a normal Startup on the same socket.
140                write_raw_byte(&mut stream, b'N').await?;
141                continue;
142            }
143            FrontendMessage::Startup(params) => {
144                send_auth_ok(&mut stream, &config, &params).await?;
145                break;
146            }
147            FrontendMessage::Unknown { .. } => {
148                // CancelRequest: no response expected; drop the socket.
149                return Ok(());
150            }
151            other => {
152                return Err(PgWireError::Protocol(format!(
153                    "unexpected startup frame: {other:?}"
154                )));
155            }
156        }
157    }
158
159    let mut prepared: HashMap<String, PgPreparedStatement> = HashMap::new();
160    let mut portals: HashMap<String, PgPortal> = HashMap::new();
161
162    // Main query loop.
163    loop {
164        let frame = match read_frame(&mut stream).await {
165            Ok(f) => f,
166            Err(PgWireError::Eof) => return Ok(()),
167            Err(e) => return Err(e),
168        };
169
170        match frame {
171            FrontendMessage::Query(sql) => {
172                handle_simple_query(&mut stream, &runtime, &sql).await?;
173            }
174            FrontendMessage::Parse(msg) => {
175                handle_parse(&mut stream, &mut prepared, msg).await?;
176            }
177            FrontendMessage::Bind(msg) => {
178                handle_bind(&mut stream, &prepared, &mut portals, msg).await?;
179            }
180            FrontendMessage::Describe(msg) => {
181                handle_describe(&mut stream, &runtime, &prepared, &mut portals, msg).await?;
182            }
183            FrontendMessage::Execute(msg) => {
184                handle_execute(&mut stream, &runtime, &mut portals, msg).await?;
185            }
186            FrontendMessage::Close(msg) => {
187                handle_close(&mut stream, &mut prepared, &mut portals, msg).await?;
188            }
189            FrontendMessage::Terminate => return Ok(()),
190            FrontendMessage::Flush => {
191                // Frames are written immediately; no additional marker is
192                // needed. ReadyForQuery belongs to Sync, not Flush.
193                continue;
194            }
195            FrontendMessage::Sync => {
196                write_frame(
197                    &mut stream,
198                    &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
199                )
200                .await?;
201            }
202            FrontendMessage::PasswordMessage(_) => {
203                // Should only arrive during auth. Ignore post-auth.
204                continue;
205            }
206            FrontendMessage::Unknown { tag, .. } => {
207                send_error(
208                    &mut stream,
209                    "0A000",
210                    &format!("unsupported frame tag 0x{tag:02x}"),
211                )
212                .await?;
213                write_frame(
214                    &mut stream,
215                    &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
216                )
217                .await?;
218            }
219            other => {
220                send_error(
221                    &mut stream,
222                    "0A000",
223                    &format!("unsupported frame {other:?}"),
224                )
225                .await?;
226                write_frame(
227                    &mut stream,
228                    &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
229                )
230                .await?;
231            }
232        }
233    }
234}
235
236async fn handle_parse<S>(
237    stream: &mut S,
238    prepared: &mut HashMap<String, PgPreparedStatement>,
239    msg: super::protocol::ParseMessage,
240) -> Result<(), PgWireError>
241where
242    S: AsyncWrite + Unpin,
243{
244    let inferred_param_type_oids = infer_pg_cast_param_type_oids(&msg.query);
245    let sql = rewrite_pg_parameter_casts(&msg.query);
246    let parsed_param_count = match crate::storage::query::modes::parse_multi(&sql) {
247        Ok(parsed) => Some(
248            crate::storage::query::user_params::scan_parameters(&parsed)
249                .into_iter()
250                .map(|param| param.index + 1)
251                .max()
252                .unwrap_or(0),
253        ),
254        Err(err) => {
255            if pg_scalar_select_param_index(&sql).is_none() {
256                send_error(stream, "42601", &err.to_string()).await?;
257                return Ok(());
258            }
259            None
260        }
261    };
262    let mut param_type_oids = msg.param_type_oids;
263    if param_type_oids.is_empty() {
264        let count = parsed_param_count
265            .or_else(|| pg_scalar_select_param_index(&sql).map(|idx| idx + 1))
266            .unwrap_or(0);
267        param_type_oids.resize(count, PgOid::Unknown.as_u32());
268    }
269    for (idx, oid) in inferred_param_type_oids {
270        if idx >= param_type_oids.len() {
271            param_type_oids.resize(idx + 1, PgOid::Unknown.as_u32());
272        }
273        if param_type_oids[idx] == PgOid::Unknown.as_u32() {
274            param_type_oids[idx] = oid;
275        }
276    }
277    prepared.insert(
278        msg.statement,
279        PgPreparedStatement {
280            sql,
281            param_type_oids,
282        },
283    );
284    write_frame(stream, &BackendMessage::ParseComplete).await
285}
286
287async fn handle_bind<S>(
288    stream: &mut S,
289    prepared: &HashMap<String, PgPreparedStatement>,
290    portals: &mut HashMap<String, PgPortal>,
291    msg: super::protocol::BindMessage,
292) -> Result<(), PgWireError>
293where
294    S: AsyncWrite + Unpin,
295{
296    let Some(stmt) = prepared.get(&msg.statement) else {
297        send_error(
298            stream,
299            "26000",
300            &format!("prepared statement {:?} does not exist", msg.statement),
301        )
302        .await?;
303        return Ok(());
304    };
305    let params = match bind_pg_params(stmt, &msg) {
306        Ok(params) => params,
307        Err(err) => {
308            send_error(stream, "22023", &err).await?;
309            return Ok(());
310        }
311    };
312    portals.insert(
313        msg.portal,
314        PgPortal {
315            sql: stmt.sql.clone(),
316            params,
317            result_format_codes: msg.result_format_codes,
318            row_description_sent: false,
319            described_result: None,
320        },
321    );
322    write_frame(stream, &BackendMessage::BindComplete).await
323}
324
325async fn handle_describe<S>(
326    stream: &mut S,
327    runtime: &RedDBRuntime,
328    prepared: &HashMap<String, PgPreparedStatement>,
329    portals: &mut HashMap<String, PgPortal>,
330    msg: super::protocol::DescribeMessage,
331) -> Result<(), PgWireError>
332where
333    S: AsyncWrite + Unpin,
334{
335    match msg.target {
336        DescribeTarget::Statement => {
337            let Some(stmt) = prepared.get(&msg.name) else {
338                send_error(
339                    stream,
340                    "26000",
341                    &format!("prepared statement {:?} does not exist", msg.name),
342                )
343                .await?;
344                return Ok(());
345            };
346            write_frame(
347                stream,
348                &BackendMessage::ParameterDescription(stmt.param_type_oids.clone()),
349            )
350            .await?;
351            if is_ask_query(&stmt.sql) {
352                emit_ask_row_description(stream).await
353            } else {
354                write_frame(stream, &BackendMessage::NoData).await
355            }
356        }
357        DescribeTarget::Portal => {
358            let Some(portal) = portals.get_mut(&msg.name) else {
359                send_error(
360                    stream,
361                    "34000",
362                    &format!("portal {:?} does not exist", msg.name),
363                )
364                .await?;
365                return Ok(());
366            };
367            if is_ask_query(&portal.sql) {
368                emit_ask_row_description(stream).await?;
369                portal.row_description_sent = true;
370                Ok(())
371            } else if is_row_returning_query(&portal.sql) {
372                let result = match execute_pg_query_result(runtime, &portal.sql, &portal.params) {
373                    Ok(result) => result,
374                    Err(err) => {
375                        let code = classify_sqlstate(&err);
376                        send_error(stream, code, &err).await?;
377                        return Ok(());
378                    }
379                };
380                emit_row_description_for_result(stream, &result).await?;
381                portal.row_description_sent = true;
382                portal.described_result = Some(result);
383                Ok(())
384            } else {
385                write_frame(stream, &BackendMessage::NoData).await
386            }
387        }
388    }
389}
390
391async fn handle_execute<S>(
392    stream: &mut S,
393    runtime: &RedDBRuntime,
394    portals: &mut HashMap<String, PgPortal>,
395    msg: super::protocol::ExecuteMessage,
396) -> Result<(), PgWireError>
397where
398    S: AsyncWrite + Unpin,
399{
400    let Some(portal) = portals.get_mut(&msg.portal) else {
401        send_error(
402            stream,
403            "34000",
404            &format!("portal {:?} does not exist", msg.portal),
405        )
406        .await?;
407        return Ok(());
408    };
409    let _max_rows = msg.max_rows;
410    let was_described = portal.row_description_sent || portal.described_result.is_some();
411    portal.row_description_sent = false;
412    let result = match portal.described_result.take() {
413        Some(result) => Ok(result),
414        None => execute_pg_query_result(runtime, &portal.sql, &portal.params),
415    };
416    match result {
417        Ok(result) if was_described => {
418            emit_success_result_without_row_description(stream, &result).await
419        }
420        Ok(result) => emit_success_result(stream, &result).await,
421        Err(err) => {
422            let code = classify_sqlstate(&err);
423            send_error(stream, code, &err).await
424        }
425    }
426}
427
428async fn handle_close<S>(
429    stream: &mut S,
430    prepared: &mut HashMap<String, PgPreparedStatement>,
431    portals: &mut HashMap<String, PgPortal>,
432    msg: super::protocol::CloseMessage,
433) -> Result<(), PgWireError>
434where
435    S: AsyncWrite + Unpin,
436{
437    match msg.target {
438        DescribeTarget::Statement => {
439            prepared.remove(&msg.name);
440        }
441        DescribeTarget::Portal => {
442            portals.remove(&msg.name);
443        }
444    }
445    write_frame(stream, &BackendMessage::CloseComplete).await
446}
447
448fn bind_pg_params(
449    stmt: &PgPreparedStatement,
450    msg: &super::protocol::BindMessage,
451) -> Result<Vec<Value>, String> {
452    if !matches!(msg.param_format_codes.len(), 0 | 1)
453        && msg.param_format_codes.len() != msg.params.len()
454    {
455        return Err("Bind format count must be 0, 1, or match parameter count".to_string());
456    }
457    msg.params
458        .iter()
459        .enumerate()
460        .map(|(idx, param)| {
461            let oid = stmt
462                .param_type_oids
463                .get(idx)
464                .copied()
465                .unwrap_or(PgOid::Unknown.as_u32());
466            let format_code = match msg.param_format_codes.as_slice() {
467                [] => 0,
468                [format] => *format,
469                formats => formats[idx],
470            };
471            pg_param_to_value(oid, format_code, param.as_deref())
472        })
473        .collect()
474}
475
476fn execute_pg_query_result(
477    runtime: &RedDBRuntime,
478    sql: &str,
479    params: &[Value],
480) -> Result<crate::runtime::RuntimeQueryResult, String> {
481    if let Some(result) = try_execute_pg_scalar_select(sql, params) {
482        return Ok(result);
483    }
484    if params.is_empty() {
485        return match translate_pg_catalog_query(runtime, sql) {
486            Ok(Some(result)) => Ok(crate::runtime::RuntimeQueryResult {
487                query: sql.to_string(),
488                mode: crate::storage::query::modes::QueryMode::Sql,
489                statement: "select",
490                engine: "pg-catalog",
491                result,
492                affected_rows: 0,
493                statement_type: "select",
494                bookmark: None,
495            }),
496            Ok(None) => {
497                run_runtime_blocking(|| runtime.execute_query(sql)).map_err(|err| err.to_string())
498            }
499            Err(err) => Err(err.to_string()),
500        };
501    }
502
503    let parsed = crate::storage::query::modes::parse_multi(sql).map_err(|err| err.to_string())?;
504    let bound =
505        crate::storage::query::user_params::bind(&parsed, params).map_err(|err| err.to_string())?;
506    run_runtime_blocking(|| runtime.execute_query_expr(bound)).map_err(|err| err.to_string())
507}
508
509fn try_execute_pg_scalar_select(
510    sql: &str,
511    params: &[Value],
512) -> Option<crate::runtime::RuntimeQueryResult> {
513    let index = pg_scalar_select_param_index(sql)?;
514    let value = params.get(index)?.clone();
515    let mut result = UnifiedResult::with_columns(vec!["?column?".to_string()]);
516    let mut record = UnifiedRecord::new();
517    record.set("?column?", value);
518    result.push(record);
519    Some(crate::runtime::RuntimeQueryResult {
520        query: sql.to_string(),
521        mode: crate::storage::query::modes::QueryMode::Sql,
522        statement: "select",
523        engine: "pg-wire",
524        result,
525        affected_rows: 0,
526        statement_type: "select",
527        bookmark: None,
528    })
529}
530
531fn pg_scalar_select_param_index(sql: &str) -> Option<usize> {
532    let trimmed = sql.trim().trim_end_matches(';').trim();
533    let lower = trimmed.to_ascii_lowercase();
534    let body = lower.strip_prefix("select ")?;
535    let param = if let Some(inner) = body.strip_prefix("cast(") {
536        let end = inner.find(" as ")?;
537        &inner[..end]
538    } else {
539        body.split_whitespace().next()?
540    };
541    let digits = param.strip_prefix('$')?;
542    let n = digits.parse::<usize>().ok()?;
543    n.checked_sub(1)
544}
545
546fn rewrite_pg_parameter_casts(sql: &str) -> String {
547    let mut out = String::with_capacity(sql.len());
548    let bytes = sql.as_bytes();
549    let mut cursor = 0;
550    let mut pos = 0;
551    while pos < bytes.len() {
552        if bytes[pos] != b'$' {
553            pos += 1;
554            continue;
555        }
556        let param_start = pos;
557        pos += 1;
558        let digits_start = pos;
559        while pos < bytes.len() && bytes[pos].is_ascii_digit() {
560            pos += 1;
561        }
562        if digits_start == pos {
563            continue;
564        }
565        if pos + 2 <= bytes.len() && &bytes[pos..pos + 2] == b"::" {
566            let param_end = pos;
567            pos += 2;
568            let type_start = pos;
569            while pos < bytes.len()
570                && (bytes[pos].is_ascii_alphanumeric() || matches!(bytes[pos], b'_' | b'.'))
571            {
572                pos += 1;
573            }
574            if type_start != pos {
575                out.push_str(&sql[cursor..param_start]);
576                out.push_str(&sql[param_start..param_end]);
577                cursor = pos;
578                continue;
579            }
580        }
581    }
582    out.push_str(&sql[cursor..]);
583    out
584}
585
586fn infer_pg_cast_param_type_oids(sql: &str) -> Vec<(usize, u32)> {
587    let mut out = Vec::new();
588    let bytes = sql.as_bytes();
589    let mut pos = 0;
590    while pos < bytes.len() {
591        if bytes[pos] != b'$' {
592            pos += 1;
593            continue;
594        }
595        pos += 1;
596        let digits_start = pos;
597        while pos < bytes.len() && bytes[pos].is_ascii_digit() {
598            pos += 1;
599        }
600        if digits_start == pos {
601            continue;
602        }
603        let Some(param_index) = sql[digits_start..pos]
604            .parse::<usize>()
605            .ok()
606            .and_then(|idx| idx.checked_sub(1))
607        else {
608            continue;
609        };
610        if pos + 2 > bytes.len() || &bytes[pos..pos + 2] != b"::" {
611            continue;
612        }
613        pos += 2;
614        let type_start = pos;
615        while pos < bytes.len()
616            && (bytes[pos].is_ascii_alphanumeric() || matches!(bytes[pos], b'_' | b'.'))
617        {
618            pos += 1;
619        }
620        if type_start == pos {
621            continue;
622        }
623        if let Some(oid) = pg_cast_type_oid(&sql[type_start..pos]) {
624            out.push((param_index, oid));
625        }
626    }
627    out
628}
629
630fn pg_cast_type_oid(ty: &str) -> Option<u32> {
631    let lower = ty.to_ascii_lowercase();
632    let short = lower.rsplit('.').next().unwrap_or(lower.as_str());
633    let oid = match short {
634        "bool" | "boolean" => PgOid::Bool,
635        "int2" | "smallint" => PgOid::Int2,
636        "int" | "int4" | "integer" => PgOid::Int4,
637        "int8" | "bigint" => PgOid::Int8,
638        "float4" | "real" => PgOid::Float4,
639        "float8" | "double" | "doubleprecision" => PgOid::Float8,
640        "numeric" | "decimal" => PgOid::Numeric,
641        "bytea" => PgOid::Bytea,
642        "json" => PgOid::Json,
643        "jsonb" => PgOid::Jsonb,
644        "text" => PgOid::Text,
645        "varchar" | "character varying" => PgOid::Varchar,
646        "uuid" => PgOid::Uuid,
647        "timestamp" => PgOid::Timestamp,
648        "timestamptz" | "timestampz" => PgOid::TimestampTz,
649        "vector" => PgOid::Vector,
650        _ => return None,
651    };
652    Some(oid.as_u32())
653}
654
655fn is_row_returning_query(sql: &str) -> bool {
656    let trimmed = sql.trim_start().to_ascii_lowercase();
657    trimmed.starts_with("select")
658        || trimmed.starts_with("with")
659        || trimmed.starts_with("ask")
660        || trimmed.starts_with("search")
661        || trimmed.starts_with("vector")
662        || trimmed.starts_with("hybrid")
663}
664
665fn is_ask_query(sql: &str) -> bool {
666    sql.trim_start().to_ascii_lowercase().starts_with("ask")
667}
668
669async fn send_auth_ok<S>(
670    stream: &mut S,
671    config: &PgWireConfig,
672    params: &super::protocol::StartupParams,
673) -> Result<(), PgWireError>
674where
675    S: AsyncWrite + Unpin,
676{
677    // Phase 3.1: trust auth. We always send AuthenticationOk.
678    write_frame(stream, &BackendMessage::AuthenticationOk).await?;
679
680    // Standard ParameterStatus frames. Drivers gate capabilities on these.
681    for (name, value) in [
682        ("server_version", config.server_version.as_str()),
683        ("server_encoding", "UTF8"),
684        ("client_encoding", "UTF8"),
685        ("DateStyle", "ISO, MDY"),
686        ("TimeZone", "UTC"),
687        ("integer_datetimes", "on"),
688        ("standard_conforming_strings", "on"),
689        (
690            "application_name",
691            params.get("application_name").unwrap_or(""),
692        ),
693    ] {
694        write_frame(
695            stream,
696            &BackendMessage::ParameterStatus {
697                name: name.to_string(),
698                value: value.to_string(),
699            },
700        )
701        .await?;
702    }
703
704    // BackendKeyData: (pid, secret_key). Used by CancelRequest; we don't
705    // honour cancels in 3.1 so random-ish values are fine.
706    write_frame(
707        stream,
708        &BackendMessage::BackendKeyData {
709            pid: std::process::id(),
710            key: 0xDEADBEEF,
711        },
712    )
713    .await?;
714
715    write_frame(
716        stream,
717        &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
718    )
719    .await?;
720    Ok(())
721}
722
723async fn handle_simple_query<S>(
724    stream: &mut S,
725    runtime: &RedDBRuntime,
726    sql: &str,
727) -> Result<(), PgWireError>
728where
729    S: AsyncWrite + Unpin,
730{
731    // Empty query convention: PG emits EmptyQueryResponse instead of a
732    // CommandComplete. Some clients (psql `\;`) rely on this.
733    if sql.trim().is_empty() {
734        write_frame(stream, &BackendMessage::EmptyQueryResponse).await?;
735        write_frame(
736            stream,
737            &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
738        )
739        .await?;
740        return Ok(());
741    }
742
743    if let Some(tag) = pg_session_compat_command_tag(sql) {
744        write_frame(stream, &BackendMessage::CommandComplete(tag.to_string())).await?;
745        write_frame(
746            stream,
747            &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
748        )
749        .await?;
750        return Ok(());
751    }
752
753    let query_result = match translate_pg_catalog_query(runtime, sql) {
754        Ok(Some(result)) => Ok(crate::runtime::RuntimeQueryResult {
755            query: sql.to_string(),
756            mode: crate::storage::query::modes::QueryMode::Sql,
757            statement: "select",
758            engine: "pg-catalog",
759            result,
760            affected_rows: 0,
761            statement_type: "select",
762            bookmark: None,
763        }),
764        Ok(None) => run_runtime_blocking(|| runtime.execute_query(sql)),
765        Err(err) => Err(err),
766    };
767
768    match query_result {
769        Ok(result) => {
770            emit_success_result(stream, &result).await?;
771        }
772        Err(err) => {
773            // PG SQLSTATE class 42 covers syntax / binding errors; we use
774            // 42P01 (undefined_table) and 42601 (syntax_error) when we can
775            // detect; otherwise fall back to XX000 (internal error).
776            let code = classify_sqlstate(&err.to_string());
777            send_error(stream, code, &err.to_string()).await?;
778        }
779    }
780
781    write_frame(
782        stream,
783        &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
784    )
785    .await?;
786    Ok(())
787}
788
789fn pg_session_compat_command_tag(sql: &str) -> Option<&'static str> {
790    let lower = sql.trim().trim_end_matches(';').to_ascii_lowercase();
791    if lower.starts_with("set ") {
792        return Some("SET");
793    }
794    None
795}
796
797async fn emit_success_result<S>(
798    stream: &mut S,
799    result: &crate::runtime::RuntimeQueryResult,
800) -> Result<(), PgWireError>
801where
802    S: AsyncWrite + Unpin,
803{
804    if result.statement == "ask" {
805        emit_ask_result_row(stream, result).await?;
806        write_frame(
807            stream,
808            &BackendMessage::CommandComplete("SELECT 1".to_string()),
809        )
810        .await?;
811    } else if result_returns_rows(result) {
812        emit_result_rows(stream, &result.result).await?;
813        write_frame(
814            stream,
815            &BackendMessage::CommandComplete(format!("SELECT {}", result.result.records.len())),
816        )
817        .await?;
818    } else {
819        // DDL / DML / config statements: echo the runtime's
820        // high-level statement tag back. PG format is
821        // "<CMD> [<OID>] <COUNT>"; we keep the count where
822        // applicable and fall back to the runtime's message.
823        let tag = match result.statement_type {
824            "insert" => format!("INSERT 0 {}", result.affected_rows),
825            "update" => format!("UPDATE {}", result.affected_rows),
826            "delete" => format!("DELETE {}", result.affected_rows),
827            other => other.to_uppercase(),
828        };
829        write_frame(stream, &BackendMessage::CommandComplete(tag)).await?;
830    }
831    Ok(())
832}
833
834async fn emit_success_result_without_row_description<S>(
835    stream: &mut S,
836    result: &crate::runtime::RuntimeQueryResult,
837) -> Result<(), PgWireError>
838where
839    S: AsyncWrite + Unpin,
840{
841    if result.statement == "ask" {
842        let row = ask_query_result_to_pg_wire_row(result)
843            .ok_or_else(|| PgWireError::Protocol("ASK result missing row body".to_string()))?;
844        write_frame(stream, &BackendMessage::DataRow(row.cells)).await?;
845        write_frame(
846            stream,
847            &BackendMessage::CommandComplete("SELECT 1".to_string()),
848        )
849        .await?;
850    } else if result_returns_rows(result) {
851        emit_result_data_rows(stream, &result.result).await?;
852        write_frame(
853            stream,
854            &BackendMessage::CommandComplete(format!("SELECT {}", result.result.records.len())),
855        )
856        .await?;
857    } else {
858        let tag = match result.statement_type {
859            "insert" => format!("INSERT 0 {}", result.affected_rows),
860            "update" => format!("UPDATE {}", result.affected_rows),
861            "delete" => format!("DELETE {}", result.affected_rows),
862            other => other.to_uppercase(),
863        };
864        write_frame(stream, &BackendMessage::CommandComplete(tag)).await?;
865    }
866    Ok(())
867}
868
869async fn emit_row_description_for_result<S>(
870    stream: &mut S,
871    result: &crate::runtime::RuntimeQueryResult,
872) -> Result<(), PgWireError>
873where
874    S: AsyncWrite + Unpin,
875{
876    if result.statement == "ask" {
877        emit_ask_row_description(stream).await
878    } else if result_returns_rows(result) {
879        emit_result_row_description(stream, &result.result).await
880    } else {
881        write_frame(stream, &BackendMessage::NoData).await
882    }
883}
884
885fn result_returns_rows(result: &crate::runtime::RuntimeQueryResult) -> bool {
886    result.statement_type == "select"
887}
888
889async fn emit_result_rows<S>(
890    stream: &mut S,
891    result: &crate::storage::query::unified::UnifiedResult,
892) -> Result<(), PgWireError>
893where
894    S: AsyncWrite + Unpin,
895{
896    emit_result_row_description(stream, result).await?;
897    emit_result_data_rows(stream, result).await
898}
899
900async fn emit_result_row_description<S>(
901    stream: &mut S,
902    result: &crate::storage::query::unified::UnifiedResult,
903) -> Result<(), PgWireError>
904where
905    S: AsyncWrite + Unpin,
906{
907    // RowDescription: derived from the first record's column ordering.
908    // When `result.columns` is non-empty we honour that order; otherwise
909    // we synthesise one from the record's field order.
910    let columns: Vec<String> = if !result.columns.is_empty() {
911        result.columns.clone()
912    } else if let Some(first) = result.records.first() {
913        record_field_names(first)
914    } else {
915        Vec::new()
916    };
917
918    // Peek at the first record for per-column type OIDs. When there's no
919    // data row we fall back to TEXT for every column — clients render
920    // empty result sets happily.
921    let type_oids: Vec<PgOid> = columns
922        .iter()
923        .map(|col| {
924            result
925                .records
926                .first()
927                .and_then(|r| record_get(r, col))
928                .map(PgOid::from_value)
929                .unwrap_or(PgOid::Text)
930        })
931        .collect();
932
933    let descriptors: Vec<ColumnDescriptor> = columns
934        .iter()
935        .zip(type_oids.iter())
936        .map(|(name, oid)| ColumnDescriptor {
937            name: name.clone(),
938            table_oid: 0,
939            column_attr: 0,
940            type_oid: oid.as_u32(),
941            type_size: -1,
942            type_mod: -1,
943            format: 0,
944        })
945        .collect();
946
947    write_frame(stream, &BackendMessage::RowDescription(descriptors)).await
948}
949
950async fn emit_result_data_rows<S>(
951    stream: &mut S,
952    result: &crate::storage::query::unified::UnifiedResult,
953) -> Result<(), PgWireError>
954where
955    S: AsyncWrite + Unpin,
956{
957    let columns: Vec<String> = if !result.columns.is_empty() {
958        result.columns.clone()
959    } else if let Some(first) = result.records.first() {
960        record_field_names(first)
961    } else {
962        Vec::new()
963    };
964    for record in &result.records {
965        let fields: Vec<Option<Vec<u8>>> = columns
966            .iter()
967            .map(|col| record_get(record, col).and_then(value_to_pg_wire_bytes))
968            .collect();
969        write_frame(stream, &BackendMessage::DataRow(fields)).await?;
970    }
971
972    Ok(())
973}
974
975async fn emit_ask_result_row<S>(
976    stream: &mut S,
977    result: &crate::runtime::RuntimeQueryResult,
978) -> Result<(), PgWireError>
979where
980    S: AsyncWrite + Unpin,
981{
982    let row = ask_query_result_to_pg_wire_row(result)
983        .ok_or_else(|| PgWireError::Protocol("ASK result missing row body".to_string()))?;
984
985    emit_ask_row_description(stream).await?;
986    write_frame(stream, &BackendMessage::DataRow(row.cells)).await?;
987    Ok(())
988}
989
990async fn emit_ask_row_description<S>(stream: &mut S) -> Result<(), PgWireError>
991where
992    S: AsyncWrite + Unpin,
993{
994    let descriptors: Vec<ColumnDescriptor> = crate::runtime::ai::pg_wire_ask_row_encoder::columns()
995        .iter()
996        .map(|col| ColumnDescriptor {
997            name: col.name.to_string(),
998            table_oid: 0,
999            column_attr: 0,
1000            type_oid: col.oid.as_u32(),
1001            type_size: -1,
1002            type_mod: -1,
1003            format: 0,
1004        })
1005        .collect();
1006    write_frame(stream, &BackendMessage::RowDescription(descriptors)).await
1007}
1008
1009fn ask_query_result_to_pg_wire_row(
1010    result: &crate::runtime::RuntimeQueryResult,
1011) -> Option<crate::runtime::ai::pg_wire_ask_row_encoder::AskRow> {
1012    if result.statement != "ask" {
1013        return None;
1014    }
1015    let record = result.result.records.first()?;
1016    let sources_flat_json =
1017        json_field(record, "sources_flat").unwrap_or(crate::json::Value::Array(Vec::new()));
1018    let citations_json =
1019        json_field(record, "citations").unwrap_or(crate::json::Value::Array(Vec::new()));
1020    let validation_json = json_field(record, "validation")
1021        .unwrap_or_else(|| crate::json::Value::Object(Default::default()));
1022
1023    let effective_mode = match text_field(record, "mode").as_deref() {
1024        Some("lenient") => Mode::Lenient,
1025        _ => Mode::Strict,
1026    };
1027
1028    let ask = AskResult {
1029        answer: text_field(record, "answer")?,
1030        sources_flat: ask_sources_flat(&sources_flat_json),
1031        citations: ask_citations(&citations_json),
1032        validation: ask_validation(&validation_json),
1033        cache_hit: bool_field(record, "cache_hit").unwrap_or(false),
1034        provider: text_field(record, "provider").unwrap_or_default(),
1035        model: text_field(record, "model").unwrap_or_default(),
1036        prompt_tokens: u32_field(record, "prompt_tokens").unwrap_or(0),
1037        completion_tokens: u32_field(record, "completion_tokens").unwrap_or(0),
1038        cost_usd: f64_field(record, "cost_usd").unwrap_or(0.0),
1039        effective_mode,
1040        retry_count: u32_field(record, "retry_count").unwrap_or(0),
1041    };
1042
1043    Some(crate::runtime::ai::pg_wire_ask_row_encoder::encode(&ask))
1044}
1045
1046fn record_field<'a>(record: &'a UnifiedRecord, key: &str) -> Option<&'a Value> {
1047    record.iter_fields().find_map(|(name, value)| {
1048        let name: &str = name;
1049        (name == key).then_some(value)
1050    })
1051}
1052
1053fn text_field(record: &UnifiedRecord, key: &str) -> Option<String> {
1054    match record_field(record, key)? {
1055        Value::Text(s) => Some(s.to_string()),
1056        Value::Email(s) | Value::Url(s) | Value::NodeRef(s) | Value::EdgeRef(s) => Some(s.clone()),
1057        other => Some(other.to_string()),
1058    }
1059}
1060
1061fn bool_field(record: &UnifiedRecord, key: &str) -> Option<bool> {
1062    match record_field(record, key)? {
1063        Value::Boolean(value) => Some(*value),
1064        _ => None,
1065    }
1066}
1067
1068fn u32_field(record: &UnifiedRecord, key: &str) -> Option<u32> {
1069    match record_field(record, key)? {
1070        Value::Integer(n) => (*n >= 0).then_some((*n).min(u32::MAX as i64) as u32),
1071        Value::UnsignedInteger(n) => Some((*n).min(u32::MAX as u64) as u32),
1072        Value::BigInt(n)
1073        | Value::TimestampMs(n)
1074        | Value::Timestamp(n)
1075        | Value::Duration(n)
1076        | Value::Decimal(n) => (*n >= 0).then_some((*n).min(u32::MAX as i64) as u32),
1077        Value::Float(n) => (*n >= 0.0).then_some((*n).min(u32::MAX as f64) as u32),
1078        _ => None,
1079    }
1080}
1081
1082fn f64_field(record: &UnifiedRecord, key: &str) -> Option<f64> {
1083    match record_field(record, key)? {
1084        Value::Integer(n) => Some(*n as f64),
1085        Value::UnsignedInteger(n) => Some(*n as f64),
1086        Value::BigInt(n)
1087        | Value::TimestampMs(n)
1088        | Value::Timestamp(n)
1089        | Value::Duration(n)
1090        | Value::Decimal(n) => Some(*n as f64),
1091        Value::Float(n) => Some(*n),
1092        _ => None,
1093    }
1094}
1095
1096fn json_field(record: &UnifiedRecord, key: &str) -> Option<crate::json::Value> {
1097    match record_field(record, key)? {
1098        Value::Json(bytes) => crate::json::from_slice(bytes).ok(),
1099        Value::Text(text) => crate::json::from_str(text).ok(),
1100        _ => None,
1101    }
1102}
1103
1104fn ask_sources_flat(value: &crate::json::Value) -> Vec<SourceRow> {
1105    value
1106        .as_array()
1107        .unwrap_or(&[])
1108        .iter()
1109        .filter_map(|source| {
1110            let urn = source
1111                .get("urn")
1112                .and_then(crate::json::Value::as_str)?
1113                .to_string();
1114            let payload = source
1115                .get("payload")
1116                .and_then(crate::json::Value::as_str)
1117                .map(ToString::to_string)
1118                .unwrap_or_else(|| source.to_string_compact());
1119            Some(SourceRow { urn, payload })
1120        })
1121        .collect()
1122}
1123
1124fn ask_citations(value: &crate::json::Value) -> Vec<Citation> {
1125    value
1126        .as_array()
1127        .unwrap_or(&[])
1128        .iter()
1129        .filter_map(|citation| {
1130            let marker = citation
1131                .get("marker")
1132                .and_then(crate::json::Value::as_u64)?;
1133            let urn = citation
1134                .get("urn")
1135                .and_then(crate::json::Value::as_str)?
1136                .to_string();
1137            Some(Citation {
1138                marker: marker.min(u32::MAX as u64) as u32,
1139                urn,
1140            })
1141        })
1142        .collect()
1143}
1144
1145fn ask_validation(value: &crate::json::Value) -> Validation {
1146    Validation {
1147        ok: value
1148            .get("ok")
1149            .and_then(crate::json::Value::as_bool)
1150            .unwrap_or(true),
1151        warnings: validation_items(value, "warnings")
1152            .into_iter()
1153            .map(|(kind, detail)| ValidationWarning { kind, detail })
1154            .collect(),
1155        errors: validation_items(value, "errors")
1156            .into_iter()
1157            .map(|(kind, detail)| ValidationError { kind, detail })
1158            .collect(),
1159    }
1160}
1161
1162fn validation_items(value: &crate::json::Value, key: &str) -> Vec<(String, String)> {
1163    value
1164        .get(key)
1165        .and_then(crate::json::Value::as_array)
1166        .unwrap_or(&[])
1167        .iter()
1168        .filter_map(|item| {
1169            Some((
1170                item.get("kind")
1171                    .and_then(crate::json::Value::as_str)?
1172                    .to_string(),
1173                item.get("detail")
1174                    .and_then(crate::json::Value::as_str)
1175                    .unwrap_or("")
1176                    .to_string(),
1177            ))
1178        })
1179        .collect()
1180}
1181
1182/// Best-effort field lookup on a `UnifiedRecord`. The record API lives in
1183/// `storage::query::unified` and today uses `HashMap<String, Value>` under
1184/// the hood — we use `get` if it exists, else fall back to serialised map.
1185fn record_get<'a>(record: &'a UnifiedRecord, key: &str) -> Option<&'a Value> {
1186    record.get(key)
1187}
1188
1189/// Extract column names in iteration order from a single record. When
1190/// the caller didn't supply an explicit `columns` projection we use the
1191/// first record's field ordering as the canonical tuple shape.
1192///
1193/// HashMap iteration order is non-deterministic — for Phase 3.1 we
1194/// accept the shuffle since PG clients receive the ordered header via
1195/// RowDescription and match cells positionally. A stable ordering
1196/// would require keeping an insertion-order index alongside `values`.
1197fn record_field_names(record: &UnifiedRecord) -> Vec<String> {
1198    // `column_names()` merges the columnar scan side-channel with
1199    // the HashMap so scan rows (which populate only columnar) still
1200    // surface their field names in PG wire output.
1201    record
1202        .column_names()
1203        .into_iter()
1204        .map(|k| k.to_string())
1205        .collect()
1206}
1207
1208async fn send_error<S>(stream: &mut S, code: &str, message: &str) -> Result<(), PgWireError>
1209where
1210    S: AsyncWrite + Unpin,
1211{
1212    write_frame(
1213        stream,
1214        &BackendMessage::ErrorResponse {
1215            severity: "ERROR".to_string(),
1216            code: code.to_string(),
1217            message: message.to_string(),
1218        },
1219    )
1220    .await
1221}
1222
1223/// Heuristically map a runtime error message onto a PG SQLSTATE. Full
1224/// coverage would map every `RedDBError` variant; this is enough for the
1225/// common psql / JDBC paths.
1226fn classify_sqlstate(msg: &str) -> &'static str {
1227    let lower = msg.to_ascii_lowercase();
1228    if lower.contains("not found") || lower.contains("does not exist") {
1229        // 42P01 undefined_table; close enough for collection-not-found.
1230        "42P01"
1231    } else if lower.contains("parse") || lower.contains("expected") || lower.contains("syntax") {
1232        "42601"
1233    } else if lower.contains("already exists") {
1234        "42P07"
1235    } else if lower.contains("permission") || lower.contains("auth") {
1236        "28000"
1237    } else {
1238        "XX000"
1239    }
1240}
1241
1242#[cfg(test)]
1243mod tests {
1244    use super::*;
1245    use crate::api::RedDBOptions;
1246    use crate::runtime::RuntimeQueryResult;
1247    use crate::storage::query::modes::QueryMode;
1248    use crate::storage::query::unified::UnifiedResult;
1249    use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
1250
1251    #[tokio::test]
1252    async fn extended_parse_bind_execute_returns_rows() {
1253        let runtime = Arc::new(RedDBRuntime::with_options(RedDBOptions::in_memory()).unwrap());
1254        let config = Arc::new(PgWireConfig::default());
1255        let (server_io, mut client_io) = tokio::io::duplex(64 * 1024);
1256        let server = tokio::spawn(async move {
1257            handle_connection(server_io, runtime, config).await.unwrap();
1258        });
1259
1260        write_startup(&mut client_io).await;
1261        read_until_ready(&mut client_io).await;
1262
1263        write_frontend_frame(
1264            &mut client_io,
1265            b'P',
1266            parse_body("", "SELECT $1::int", &[PgOid::Int4.as_u32()]),
1267        )
1268        .await;
1269        write_frontend_frame(
1270            &mut client_io,
1271            b'B',
1272            bind_body("", "", &[0], &[Some(b"42".as_slice())], &[]),
1273        )
1274        .await;
1275        write_frontend_frame(&mut client_io, b'D', describe_body(b'P', "")).await;
1276        write_frontend_frame(&mut client_io, b'E', execute_body("", 0)).await;
1277        write_frontend_frame(&mut client_io, b'S', Vec::new()).await;
1278
1279        let frames = read_until_ready(&mut client_io).await;
1280        assert_eq!(
1281            frames.iter().map(|(tag, _)| *tag).collect::<Vec<_>>(),
1282            b"12TDCZ"
1283        );
1284        let columns = decode_row_description(&frames[2].1);
1285        assert_eq!(columns.len(), 1);
1286        let cells = decode_data_row(&frames[3].1);
1287        assert_eq!(cells.len(), 1);
1288        assert_eq!(cells[0].as_deref(), Some(b"42".as_slice()));
1289        assert_eq!(decode_command_complete(&frames[4].1), "SELECT 1");
1290
1291        write_frontend_frame(&mut client_io, b'X', Vec::new()).await;
1292        server.await.unwrap();
1293    }
1294
1295    #[test]
1296    fn infer_pg_cast_param_type_oids_from_parameter_casts() {
1297        assert_eq!(
1298            infer_pg_cast_param_type_oids("INSERT INTO t (id, name) VALUES ($1::int, $2::text)"),
1299            vec![(0, PgOid::Int4.as_u32()), (1, PgOid::Text.as_u32())]
1300        );
1301        assert_eq!(
1302            infer_pg_cast_param_type_oids("SEARCH SIMILAR [1.0] COLLECTION v LIMIT $1::int8"),
1303            vec![(0, PgOid::Int8.as_u32())]
1304        );
1305    }
1306
1307    #[test]
1308    fn pg_session_compat_accepts_driver_setup_set_commands() {
1309        assert_eq!(
1310            pg_session_compat_command_tag("SET extra_float_digits = 3"),
1311            Some("SET")
1312        );
1313        assert_eq!(
1314            pg_session_compat_command_tag("SET application_name = 'pgjdbc'"),
1315            Some("SET")
1316        );
1317        assert_eq!(pg_session_compat_command_tag("SELECT 1"), None);
1318    }
1319
1320    #[tokio::test]
1321    async fn ask_success_result_uses_canonical_pg_wire_row_shape() {
1322        let mut result = UnifiedResult::with_columns(vec![
1323            "answer".into(),
1324            "provider".into(),
1325            "model".into(),
1326            "prompt_tokens".into(),
1327            "completion_tokens".into(),
1328            "sources_count".into(),
1329            "sources_flat".into(),
1330            "citations".into(),
1331            "validation".into(),
1332        ]);
1333        let mut record = UnifiedRecord::new();
1334        record.set("answer", Value::text("Deploy failed [^1]."));
1335        record.set("provider", Value::text("openai"));
1336        record.set("model", Value::text("gpt-4o-mini"));
1337        record.set("prompt_tokens", Value::Integer(11));
1338        record.set("completion_tokens", Value::Integer(7));
1339        record.set(
1340            "sources_flat",
1341            Value::Json(
1342                br#"[{"urn":"urn:reddb:row:deployments:1","kind":"row","collection":"deployments","id":"1"}]"#
1343                    .to_vec(),
1344            ),
1345        );
1346        record.set(
1347            "citations",
1348            Value::Json(br#"[{"marker":1,"urn":"urn:reddb:row:deployments:1"}]"#.to_vec()),
1349        );
1350        record.set(
1351            "validation",
1352            Value::Json(br#"{"ok":true,"warnings":[],"errors":[]}"#.to_vec()),
1353        );
1354        result.push(record);
1355
1356        let qr = RuntimeQueryResult {
1357            query: "ASK 'why did deploy fail?'".to_string(),
1358            mode: QueryMode::Sql,
1359            statement: "ask",
1360            engine: "runtime-ai",
1361            result,
1362            affected_rows: 0,
1363            statement_type: "select",
1364            bookmark: None,
1365        };
1366
1367        let mut out = Vec::new();
1368        emit_success_result(&mut out, &qr).await.unwrap();
1369        let frames = decode_frames(&out);
1370
1371        assert_eq!(
1372            frames.iter().map(|(tag, _)| *tag).collect::<Vec<_>>(),
1373            b"TDC"
1374        );
1375
1376        let columns = decode_row_description(frames[0].1);
1377        assert_eq!(
1378            columns,
1379            vec![
1380                ("answer".to_string(), PgOid::Text.as_u32()),
1381                ("cache_hit".to_string(), PgOid::Bool.as_u32()),
1382                ("citations".to_string(), PgOid::Jsonb.as_u32()),
1383                ("completion_tokens".to_string(), PgOid::Int8.as_u32()),
1384                ("cost_usd".to_string(), PgOid::Numeric.as_u32()),
1385                ("mode".to_string(), PgOid::Text.as_u32()),
1386                ("model".to_string(), PgOid::Text.as_u32()),
1387                ("prompt_tokens".to_string(), PgOid::Int8.as_u32()),
1388                ("provider".to_string(), PgOid::Text.as_u32()),
1389                ("retry_count".to_string(), PgOid::Int8.as_u32()),
1390                ("sources_flat".to_string(), PgOid::Jsonb.as_u32()),
1391                ("validation".to_string(), PgOid::Jsonb.as_u32()),
1392            ]
1393        );
1394
1395        let cells = decode_data_row(frames[1].1);
1396        assert_eq!(cells.len(), 12);
1397        assert_eq!(cells[0].as_deref(), Some(b"Deploy failed [^1].".as_slice()));
1398        assert_eq!(cells[1].as_deref(), Some(b"f".as_slice()));
1399        assert_eq!(cells[4].as_deref(), Some(b"0".as_slice()));
1400        assert_eq!(cells[5].as_deref(), Some(b"strict".as_slice()));
1401        assert_eq!(cells[9].as_deref(), Some(b"0".as_slice()));
1402        assert!(std::str::from_utf8(cells[10].as_deref().unwrap())
1403            .unwrap()
1404            .contains(r#""payload""#));
1405        assert_eq!(decode_command_complete(frames[2].1), "SELECT 1");
1406    }
1407
1408    fn decode_frames(bytes: &[u8]) -> Vec<(u8, &[u8])> {
1409        let mut pos = 0;
1410        let mut frames = Vec::new();
1411        while pos < bytes.len() {
1412            let tag = bytes[pos];
1413            let len = u32::from_be_bytes([
1414                bytes[pos + 1],
1415                bytes[pos + 2],
1416                bytes[pos + 3],
1417                bytes[pos + 4],
1418            ]) as usize;
1419            let body_start = pos + 5;
1420            let body_end = pos + 1 + len;
1421            frames.push((tag, &bytes[body_start..body_end]));
1422            pos = body_end;
1423        }
1424        frames
1425    }
1426
1427    fn decode_row_description(body: &[u8]) -> Vec<(String, u32)> {
1428        let count = i16::from_be_bytes([body[0], body[1]]) as usize;
1429        let mut pos = 2;
1430        let mut columns = Vec::with_capacity(count);
1431        for _ in 0..count {
1432            let end = body[pos..].iter().position(|&b| b == 0).unwrap() + pos;
1433            let name = std::str::from_utf8(&body[pos..end]).unwrap().to_string();
1434            pos = end + 1;
1435            pos += 4; // table oid
1436            pos += 2; // column attr
1437            let oid = u32::from_be_bytes([body[pos], body[pos + 1], body[pos + 2], body[pos + 3]]);
1438            pos += 4;
1439            pos += 2; // type size
1440            pos += 4; // type mod
1441            pos += 2; // format
1442            columns.push((name, oid));
1443        }
1444        columns
1445    }
1446
1447    fn decode_data_row(body: &[u8]) -> Vec<Option<Vec<u8>>> {
1448        let count = i16::from_be_bytes([body[0], body[1]]) as usize;
1449        let mut pos = 2;
1450        let mut cells = Vec::with_capacity(count);
1451        for _ in 0..count {
1452            let len = i32::from_be_bytes([body[pos], body[pos + 1], body[pos + 2], body[pos + 3]]);
1453            pos += 4;
1454            if len < 0 {
1455                cells.push(None);
1456            } else {
1457                let len = len as usize;
1458                cells.push(Some(body[pos..pos + len].to_vec()));
1459                pos += len;
1460            }
1461        }
1462        cells
1463    }
1464
1465    fn decode_command_complete(body: &[u8]) -> &str {
1466        let nul = body.iter().position(|&b| b == 0).unwrap_or(body.len());
1467        std::str::from_utf8(&body[..nul]).unwrap()
1468    }
1469
1470    async fn write_startup<W: AsyncWrite + Unpin>(stream: &mut W) {
1471        let mut payload = Vec::new();
1472        payload.extend_from_slice(&crate::wire::postgres::protocol::PG_PROTOCOL_V3.to_be_bytes());
1473        payload.extend_from_slice(b"user\0reddb\0");
1474        payload.push(0);
1475        let len = (payload.len() + 4) as u32;
1476        stream.write_all(&len.to_be_bytes()).await.unwrap();
1477        stream.write_all(&payload).await.unwrap();
1478    }
1479
1480    async fn write_frontend_frame<W: AsyncWrite + Unpin>(
1481        stream: &mut W,
1482        tag: u8,
1483        payload: Vec<u8>,
1484    ) {
1485        stream.write_all(&[tag]).await.unwrap();
1486        stream
1487            .write_all(&((payload.len() + 4) as u32).to_be_bytes())
1488            .await
1489            .unwrap();
1490        stream.write_all(&payload).await.unwrap();
1491    }
1492
1493    async fn read_backend_frame<R: AsyncRead + Unpin>(stream: &mut R) -> (u8, Vec<u8>) {
1494        let mut tag = [0u8; 1];
1495        stream.read_exact(&mut tag).await.unwrap();
1496        let mut len = [0u8; 4];
1497        stream.read_exact(&mut len).await.unwrap();
1498        let len = u32::from_be_bytes(len) as usize;
1499        let mut body = vec![0u8; len - 4];
1500        stream.read_exact(&mut body).await.unwrap();
1501        (tag[0], body)
1502    }
1503
1504    async fn read_until_ready<R: AsyncRead + Unpin>(stream: &mut R) -> Vec<(u8, Vec<u8>)> {
1505        let mut frames = Vec::new();
1506        loop {
1507            let frame = read_backend_frame(stream).await;
1508            let done = frame.0 == b'Z';
1509            frames.push(frame);
1510            if done {
1511                return frames;
1512            }
1513        }
1514    }
1515
1516    fn parse_body(statement: &str, query: &str, oids: &[u32]) -> Vec<u8> {
1517        let mut out = Vec::new();
1518        push_pg_cstring(&mut out, statement);
1519        push_pg_cstring(&mut out, query);
1520        out.extend_from_slice(&(oids.len() as i16).to_be_bytes());
1521        for oid in oids {
1522            out.extend_from_slice(&oid.to_be_bytes());
1523        }
1524        out
1525    }
1526
1527    fn bind_body(
1528        portal: &str,
1529        statement: &str,
1530        formats: &[i16],
1531        params: &[Option<&[u8]>],
1532        result_formats: &[i16],
1533    ) -> Vec<u8> {
1534        let mut out = Vec::new();
1535        push_pg_cstring(&mut out, portal);
1536        push_pg_cstring(&mut out, statement);
1537        out.extend_from_slice(&(formats.len() as i16).to_be_bytes());
1538        for format in formats {
1539            out.extend_from_slice(&format.to_be_bytes());
1540        }
1541        out.extend_from_slice(&(params.len() as i16).to_be_bytes());
1542        for param in params {
1543            match param {
1544                Some(bytes) => {
1545                    out.extend_from_slice(&(bytes.len() as i32).to_be_bytes());
1546                    out.extend_from_slice(bytes);
1547                }
1548                None => out.extend_from_slice(&(-1i32).to_be_bytes()),
1549            }
1550        }
1551        out.extend_from_slice(&(result_formats.len() as i16).to_be_bytes());
1552        for format in result_formats {
1553            out.extend_from_slice(&format.to_be_bytes());
1554        }
1555        out
1556    }
1557
1558    fn describe_body(target: u8, name: &str) -> Vec<u8> {
1559        let mut out = vec![target];
1560        push_pg_cstring(&mut out, name);
1561        out
1562    }
1563
1564    fn execute_body(portal: &str, max_rows: u32) -> Vec<u8> {
1565        let mut out = Vec::new();
1566        push_pg_cstring(&mut out, portal);
1567        out.extend_from_slice(&max_rows.to_be_bytes());
1568        out
1569    }
1570
1571    fn push_pg_cstring(out: &mut Vec<u8>, value: &str) {
1572        out.extend_from_slice(value.as_bytes());
1573        out.push(0);
1574    }
1575}