Skip to main content

reddb_server/wire/postgres/
server.rs

1//! PostgreSQL wire-protocol listener (Phase 3.1 PG parity).
2//!
3//! Accepts TCP connections from PG-compatible clients, drives the startup
4//! handshake, and routes simple/extended-query frames into the existing
5//! `RedDBRuntime::execute_query` path. Results are adapted back into PG
6//! `RowDescription` + `DataRow` frames via `types::value_to_pg_wire_bytes`.
7
8use std::collections::HashMap;
9use std::sync::Arc;
10
11use tokio::io::{AsyncRead, AsyncWrite};
12use tokio::net::TcpListener;
13
14use super::catalog_views::translate_pg_catalog_query;
15use super::protocol::{
16    read_frame, read_startup, write_frame, write_raw_byte, BackendMessage, ColumnDescriptor,
17    DescribeTarget, FrontendMessage, PgWireError, TransactionStatus,
18};
19use super::types::{pg_param_to_value, value_to_pg_wire_bytes, PgOid};
20use crate::runtime::ai::ask_response_envelope::{
21    AskResult, Citation, Mode, SourceRow, Validation, ValidationError, ValidationWarning,
22};
23use crate::runtime::RedDBRuntime;
24use crate::storage::query::unified::{UnifiedRecord, UnifiedResult};
25use crate::storage::schema::Value;
26
27/// Startup-tuned configuration for the PG wire listener.
28#[derive(Debug, Clone)]
29pub struct PgWireConfig {
30    /// TCP bind address ("host:port"). The caller is responsible for
31    /// keeping this disjoint from the native wire / gRPC / HTTP listeners.
32    pub bind_addr: String,
33    /// PG version string sent back in `ParameterStatus`. Many drivers
34    /// sniff this to enable/disable features. RedDB advertises a
35    /// recent-enough version to get the broadest client support.
36    pub server_version: String,
37}
38
39#[derive(Debug, Clone)]
40struct PgPreparedStatement {
41    sql: String,
42    param_type_oids: Vec<u32>,
43}
44
45#[derive(Debug, Clone)]
46struct PgPortal {
47    sql: String,
48    params: Vec<Value>,
49    #[allow(dead_code)]
50    result_format_codes: Vec<i16>,
51    row_description_sent: bool,
52    described_result: Option<crate::runtime::RuntimeQueryResult>,
53}
54
55impl Default for PgWireConfig {
56    fn default() -> Self {
57        Self {
58            bind_addr: "127.0.0.1:5432".to_string(),
59            server_version: "15.0 (RedDB 3.1)".to_string(),
60        }
61    }
62}
63
64/// Run a synchronous, possibly long-parking runtime call without
65/// monopolising the async worker thread.
66///
67/// `QUEUE READ … WAIT <budget>` parks the calling thread on a condvar
68/// for up to the wait budget. Doing that directly on a tokio worker
69/// holds the worker for the whole budget, which serialises every other
70/// connection's query behind the parked waiter — a producer's
71/// `QUEUE PUSH` could not run to notify the waiter, so an enqueue during
72/// a WAIT would never be delivered — and it stalls the runtime's timer
73/// driver. `block_in_place` hands the remaining tasks on this worker to
74/// a sibling worker for the duration while keeping the current OS thread
75/// — and therefore the per-connection thread-locals `execute_query`
76/// relies on (current tenant, connection id, transaction context) —
77/// intact, which `spawn_blocking` would not.
78///
79/// `block_in_place` panics on a current-thread runtime, so we fall back
80/// to a direct call there: a single-threaded runtime has no sibling
81/// worker to hand work to and nothing else to starve (e.g. the
82/// `#[tokio::test]` harnesses that drive one connection at a time).
83fn run_runtime_blocking<T>(f: impl FnOnce() -> T) -> T {
84    use tokio::runtime::{Handle, RuntimeFlavor};
85    match Handle::try_current().map(|h| h.runtime_flavor()) {
86        Ok(RuntimeFlavor::MultiThread) => tokio::task::block_in_place(f),
87        _ => f(),
88    }
89}
90
91/// Spawn the PG wire listener. Blocks until the listener errors out.
92/// Each connection is handled in its own tokio task.
93pub async fn start_pg_wire_listener(
94    config: PgWireConfig,
95    runtime: Arc<RedDBRuntime>,
96) -> Result<(), Box<dyn std::error::Error>> {
97    let listener = TcpListener::bind(&config.bind_addr).await?;
98    tracing::info!(
99        transport = "pg-wire",
100        bind = %config.bind_addr,
101        "listener online"
102    );
103    let cfg = Arc::new(config);
104    loop {
105        let (stream, peer) = listener.accept().await?;
106        let rt = Arc::clone(&runtime);
107        let cfg = Arc::clone(&cfg);
108        let peer_str = peer.to_string();
109        tokio::spawn(async move {
110            if let Err(e) = handle_connection(stream, rt, cfg).await {
111                tracing::warn!(
112                    transport = "pg-wire",
113                    peer = %peer_str,
114                    err = %e,
115                    "connection failed"
116                );
117            }
118        });
119    }
120}
121
122/// Drive one connection's lifetime: startup → authentication → query loop.
123pub(crate) async fn handle_connection<S>(
124    mut stream: S,
125    runtime: Arc<RedDBRuntime>,
126    config: Arc<PgWireConfig>,
127) -> Result<(), PgWireError>
128where
129    S: AsyncRead + AsyncWrite + Unpin + Send,
130{
131    // Handshake. The first frame may be SSLRequest / GSSENCRequest
132    // (pre-auth negotiation) or a plain Startup. Loop once to cover
133    // SSL-not-supported path: reply 'N' and expect the client to send
134    // a regular Startup next.
135    loop {
136        match read_startup(&mut stream).await? {
137            FrontendMessage::SslRequest | FrontendMessage::GssEncRequest => {
138                // 'N' = not supported — client continues in plaintext and
139                // re-sends a normal Startup on the same socket.
140                write_raw_byte(&mut stream, b'N').await?;
141                continue;
142            }
143            FrontendMessage::Startup(params) => {
144                send_auth_ok(&mut stream, &config, &params).await?;
145                break;
146            }
147            FrontendMessage::Unknown { .. } => {
148                // CancelRequest: no response expected; drop the socket.
149                return Ok(());
150            }
151            other => {
152                return Err(PgWireError::Protocol(format!(
153                    "unexpected startup frame: {other:?}"
154                )));
155            }
156        }
157    }
158
159    let mut prepared: HashMap<String, PgPreparedStatement> = HashMap::new();
160    let mut portals: HashMap<String, PgPortal> = HashMap::new();
161
162    // Main query loop.
163    loop {
164        let frame = match read_frame(&mut stream).await {
165            Ok(f) => f,
166            Err(PgWireError::Eof) => return Ok(()),
167            Err(e) => return Err(e),
168        };
169
170        match frame {
171            FrontendMessage::Query(sql) => {
172                handle_simple_query(&mut stream, &runtime, &sql).await?;
173            }
174            FrontendMessage::Parse(msg) => {
175                handle_parse(&mut stream, &mut prepared, msg).await?;
176            }
177            FrontendMessage::Bind(msg) => {
178                handle_bind(&mut stream, &prepared, &mut portals, msg).await?;
179            }
180            FrontendMessage::Describe(msg) => {
181                handle_describe(&mut stream, &runtime, &prepared, &mut portals, msg).await?;
182            }
183            FrontendMessage::Execute(msg) => {
184                handle_execute(&mut stream, &runtime, &mut portals, msg).await?;
185            }
186            FrontendMessage::Close(msg) => {
187                handle_close(&mut stream, &mut prepared, &mut portals, msg).await?;
188            }
189            FrontendMessage::Terminate => return Ok(()),
190            FrontendMessage::Flush => {
191                // Frames are written immediately; no additional marker is
192                // needed. ReadyForQuery belongs to Sync, not Flush.
193                continue;
194            }
195            FrontendMessage::Sync => {
196                write_frame(
197                    &mut stream,
198                    &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
199                )
200                .await?;
201            }
202            FrontendMessage::PasswordMessage(_) => {
203                // Should only arrive during auth. Ignore post-auth.
204                continue;
205            }
206            FrontendMessage::Unknown { tag, .. } => {
207                send_error(
208                    &mut stream,
209                    "0A000",
210                    &format!("unsupported frame tag 0x{tag:02x}"),
211                )
212                .await?;
213                write_frame(
214                    &mut stream,
215                    &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
216                )
217                .await?;
218            }
219            other => {
220                send_error(
221                    &mut stream,
222                    "0A000",
223                    &format!("unsupported frame {other:?}"),
224                )
225                .await?;
226                write_frame(
227                    &mut stream,
228                    &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
229                )
230                .await?;
231            }
232        }
233    }
234}
235
236async fn handle_parse<S>(
237    stream: &mut S,
238    prepared: &mut HashMap<String, PgPreparedStatement>,
239    msg: super::protocol::ParseMessage,
240) -> Result<(), PgWireError>
241where
242    S: AsyncWrite + Unpin,
243{
244    let inferred_param_type_oids = infer_pg_cast_param_type_oids(&msg.query);
245    let sql = rewrite_pg_parameter_casts(&msg.query);
246    let parsed_param_count = match crate::storage::query::modes::parse_multi(&sql) {
247        Ok(parsed) => Some(
248            crate::storage::query::user_params::scan_parameters(&parsed)
249                .into_iter()
250                .map(|param| param.index + 1)
251                .max()
252                .unwrap_or(0),
253        ),
254        Err(err) => {
255            if pg_scalar_select_param_index(&sql).is_none() {
256                send_error(stream, "42601", &err.to_string()).await?;
257                return Ok(());
258            }
259            None
260        }
261    };
262    let mut param_type_oids = msg.param_type_oids;
263    if param_type_oids.is_empty() {
264        let count = parsed_param_count
265            .or_else(|| pg_scalar_select_param_index(&sql).map(|idx| idx + 1))
266            .unwrap_or(0);
267        param_type_oids.resize(count, PgOid::Unknown.as_u32());
268    }
269    for (idx, oid) in inferred_param_type_oids {
270        if idx >= param_type_oids.len() {
271            param_type_oids.resize(idx + 1, PgOid::Unknown.as_u32());
272        }
273        if param_type_oids[idx] == PgOid::Unknown.as_u32() {
274            param_type_oids[idx] = oid;
275        }
276    }
277    prepared.insert(
278        msg.statement,
279        PgPreparedStatement {
280            sql,
281            param_type_oids,
282        },
283    );
284    write_frame(stream, &BackendMessage::ParseComplete).await
285}
286
287async fn handle_bind<S>(
288    stream: &mut S,
289    prepared: &HashMap<String, PgPreparedStatement>,
290    portals: &mut HashMap<String, PgPortal>,
291    msg: super::protocol::BindMessage,
292) -> Result<(), PgWireError>
293where
294    S: AsyncWrite + Unpin,
295{
296    let Some(stmt) = prepared.get(&msg.statement) else {
297        send_error(
298            stream,
299            "26000",
300            &format!("prepared statement {:?} does not exist", msg.statement),
301        )
302        .await?;
303        return Ok(());
304    };
305    let params = match bind_pg_params(stmt, &msg) {
306        Ok(params) => params,
307        Err(err) => {
308            send_error(stream, "22023", &err).await?;
309            return Ok(());
310        }
311    };
312    portals.insert(
313        msg.portal,
314        PgPortal {
315            sql: stmt.sql.clone(),
316            params,
317            result_format_codes: msg.result_format_codes,
318            row_description_sent: false,
319            described_result: None,
320        },
321    );
322    write_frame(stream, &BackendMessage::BindComplete).await
323}
324
325async fn handle_describe<S>(
326    stream: &mut S,
327    runtime: &RedDBRuntime,
328    prepared: &HashMap<String, PgPreparedStatement>,
329    portals: &mut HashMap<String, PgPortal>,
330    msg: super::protocol::DescribeMessage,
331) -> Result<(), PgWireError>
332where
333    S: AsyncWrite + Unpin,
334{
335    match msg.target {
336        DescribeTarget::Statement => {
337            let Some(stmt) = prepared.get(&msg.name) else {
338                send_error(
339                    stream,
340                    "26000",
341                    &format!("prepared statement {:?} does not exist", msg.name),
342                )
343                .await?;
344                return Ok(());
345            };
346            write_frame(
347                stream,
348                &BackendMessage::ParameterDescription(stmt.param_type_oids.clone()),
349            )
350            .await?;
351            if is_ask_query(&stmt.sql) {
352                emit_ask_row_description(stream).await
353            } else {
354                write_frame(stream, &BackendMessage::NoData).await
355            }
356        }
357        DescribeTarget::Portal => {
358            let Some(portal) = portals.get_mut(&msg.name) else {
359                send_error(
360                    stream,
361                    "34000",
362                    &format!("portal {:?} does not exist", msg.name),
363                )
364                .await?;
365                return Ok(());
366            };
367            if is_ask_query(&portal.sql) {
368                emit_ask_row_description(stream).await?;
369                portal.row_description_sent = true;
370                Ok(())
371            } else if is_row_returning_query(&portal.sql) {
372                let result = match execute_pg_query_result(runtime, &portal.sql, &portal.params) {
373                    Ok(result) => result,
374                    Err(err) => {
375                        let code = classify_sqlstate(&err);
376                        send_error(stream, code, &err).await?;
377                        return Ok(());
378                    }
379                };
380                emit_row_description_for_result(stream, &result).await?;
381                portal.row_description_sent = true;
382                portal.described_result = Some(result);
383                Ok(())
384            } else {
385                write_frame(stream, &BackendMessage::NoData).await
386            }
387        }
388    }
389}
390
391async fn handle_execute<S>(
392    stream: &mut S,
393    runtime: &RedDBRuntime,
394    portals: &mut HashMap<String, PgPortal>,
395    msg: super::protocol::ExecuteMessage,
396) -> Result<(), PgWireError>
397where
398    S: AsyncWrite + Unpin,
399{
400    let Some(portal) = portals.get_mut(&msg.portal) else {
401        send_error(
402            stream,
403            "34000",
404            &format!("portal {:?} does not exist", msg.portal),
405        )
406        .await?;
407        return Ok(());
408    };
409    let _max_rows = msg.max_rows;
410    let was_described = portal.row_description_sent || portal.described_result.is_some();
411    portal.row_description_sent = false;
412    let result = match portal.described_result.take() {
413        Some(result) => Ok(result),
414        None => execute_pg_query_result(runtime, &portal.sql, &portal.params),
415    };
416    match result {
417        Ok(result) if was_described => {
418            emit_success_result_without_row_description(stream, &result).await
419        }
420        Ok(result) => emit_success_result(stream, &result).await,
421        Err(err) => {
422            let code = classify_sqlstate(&err);
423            send_error(stream, code, &err).await
424        }
425    }
426}
427
428async fn handle_close<S>(
429    stream: &mut S,
430    prepared: &mut HashMap<String, PgPreparedStatement>,
431    portals: &mut HashMap<String, PgPortal>,
432    msg: super::protocol::CloseMessage,
433) -> Result<(), PgWireError>
434where
435    S: AsyncWrite + Unpin,
436{
437    match msg.target {
438        DescribeTarget::Statement => {
439            prepared.remove(&msg.name);
440        }
441        DescribeTarget::Portal => {
442            portals.remove(&msg.name);
443        }
444    }
445    write_frame(stream, &BackendMessage::CloseComplete).await
446}
447
448fn bind_pg_params(
449    stmt: &PgPreparedStatement,
450    msg: &super::protocol::BindMessage,
451) -> Result<Vec<Value>, String> {
452    if !matches!(msg.param_format_codes.len(), 0 | 1)
453        && msg.param_format_codes.len() != msg.params.len()
454    {
455        return Err("Bind format count must be 0, 1, or match parameter count".to_string());
456    }
457    msg.params
458        .iter()
459        .enumerate()
460        .map(|(idx, param)| {
461            let oid = stmt
462                .param_type_oids
463                .get(idx)
464                .copied()
465                .unwrap_or(PgOid::Unknown.as_u32());
466            let format_code = match msg.param_format_codes.as_slice() {
467                [] => 0,
468                [format] => *format,
469                formats => formats[idx],
470            };
471            pg_param_to_value(oid, format_code, param.as_deref())
472        })
473        .collect()
474}
475
476fn execute_pg_query_result(
477    runtime: &RedDBRuntime,
478    sql: &str,
479    params: &[Value],
480) -> Result<crate::runtime::RuntimeQueryResult, String> {
481    if let Some(result) = try_execute_pg_scalar_select(sql, params) {
482        return Ok(result);
483    }
484    if params.is_empty() {
485        return match translate_pg_catalog_query(runtime, sql) {
486            Ok(Some(result)) => Ok(crate::runtime::RuntimeQueryResult {
487                query: sql.to_string(),
488                mode: crate::storage::query::modes::QueryMode::Sql,
489                statement: "select",
490                engine: "pg-catalog",
491                result,
492                affected_rows: 0,
493                statement_type: "select",
494                bookmark: None,
495            }),
496            Ok(None) => {
497                run_runtime_blocking(|| runtime.execute_query(sql)).map_err(|err| err.to_string())
498            }
499            Err(err) => Err(err.to_string()),
500        };
501    }
502
503    run_runtime_blocking(|| runtime.execute_query_with_params(sql, params))
504        .map_err(|err| err.to_string())
505}
506
507fn try_execute_pg_scalar_select(
508    sql: &str,
509    params: &[Value],
510) -> Option<crate::runtime::RuntimeQueryResult> {
511    let index = pg_scalar_select_param_index(sql)?;
512    let value = params.get(index)?.clone();
513    let mut result = UnifiedResult::with_columns(vec!["?column?".to_string()]);
514    let mut record = UnifiedRecord::new();
515    record.set("?column?", value);
516    result.push(record);
517    Some(crate::runtime::RuntimeQueryResult {
518        query: sql.to_string(),
519        mode: crate::storage::query::modes::QueryMode::Sql,
520        statement: "select",
521        engine: "pg-wire",
522        result,
523        affected_rows: 0,
524        statement_type: "select",
525        bookmark: None,
526    })
527}
528
529fn pg_scalar_select_param_index(sql: &str) -> Option<usize> {
530    let trimmed = sql.trim().trim_end_matches(';').trim();
531    let lower = trimmed.to_ascii_lowercase();
532    let body = lower.strip_prefix("select ")?;
533    let param = if let Some(inner) = body.strip_prefix("cast(") {
534        let end = inner.find(" as ")?;
535        &inner[..end]
536    } else {
537        body.split_whitespace().next()?
538    };
539    let digits = param.strip_prefix('$')?;
540    let n = digits.parse::<usize>().ok()?;
541    n.checked_sub(1)
542}
543
544fn rewrite_pg_parameter_casts(sql: &str) -> String {
545    let mut out = String::with_capacity(sql.len());
546    let bytes = sql.as_bytes();
547    let mut cursor = 0;
548    let mut pos = 0;
549    while pos < bytes.len() {
550        if bytes[pos] != b'$' {
551            pos += 1;
552            continue;
553        }
554        let param_start = pos;
555        pos += 1;
556        let digits_start = pos;
557        while pos < bytes.len() && bytes[pos].is_ascii_digit() {
558            pos += 1;
559        }
560        if digits_start == pos {
561            continue;
562        }
563        if pos + 2 <= bytes.len() && &bytes[pos..pos + 2] == b"::" {
564            let param_end = pos;
565            pos += 2;
566            let type_start = pos;
567            while pos < bytes.len()
568                && (bytes[pos].is_ascii_alphanumeric() || matches!(bytes[pos], b'_' | b'.'))
569            {
570                pos += 1;
571            }
572            if type_start != pos {
573                out.push_str(&sql[cursor..param_start]);
574                out.push_str(&sql[param_start..param_end]);
575                cursor = pos;
576                continue;
577            }
578        }
579    }
580    out.push_str(&sql[cursor..]);
581    out
582}
583
584fn infer_pg_cast_param_type_oids(sql: &str) -> Vec<(usize, u32)> {
585    let mut out = Vec::new();
586    let bytes = sql.as_bytes();
587    let mut pos = 0;
588    while pos < bytes.len() {
589        if bytes[pos] != b'$' {
590            pos += 1;
591            continue;
592        }
593        pos += 1;
594        let digits_start = pos;
595        while pos < bytes.len() && bytes[pos].is_ascii_digit() {
596            pos += 1;
597        }
598        if digits_start == pos {
599            continue;
600        }
601        let Some(param_index) = sql[digits_start..pos]
602            .parse::<usize>()
603            .ok()
604            .and_then(|idx| idx.checked_sub(1))
605        else {
606            continue;
607        };
608        if pos + 2 > bytes.len() || &bytes[pos..pos + 2] != b"::" {
609            continue;
610        }
611        pos += 2;
612        let type_start = pos;
613        while pos < bytes.len()
614            && (bytes[pos].is_ascii_alphanumeric() || matches!(bytes[pos], b'_' | b'.'))
615        {
616            pos += 1;
617        }
618        if type_start == pos {
619            continue;
620        }
621        if let Some(oid) = pg_cast_type_oid(&sql[type_start..pos]) {
622            out.push((param_index, oid));
623        }
624    }
625    out
626}
627
628fn pg_cast_type_oid(ty: &str) -> Option<u32> {
629    let lower = ty.to_ascii_lowercase();
630    let short = lower.rsplit('.').next().unwrap_or(lower.as_str());
631    let oid = match short {
632        "bool" | "boolean" => PgOid::Bool,
633        "int2" | "smallint" => PgOid::Int2,
634        "int" | "int4" | "integer" => PgOid::Int4,
635        "int8" | "bigint" => PgOid::Int8,
636        "float4" | "real" => PgOid::Float4,
637        "float8" | "double" | "doubleprecision" => PgOid::Float8,
638        "numeric" | "decimal" => PgOid::Numeric,
639        "bytea" => PgOid::Bytea,
640        "json" => PgOid::Json,
641        "jsonb" => PgOid::Jsonb,
642        "text" => PgOid::Text,
643        "varchar" | "character varying" => PgOid::Varchar,
644        "uuid" => PgOid::Uuid,
645        "timestamp" => PgOid::Timestamp,
646        "timestamptz" | "timestampz" => PgOid::TimestampTz,
647        "vector" => PgOid::Vector,
648        _ => return None,
649    };
650    Some(oid.as_u32())
651}
652
653fn is_row_returning_query(sql: &str) -> bool {
654    let trimmed = sql.trim_start().to_ascii_lowercase();
655    trimmed.starts_with("select")
656        || trimmed.starts_with("with")
657        || trimmed.starts_with("ask")
658        || trimmed.starts_with("search")
659        || trimmed.starts_with("vector")
660        || trimmed.starts_with("hybrid")
661}
662
663fn is_ask_query(sql: &str) -> bool {
664    sql.trim_start().to_ascii_lowercase().starts_with("ask")
665}
666
667async fn send_auth_ok<S>(
668    stream: &mut S,
669    config: &PgWireConfig,
670    params: &super::protocol::StartupParams,
671) -> Result<(), PgWireError>
672where
673    S: AsyncWrite + Unpin,
674{
675    // Phase 3.1: trust auth. We always send AuthenticationOk.
676    write_frame(stream, &BackendMessage::AuthenticationOk).await?;
677
678    // Standard ParameterStatus frames. Drivers gate capabilities on these.
679    for (name, value) in [
680        ("server_version", config.server_version.as_str()),
681        ("server_encoding", "UTF8"),
682        ("client_encoding", "UTF8"),
683        ("DateStyle", "ISO, MDY"),
684        ("TimeZone", "UTC"),
685        ("integer_datetimes", "on"),
686        ("standard_conforming_strings", "on"),
687        (
688            "application_name",
689            params.get("application_name").unwrap_or(""),
690        ),
691    ] {
692        write_frame(
693            stream,
694            &BackendMessage::ParameterStatus {
695                name: name.to_string(),
696                value: value.to_string(),
697            },
698        )
699        .await?;
700    }
701
702    // BackendKeyData: (pid, secret_key). Used by CancelRequest; we don't
703    // honour cancels in 3.1 so random-ish values are fine.
704    write_frame(
705        stream,
706        &BackendMessage::BackendKeyData {
707            pid: std::process::id(),
708            key: 0xDEADBEEF,
709        },
710    )
711    .await?;
712
713    write_frame(
714        stream,
715        &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
716    )
717    .await?;
718    Ok(())
719}
720
721async fn handle_simple_query<S>(
722    stream: &mut S,
723    runtime: &RedDBRuntime,
724    sql: &str,
725) -> Result<(), PgWireError>
726where
727    S: AsyncWrite + Unpin,
728{
729    // Empty query convention: PG emits EmptyQueryResponse instead of a
730    // CommandComplete. Some clients (psql `\;`) rely on this.
731    if sql.trim().is_empty() {
732        write_frame(stream, &BackendMessage::EmptyQueryResponse).await?;
733        write_frame(
734            stream,
735            &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
736        )
737        .await?;
738        return Ok(());
739    }
740
741    if let Some(tag) = pg_session_compat_command_tag(sql) {
742        write_frame(stream, &BackendMessage::CommandComplete(tag.to_string())).await?;
743        write_frame(
744            stream,
745            &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
746        )
747        .await?;
748        return Ok(());
749    }
750
751    let query_result = match translate_pg_catalog_query(runtime, sql) {
752        Ok(Some(result)) => Ok(crate::runtime::RuntimeQueryResult {
753            query: sql.to_string(),
754            mode: crate::storage::query::modes::QueryMode::Sql,
755            statement: "select",
756            engine: "pg-catalog",
757            result,
758            affected_rows: 0,
759            statement_type: "select",
760            bookmark: None,
761        }),
762        Ok(None) => run_runtime_blocking(|| runtime.execute_query(sql)),
763        Err(err) => Err(err),
764    };
765
766    match query_result {
767        Ok(result) => {
768            emit_success_result(stream, &result).await?;
769        }
770        Err(err) => {
771            // PG SQLSTATE class 42 covers syntax / binding errors; we use
772            // 42P01 (undefined_table) and 42601 (syntax_error) when we can
773            // detect; otherwise fall back to XX000 (internal error).
774            let code = classify_sqlstate(&err.to_string());
775            send_error(stream, code, &err.to_string()).await?;
776        }
777    }
778
779    write_frame(
780        stream,
781        &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
782    )
783    .await?;
784    Ok(())
785}
786
787fn pg_session_compat_command_tag(sql: &str) -> Option<&'static str> {
788    let lower = sql.trim().trim_end_matches(';').to_ascii_lowercase();
789    if lower.starts_with("set ") {
790        return Some("SET");
791    }
792    None
793}
794
795async fn emit_success_result<S>(
796    stream: &mut S,
797    result: &crate::runtime::RuntimeQueryResult,
798) -> Result<(), PgWireError>
799where
800    S: AsyncWrite + Unpin,
801{
802    if result.statement == "ask" {
803        emit_ask_result_row(stream, result).await?;
804        write_frame(
805            stream,
806            &BackendMessage::CommandComplete("SELECT 1".to_string()),
807        )
808        .await?;
809    } else if result_returns_rows(result) {
810        emit_result_rows(stream, &result.result).await?;
811        write_frame(
812            stream,
813            &BackendMessage::CommandComplete(format!("SELECT {}", result.result.records.len())),
814        )
815        .await?;
816    } else {
817        // DDL / DML / config statements: echo the runtime's
818        // high-level statement tag back. PG format is
819        // "<CMD> [<OID>] <COUNT>"; we keep the count where
820        // applicable and fall back to the runtime's message.
821        let tag = match result.statement_type {
822            "insert" => format!("INSERT 0 {}", result.affected_rows),
823            "update" => format!("UPDATE {}", result.affected_rows),
824            "delete" => format!("DELETE {}", result.affected_rows),
825            other => other.to_uppercase(),
826        };
827        write_frame(stream, &BackendMessage::CommandComplete(tag)).await?;
828    }
829    Ok(())
830}
831
832async fn emit_success_result_without_row_description<S>(
833    stream: &mut S,
834    result: &crate::runtime::RuntimeQueryResult,
835) -> Result<(), PgWireError>
836where
837    S: AsyncWrite + Unpin,
838{
839    if result.statement == "ask" {
840        let row = ask_query_result_to_pg_wire_row(result)
841            .ok_or_else(|| PgWireError::Protocol("ASK result missing row body".to_string()))?;
842        write_frame(stream, &BackendMessage::DataRow(row.cells)).await?;
843        write_frame(
844            stream,
845            &BackendMessage::CommandComplete("SELECT 1".to_string()),
846        )
847        .await?;
848    } else if result_returns_rows(result) {
849        emit_result_data_rows(stream, &result.result).await?;
850        write_frame(
851            stream,
852            &BackendMessage::CommandComplete(format!("SELECT {}", result.result.records.len())),
853        )
854        .await?;
855    } else {
856        let tag = match result.statement_type {
857            "insert" => format!("INSERT 0 {}", result.affected_rows),
858            "update" => format!("UPDATE {}", result.affected_rows),
859            "delete" => format!("DELETE {}", result.affected_rows),
860            other => other.to_uppercase(),
861        };
862        write_frame(stream, &BackendMessage::CommandComplete(tag)).await?;
863    }
864    Ok(())
865}
866
867async fn emit_row_description_for_result<S>(
868    stream: &mut S,
869    result: &crate::runtime::RuntimeQueryResult,
870) -> Result<(), PgWireError>
871where
872    S: AsyncWrite + Unpin,
873{
874    if result.statement == "ask" {
875        emit_ask_row_description(stream).await
876    } else if result_returns_rows(result) {
877        emit_result_row_description(stream, &result.result).await
878    } else {
879        write_frame(stream, &BackendMessage::NoData).await
880    }
881}
882
883fn result_returns_rows(result: &crate::runtime::RuntimeQueryResult) -> bool {
884    result.statement_type == "select"
885}
886
887async fn emit_result_rows<S>(
888    stream: &mut S,
889    result: &crate::storage::query::unified::UnifiedResult,
890) -> Result<(), PgWireError>
891where
892    S: AsyncWrite + Unpin,
893{
894    emit_result_row_description(stream, result).await?;
895    emit_result_data_rows(stream, result).await
896}
897
898async fn emit_result_row_description<S>(
899    stream: &mut S,
900    result: &crate::storage::query::unified::UnifiedResult,
901) -> Result<(), PgWireError>
902where
903    S: AsyncWrite + Unpin,
904{
905    // RowDescription: derived from the first record's column ordering.
906    // When `result.columns` is non-empty we honour that order; otherwise
907    // we synthesise one from the record's field order.
908    let columns: Vec<String> = if !result.columns.is_empty() {
909        result.columns.clone()
910    } else if let Some(first) = result.records.first() {
911        record_field_names(first)
912    } else {
913        Vec::new()
914    };
915
916    // Peek at the first record for per-column type OIDs. When there's no
917    // data row we fall back to TEXT for every column — clients render
918    // empty result sets happily.
919    let type_oids: Vec<PgOid> = columns
920        .iter()
921        .map(|col| {
922            result
923                .records
924                .first()
925                .and_then(|r| record_get(r, col))
926                .map(PgOid::from_value)
927                .unwrap_or(PgOid::Text)
928        })
929        .collect();
930
931    let descriptors: Vec<ColumnDescriptor> = columns
932        .iter()
933        .zip(type_oids.iter())
934        .map(|(name, oid)| ColumnDescriptor {
935            name: name.clone(),
936            table_oid: 0,
937            column_attr: 0,
938            type_oid: oid.as_u32(),
939            type_size: -1,
940            type_mod: -1,
941            format: 0,
942        })
943        .collect();
944
945    write_frame(stream, &BackendMessage::RowDescription(descriptors)).await
946}
947
948async fn emit_result_data_rows<S>(
949    stream: &mut S,
950    result: &crate::storage::query::unified::UnifiedResult,
951) -> Result<(), PgWireError>
952where
953    S: AsyncWrite + Unpin,
954{
955    let columns: Vec<String> = if !result.columns.is_empty() {
956        result.columns.clone()
957    } else if let Some(first) = result.records.first() {
958        record_field_names(first)
959    } else {
960        Vec::new()
961    };
962    for record in &result.records {
963        let fields: Vec<Option<Vec<u8>>> = columns
964            .iter()
965            .map(|col| record_get(record, col).and_then(value_to_pg_wire_bytes))
966            .collect();
967        write_frame(stream, &BackendMessage::DataRow(fields)).await?;
968    }
969
970    Ok(())
971}
972
973async fn emit_ask_result_row<S>(
974    stream: &mut S,
975    result: &crate::runtime::RuntimeQueryResult,
976) -> Result<(), PgWireError>
977where
978    S: AsyncWrite + Unpin,
979{
980    let row = ask_query_result_to_pg_wire_row(result)
981        .ok_or_else(|| PgWireError::Protocol("ASK result missing row body".to_string()))?;
982
983    emit_ask_row_description(stream).await?;
984    write_frame(stream, &BackendMessage::DataRow(row.cells)).await?;
985    Ok(())
986}
987
988async fn emit_ask_row_description<S>(stream: &mut S) -> Result<(), PgWireError>
989where
990    S: AsyncWrite + Unpin,
991{
992    let descriptors: Vec<ColumnDescriptor> = crate::runtime::ai::pg_wire_ask_row_encoder::columns()
993        .iter()
994        .map(|col| ColumnDescriptor {
995            name: col.name.to_string(),
996            table_oid: 0,
997            column_attr: 0,
998            type_oid: col.oid.as_u32(),
999            type_size: -1,
1000            type_mod: -1,
1001            format: 0,
1002        })
1003        .collect();
1004    write_frame(stream, &BackendMessage::RowDescription(descriptors)).await
1005}
1006
1007fn ask_query_result_to_pg_wire_row(
1008    result: &crate::runtime::RuntimeQueryResult,
1009) -> Option<crate::runtime::ai::pg_wire_ask_row_encoder::AskRow> {
1010    if result.statement != "ask" {
1011        return None;
1012    }
1013    let record = result.result.records.first()?;
1014    let sources_flat_json =
1015        json_field(record, "sources_flat").unwrap_or(crate::json::Value::Array(Vec::new()));
1016    let citations_json =
1017        json_field(record, "citations").unwrap_or(crate::json::Value::Array(Vec::new()));
1018    let validation_json = json_field(record, "validation")
1019        .unwrap_or_else(|| crate::json::Value::Object(Default::default()));
1020
1021    let effective_mode = match text_field(record, "mode").as_deref() {
1022        Some("lenient") => Mode::Lenient,
1023        _ => Mode::Strict,
1024    };
1025
1026    let ask = AskResult {
1027        answer: text_field(record, "answer")?,
1028        sources_flat: ask_sources_flat(&sources_flat_json),
1029        citations: ask_citations(&citations_json),
1030        validation: ask_validation(&validation_json),
1031        cache_hit: bool_field(record, "cache_hit").unwrap_or(false),
1032        provider: text_field(record, "provider").unwrap_or_default(),
1033        model: text_field(record, "model").unwrap_or_default(),
1034        prompt_tokens: u32_field(record, "prompt_tokens").unwrap_or(0),
1035        completion_tokens: u32_field(record, "completion_tokens").unwrap_or(0),
1036        cost_usd: f64_field(record, "cost_usd").unwrap_or(0.0),
1037        effective_mode,
1038        retry_count: u32_field(record, "retry_count").unwrap_or(0),
1039    };
1040
1041    Some(crate::runtime::ai::pg_wire_ask_row_encoder::encode(&ask))
1042}
1043
1044fn record_field<'a>(record: &'a UnifiedRecord, key: &str) -> Option<&'a Value> {
1045    record.iter_fields().find_map(|(name, value)| {
1046        let name: &str = name;
1047        (name == key).then_some(value)
1048    })
1049}
1050
1051fn text_field(record: &UnifiedRecord, key: &str) -> Option<String> {
1052    match record_field(record, key)? {
1053        Value::Text(s) => Some(s.to_string()),
1054        Value::Email(s) | Value::Url(s) | Value::NodeRef(s) | Value::EdgeRef(s) => Some(s.clone()),
1055        other => Some(other.to_string()),
1056    }
1057}
1058
1059fn bool_field(record: &UnifiedRecord, key: &str) -> Option<bool> {
1060    match record_field(record, key)? {
1061        Value::Boolean(value) => Some(*value),
1062        _ => None,
1063    }
1064}
1065
1066fn u32_field(record: &UnifiedRecord, key: &str) -> Option<u32> {
1067    match record_field(record, key)? {
1068        Value::Integer(n) => (*n >= 0).then_some((*n).min(u32::MAX as i64) as u32),
1069        Value::UnsignedInteger(n) => Some((*n).min(u32::MAX as u64) as u32),
1070        Value::BigInt(n)
1071        | Value::TimestampMs(n)
1072        | Value::Timestamp(n)
1073        | Value::Duration(n)
1074        | Value::Decimal(n) => (*n >= 0).then_some((*n).min(u32::MAX as i64) as u32),
1075        Value::Float(n) => (*n >= 0.0).then_some((*n).min(u32::MAX as f64) as u32),
1076        _ => None,
1077    }
1078}
1079
1080fn f64_field(record: &UnifiedRecord, key: &str) -> Option<f64> {
1081    match record_field(record, key)? {
1082        Value::Integer(n) => Some(*n as f64),
1083        Value::UnsignedInteger(n) => Some(*n as f64),
1084        Value::BigInt(n)
1085        | Value::TimestampMs(n)
1086        | Value::Timestamp(n)
1087        | Value::Duration(n)
1088        | Value::Decimal(n) => Some(*n as f64),
1089        Value::Float(n) => Some(*n),
1090        _ => None,
1091    }
1092}
1093
1094fn json_field(record: &UnifiedRecord, key: &str) -> Option<crate::json::Value> {
1095    match record_field(record, key)? {
1096        Value::Json(bytes) => crate::json::from_slice(bytes).ok(),
1097        Value::Text(text) => crate::json::from_str(text).ok(),
1098        _ => None,
1099    }
1100}
1101
1102fn ask_sources_flat(value: &crate::json::Value) -> Vec<SourceRow> {
1103    value
1104        .as_array()
1105        .unwrap_or(&[])
1106        .iter()
1107        .filter_map(|source| {
1108            let urn = source
1109                .get("urn")
1110                .and_then(crate::json::Value::as_str)?
1111                .to_string();
1112            let payload = source
1113                .get("payload")
1114                .and_then(crate::json::Value::as_str)
1115                .map(ToString::to_string)
1116                .unwrap_or_else(|| source.to_string_compact());
1117            Some(SourceRow { urn, payload })
1118        })
1119        .collect()
1120}
1121
1122fn ask_citations(value: &crate::json::Value) -> Vec<Citation> {
1123    value
1124        .as_array()
1125        .unwrap_or(&[])
1126        .iter()
1127        .filter_map(|citation| {
1128            let marker = citation
1129                .get("marker")
1130                .and_then(crate::json::Value::as_u64)?;
1131            let urn = citation
1132                .get("urn")
1133                .and_then(crate::json::Value::as_str)?
1134                .to_string();
1135            Some(Citation {
1136                marker: marker.min(u32::MAX as u64) as u32,
1137                urn,
1138            })
1139        })
1140        .collect()
1141}
1142
1143fn ask_validation(value: &crate::json::Value) -> Validation {
1144    Validation {
1145        ok: value
1146            .get("ok")
1147            .and_then(crate::json::Value::as_bool)
1148            .unwrap_or(true),
1149        warnings: validation_items(value, "warnings")
1150            .into_iter()
1151            .map(|(kind, detail)| ValidationWarning { kind, detail })
1152            .collect(),
1153        errors: validation_items(value, "errors")
1154            .into_iter()
1155            .map(|(kind, detail)| ValidationError { kind, detail })
1156            .collect(),
1157    }
1158}
1159
1160fn validation_items(value: &crate::json::Value, key: &str) -> Vec<(String, String)> {
1161    value
1162        .get(key)
1163        .and_then(crate::json::Value::as_array)
1164        .unwrap_or(&[])
1165        .iter()
1166        .filter_map(|item| {
1167            Some((
1168                item.get("kind")
1169                    .and_then(crate::json::Value::as_str)?
1170                    .to_string(),
1171                item.get("detail")
1172                    .and_then(crate::json::Value::as_str)
1173                    .unwrap_or("")
1174                    .to_string(),
1175            ))
1176        })
1177        .collect()
1178}
1179
1180/// Best-effort field lookup on a `UnifiedRecord`. The record API lives in
1181/// `storage::query::unified` and today uses `HashMap<String, Value>` under
1182/// the hood — we use `get` if it exists, else fall back to serialised map.
1183fn record_get<'a>(record: &'a UnifiedRecord, key: &str) -> Option<&'a Value> {
1184    record.get(key)
1185}
1186
1187/// Extract column names in iteration order from a single record. When
1188/// the caller didn't supply an explicit `columns` projection we use the
1189/// first record's field ordering as the canonical tuple shape.
1190///
1191/// HashMap iteration order is non-deterministic — for Phase 3.1 we
1192/// accept the shuffle since PG clients receive the ordered header via
1193/// RowDescription and match cells positionally. A stable ordering
1194/// would require keeping an insertion-order index alongside `values`.
1195fn record_field_names(record: &UnifiedRecord) -> Vec<String> {
1196    // `column_names()` merges the columnar scan side-channel with
1197    // the HashMap so scan rows (which populate only columnar) still
1198    // surface their field names in PG wire output.
1199    record
1200        .column_names()
1201        .into_iter()
1202        .map(|k| k.to_string())
1203        .collect()
1204}
1205
1206async fn send_error<S>(stream: &mut S, code: &str, message: &str) -> Result<(), PgWireError>
1207where
1208    S: AsyncWrite + Unpin,
1209{
1210    write_frame(
1211        stream,
1212        &BackendMessage::ErrorResponse {
1213            severity: "ERROR".to_string(),
1214            code: code.to_string(),
1215            message: message.to_string(),
1216        },
1217    )
1218    .await
1219}
1220
1221/// Heuristically map a runtime error message onto a PG SQLSTATE. Full
1222/// coverage would map every `RedDBError` variant; this is enough for the
1223/// common psql / JDBC paths.
1224fn classify_sqlstate(msg: &str) -> &'static str {
1225    let lower = msg.to_ascii_lowercase();
1226    if lower.contains("not found") || lower.contains("does not exist") {
1227        // 42P01 undefined_table; close enough for collection-not-found.
1228        "42P01"
1229    } else if lower.contains("parse") || lower.contains("expected") || lower.contains("syntax") {
1230        "42601"
1231    } else if lower.contains("already exists") {
1232        "42P07"
1233    } else if lower.contains("permission") || lower.contains("auth") {
1234        "28000"
1235    } else {
1236        "XX000"
1237    }
1238}
1239
1240#[cfg(test)]
1241mod tests {
1242    use super::*;
1243    use crate::api::RedDBOptions;
1244    use crate::runtime::RuntimeQueryResult;
1245    use crate::storage::query::modes::QueryMode;
1246    use crate::storage::query::unified::UnifiedResult;
1247    use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
1248
1249    #[tokio::test]
1250    async fn extended_parse_bind_execute_returns_rows() {
1251        let runtime = Arc::new(RedDBRuntime::with_options(RedDBOptions::in_memory()).unwrap());
1252        let config = Arc::new(PgWireConfig::default());
1253        let (server_io, mut client_io) = tokio::io::duplex(64 * 1024);
1254        let server = tokio::spawn(async move {
1255            handle_connection(server_io, runtime, config).await.unwrap();
1256        });
1257
1258        write_startup(&mut client_io).await;
1259        read_until_ready(&mut client_io).await;
1260
1261        write_frontend_frame(
1262            &mut client_io,
1263            b'P',
1264            parse_body("", "SELECT $1::int", &[PgOid::Int4.as_u32()]),
1265        )
1266        .await;
1267        write_frontend_frame(
1268            &mut client_io,
1269            b'B',
1270            bind_body("", "", &[0], &[Some(b"42".as_slice())], &[]),
1271        )
1272        .await;
1273        write_frontend_frame(&mut client_io, b'D', describe_body(b'P', "")).await;
1274        write_frontend_frame(&mut client_io, b'E', execute_body("", 0)).await;
1275        write_frontend_frame(&mut client_io, b'S', Vec::new()).await;
1276
1277        let frames = read_until_ready(&mut client_io).await;
1278        assert_eq!(
1279            frames.iter().map(|(tag, _)| *tag).collect::<Vec<_>>(),
1280            b"12TDCZ"
1281        );
1282        let columns = decode_row_description(&frames[2].1);
1283        assert_eq!(columns.len(), 1);
1284        let cells = decode_data_row(&frames[3].1);
1285        assert_eq!(cells.len(), 1);
1286        assert_eq!(cells[0].as_deref(), Some(b"42".as_slice()));
1287        assert_eq!(decode_command_complete(&frames[4].1), "SELECT 1");
1288
1289        write_frontend_frame(&mut client_io, b'X', Vec::new()).await;
1290        server.await.unwrap();
1291    }
1292
1293    #[test]
1294    fn infer_pg_cast_param_type_oids_from_parameter_casts() {
1295        assert_eq!(
1296            infer_pg_cast_param_type_oids("INSERT INTO t (id, name) VALUES ($1::int, $2::text)"),
1297            vec![(0, PgOid::Int4.as_u32()), (1, PgOid::Text.as_u32())]
1298        );
1299        assert_eq!(
1300            infer_pg_cast_param_type_oids("SEARCH SIMILAR [1.0] COLLECTION v LIMIT $1::int8"),
1301            vec![(0, PgOid::Int8.as_u32())]
1302        );
1303    }
1304
1305    #[test]
1306    fn pg_session_compat_accepts_driver_setup_set_commands() {
1307        assert_eq!(
1308            pg_session_compat_command_tag("SET extra_float_digits = 3"),
1309            Some("SET")
1310        );
1311        assert_eq!(
1312            pg_session_compat_command_tag("SET application_name = 'pgjdbc'"),
1313            Some("SET")
1314        );
1315        assert_eq!(pg_session_compat_command_tag("SELECT 1"), None);
1316    }
1317
1318    #[tokio::test]
1319    async fn ask_success_result_uses_canonical_pg_wire_row_shape() {
1320        let mut result = UnifiedResult::with_columns(vec![
1321            "answer".into(),
1322            "provider".into(),
1323            "model".into(),
1324            "prompt_tokens".into(),
1325            "completion_tokens".into(),
1326            "sources_count".into(),
1327            "sources_flat".into(),
1328            "citations".into(),
1329            "validation".into(),
1330        ]);
1331        let mut record = UnifiedRecord::new();
1332        record.set("answer", Value::text("Deploy failed [^1]."));
1333        record.set("provider", Value::text("openai"));
1334        record.set("model", Value::text("gpt-4o-mini"));
1335        record.set("prompt_tokens", Value::Integer(11));
1336        record.set("completion_tokens", Value::Integer(7));
1337        record.set(
1338            "sources_flat",
1339            Value::Json(
1340                br#"[{"urn":"urn:reddb:row:deployments:1","kind":"row","collection":"deployments","id":"1"}]"#
1341                    .to_vec(),
1342            ),
1343        );
1344        record.set(
1345            "citations",
1346            Value::Json(br#"[{"marker":1,"urn":"urn:reddb:row:deployments:1"}]"#.to_vec()),
1347        );
1348        record.set(
1349            "validation",
1350            Value::Json(br#"{"ok":true,"warnings":[],"errors":[]}"#.to_vec()),
1351        );
1352        result.push(record);
1353
1354        let qr = RuntimeQueryResult {
1355            query: "ASK 'why did deploy fail?'".to_string(),
1356            mode: QueryMode::Sql,
1357            statement: "ask",
1358            engine: "runtime-ai",
1359            result,
1360            affected_rows: 0,
1361            statement_type: "select",
1362            bookmark: None,
1363        };
1364
1365        let mut out = Vec::new();
1366        emit_success_result(&mut out, &qr).await.unwrap();
1367        let frames = decode_frames(&out);
1368
1369        assert_eq!(
1370            frames.iter().map(|(tag, _)| *tag).collect::<Vec<_>>(),
1371            b"TDC"
1372        );
1373
1374        let columns = decode_row_description(frames[0].1);
1375        assert_eq!(
1376            columns,
1377            vec![
1378                ("answer".to_string(), PgOid::Text.as_u32()),
1379                ("cache_hit".to_string(), PgOid::Bool.as_u32()),
1380                ("citations".to_string(), PgOid::Jsonb.as_u32()),
1381                ("completion_tokens".to_string(), PgOid::Int8.as_u32()),
1382                ("cost_usd".to_string(), PgOid::Numeric.as_u32()),
1383                ("mode".to_string(), PgOid::Text.as_u32()),
1384                ("model".to_string(), PgOid::Text.as_u32()),
1385                ("prompt_tokens".to_string(), PgOid::Int8.as_u32()),
1386                ("provider".to_string(), PgOid::Text.as_u32()),
1387                ("retry_count".to_string(), PgOid::Int8.as_u32()),
1388                ("sources_flat".to_string(), PgOid::Jsonb.as_u32()),
1389                ("validation".to_string(), PgOid::Jsonb.as_u32()),
1390            ]
1391        );
1392
1393        let cells = decode_data_row(frames[1].1);
1394        assert_eq!(cells.len(), 12);
1395        assert_eq!(cells[0].as_deref(), Some(b"Deploy failed [^1].".as_slice()));
1396        assert_eq!(cells[1].as_deref(), Some(b"f".as_slice()));
1397        assert_eq!(cells[4].as_deref(), Some(b"0".as_slice()));
1398        assert_eq!(cells[5].as_deref(), Some(b"strict".as_slice()));
1399        assert_eq!(cells[9].as_deref(), Some(b"0".as_slice()));
1400        assert!(std::str::from_utf8(cells[10].as_deref().unwrap())
1401            .unwrap()
1402            .contains(r#""payload""#));
1403        assert_eq!(decode_command_complete(frames[2].1), "SELECT 1");
1404    }
1405
1406    fn decode_frames(bytes: &[u8]) -> Vec<(u8, &[u8])> {
1407        let mut pos = 0;
1408        let mut frames = Vec::new();
1409        while pos < bytes.len() {
1410            let tag = bytes[pos];
1411            let len = u32::from_be_bytes([
1412                bytes[pos + 1],
1413                bytes[pos + 2],
1414                bytes[pos + 3],
1415                bytes[pos + 4],
1416            ]) as usize;
1417            let body_start = pos + 5;
1418            let body_end = pos + 1 + len;
1419            frames.push((tag, &bytes[body_start..body_end]));
1420            pos = body_end;
1421        }
1422        frames
1423    }
1424
1425    fn decode_row_description(body: &[u8]) -> Vec<(String, u32)> {
1426        let count = i16::from_be_bytes([body[0], body[1]]) as usize;
1427        let mut pos = 2;
1428        let mut columns = Vec::with_capacity(count);
1429        for _ in 0..count {
1430            let end = body[pos..].iter().position(|&b| b == 0).unwrap() + pos;
1431            let name = std::str::from_utf8(&body[pos..end]).unwrap().to_string();
1432            pos = end + 1;
1433            pos += 4; // table oid
1434            pos += 2; // column attr
1435            let oid = u32::from_be_bytes([body[pos], body[pos + 1], body[pos + 2], body[pos + 3]]);
1436            pos += 4;
1437            pos += 2; // type size
1438            pos += 4; // type mod
1439            pos += 2; // format
1440            columns.push((name, oid));
1441        }
1442        columns
1443    }
1444
1445    fn decode_data_row(body: &[u8]) -> Vec<Option<Vec<u8>>> {
1446        let count = i16::from_be_bytes([body[0], body[1]]) as usize;
1447        let mut pos = 2;
1448        let mut cells = Vec::with_capacity(count);
1449        for _ in 0..count {
1450            let len = i32::from_be_bytes([body[pos], body[pos + 1], body[pos + 2], body[pos + 3]]);
1451            pos += 4;
1452            if len < 0 {
1453                cells.push(None);
1454            } else {
1455                let len = len as usize;
1456                cells.push(Some(body[pos..pos + len].to_vec()));
1457                pos += len;
1458            }
1459        }
1460        cells
1461    }
1462
1463    fn decode_command_complete(body: &[u8]) -> &str {
1464        let nul = body.iter().position(|&b| b == 0).unwrap_or(body.len());
1465        std::str::from_utf8(&body[..nul]).unwrap()
1466    }
1467
1468    async fn write_startup<W: AsyncWrite + Unpin>(stream: &mut W) {
1469        let mut payload = Vec::new();
1470        payload.extend_from_slice(&crate::wire::postgres::protocol::PG_PROTOCOL_V3.to_be_bytes());
1471        payload.extend_from_slice(b"user\0reddb\0");
1472        payload.push(0);
1473        let len = (payload.len() + 4) as u32;
1474        stream.write_all(&len.to_be_bytes()).await.unwrap();
1475        stream.write_all(&payload).await.unwrap();
1476    }
1477
1478    async fn write_frontend_frame<W: AsyncWrite + Unpin>(
1479        stream: &mut W,
1480        tag: u8,
1481        payload: Vec<u8>,
1482    ) {
1483        stream.write_all(&[tag]).await.unwrap();
1484        stream
1485            .write_all(&((payload.len() + 4) as u32).to_be_bytes())
1486            .await
1487            .unwrap();
1488        stream.write_all(&payload).await.unwrap();
1489    }
1490
1491    async fn read_backend_frame<R: AsyncRead + Unpin>(stream: &mut R) -> (u8, Vec<u8>) {
1492        let mut tag = [0u8; 1];
1493        stream.read_exact(&mut tag).await.unwrap();
1494        let mut len = [0u8; 4];
1495        stream.read_exact(&mut len).await.unwrap();
1496        let len = u32::from_be_bytes(len) as usize;
1497        let mut body = vec![0u8; len - 4];
1498        stream.read_exact(&mut body).await.unwrap();
1499        (tag[0], body)
1500    }
1501
1502    async fn read_until_ready<R: AsyncRead + Unpin>(stream: &mut R) -> Vec<(u8, Vec<u8>)> {
1503        let mut frames = Vec::new();
1504        loop {
1505            let frame = read_backend_frame(stream).await;
1506            let done = frame.0 == b'Z';
1507            frames.push(frame);
1508            if done {
1509                return frames;
1510            }
1511        }
1512    }
1513
1514    fn parse_body(statement: &str, query: &str, oids: &[u32]) -> Vec<u8> {
1515        let mut out = Vec::new();
1516        push_pg_cstring(&mut out, statement);
1517        push_pg_cstring(&mut out, query);
1518        out.extend_from_slice(&(oids.len() as i16).to_be_bytes());
1519        for oid in oids {
1520            out.extend_from_slice(&oid.to_be_bytes());
1521        }
1522        out
1523    }
1524
1525    fn bind_body(
1526        portal: &str,
1527        statement: &str,
1528        formats: &[i16],
1529        params: &[Option<&[u8]>],
1530        result_formats: &[i16],
1531    ) -> Vec<u8> {
1532        let mut out = Vec::new();
1533        push_pg_cstring(&mut out, portal);
1534        push_pg_cstring(&mut out, statement);
1535        out.extend_from_slice(&(formats.len() as i16).to_be_bytes());
1536        for format in formats {
1537            out.extend_from_slice(&format.to_be_bytes());
1538        }
1539        out.extend_from_slice(&(params.len() as i16).to_be_bytes());
1540        for param in params {
1541            match param {
1542                Some(bytes) => {
1543                    out.extend_from_slice(&(bytes.len() as i32).to_be_bytes());
1544                    out.extend_from_slice(bytes);
1545                }
1546                None => out.extend_from_slice(&(-1i32).to_be_bytes()),
1547            }
1548        }
1549        out.extend_from_slice(&(result_formats.len() as i16).to_be_bytes());
1550        for format in result_formats {
1551            out.extend_from_slice(&format.to_be_bytes());
1552        }
1553        out
1554    }
1555
1556    fn describe_body(target: u8, name: &str) -> Vec<u8> {
1557        let mut out = vec![target];
1558        push_pg_cstring(&mut out, name);
1559        out
1560    }
1561
1562    fn execute_body(portal: &str, max_rows: u32) -> Vec<u8> {
1563        let mut out = Vec::new();
1564        push_pg_cstring(&mut out, portal);
1565        out.extend_from_slice(&max_rows.to_be_bytes());
1566        out
1567    }
1568
1569    fn push_pg_cstring(out: &mut Vec<u8>, value: &str) {
1570        out.extend_from_slice(value.as_bytes());
1571        out.push(0);
1572    }
1573}