Skip to main content

reddb_server/wire/postgres/
server.rs

1//! PostgreSQL wire-protocol listener (Phase 3.1 PG parity).
2//!
3//! Accepts TCP connections from PG-compatible clients, drives the startup
4//! handshake, and routes simple/extended-query frames into the existing
5//! `RedDBRuntime::execute_query` path. Results are adapted back into PG
6//! `RowDescription` + `DataRow` frames via `types::value_to_pg_wire_bytes`.
7
8use std::collections::HashMap;
9use std::sync::Arc;
10
11use tokio::io::{AsyncRead, AsyncWrite};
12use tokio::net::TcpListener;
13
14use super::catalog_views::translate_pg_catalog_query;
15use super::protocol::{
16    read_frame, read_startup, write_frame, write_raw_byte, BackendMessage, ColumnDescriptor,
17    DescribeTarget, FrontendMessage, PgWireError, TransactionStatus,
18};
19use super::types::{pg_param_to_value, value_to_pg_wire_bytes, PgOid};
20use crate::runtime::ai::ask_response_envelope::{
21    AskResult, Citation, Mode, SourceRow, Validation, ValidationError, ValidationWarning,
22};
23use crate::runtime::RedDBRuntime;
24use crate::storage::query::unified::{UnifiedRecord, UnifiedResult};
25use crate::storage::schema::Value;
26
27/// Startup-tuned configuration for the PG wire listener.
28#[derive(Debug, Clone)]
29pub struct PgWireConfig {
30    /// TCP bind address ("host:port"). The caller is responsible for
31    /// keeping this disjoint from the native wire / gRPC / HTTP listeners.
32    pub bind_addr: String,
33    /// PG version string sent back in `ParameterStatus`. Many drivers
34    /// sniff this to enable/disable features. RedDB advertises a
35    /// recent-enough version to get the broadest client support.
36    pub server_version: String,
37}
38
39#[derive(Debug, Clone)]
40struct PgPreparedStatement {
41    sql: String,
42    param_type_oids: Vec<u32>,
43}
44
45#[derive(Debug, Clone)]
46struct PgPortal {
47    sql: String,
48    params: Vec<Value>,
49    #[allow(dead_code)]
50    result_format_codes: Vec<i16>,
51    row_description_sent: bool,
52    described_result: Option<crate::runtime::RuntimeQueryResult>,
53}
54
55impl Default for PgWireConfig {
56    fn default() -> Self {
57        Self {
58            bind_addr: "127.0.0.1:5432".to_string(),
59            server_version: "15.0 (RedDB 3.1)".to_string(),
60        }
61    }
62}
63
64/// Spawn the PG wire listener. Blocks until the listener errors out.
65/// Each connection is handled in its own tokio task.
66pub async fn start_pg_wire_listener(
67    config: PgWireConfig,
68    runtime: Arc<RedDBRuntime>,
69) -> Result<(), Box<dyn std::error::Error>> {
70    let listener = TcpListener::bind(&config.bind_addr).await?;
71    tracing::info!(
72        transport = "pg-wire",
73        bind = %config.bind_addr,
74        "listener online"
75    );
76    let cfg = Arc::new(config);
77    loop {
78        let (stream, peer) = listener.accept().await?;
79        let rt = Arc::clone(&runtime);
80        let cfg = Arc::clone(&cfg);
81        let peer_str = peer.to_string();
82        tokio::spawn(async move {
83            if let Err(e) = handle_connection(stream, rt, cfg).await {
84                tracing::warn!(
85                    transport = "pg-wire",
86                    peer = %peer_str,
87                    err = %e,
88                    "connection failed"
89                );
90            }
91        });
92    }
93}
94
95/// Drive one connection's lifetime: startup → authentication → query loop.
96pub(crate) async fn handle_connection<S>(
97    mut stream: S,
98    runtime: Arc<RedDBRuntime>,
99    config: Arc<PgWireConfig>,
100) -> Result<(), PgWireError>
101where
102    S: AsyncRead + AsyncWrite + Unpin + Send,
103{
104    // Handshake. The first frame may be SSLRequest / GSSENCRequest
105    // (pre-auth negotiation) or a plain Startup. Loop once to cover
106    // SSL-not-supported path: reply 'N' and expect the client to send
107    // a regular Startup next.
108    loop {
109        match read_startup(&mut stream).await? {
110            FrontendMessage::SslRequest | FrontendMessage::GssEncRequest => {
111                // 'N' = not supported — client continues in plaintext and
112                // re-sends a normal Startup on the same socket.
113                write_raw_byte(&mut stream, b'N').await?;
114                continue;
115            }
116            FrontendMessage::Startup(params) => {
117                send_auth_ok(&mut stream, &config, &params).await?;
118                break;
119            }
120            FrontendMessage::Unknown { .. } => {
121                // CancelRequest: no response expected; drop the socket.
122                return Ok(());
123            }
124            other => {
125                return Err(PgWireError::Protocol(format!(
126                    "unexpected startup frame: {other:?}"
127                )));
128            }
129        }
130    }
131
132    let mut prepared: HashMap<String, PgPreparedStatement> = HashMap::new();
133    let mut portals: HashMap<String, PgPortal> = HashMap::new();
134
135    // Main query loop.
136    loop {
137        let frame = match read_frame(&mut stream).await {
138            Ok(f) => f,
139            Err(PgWireError::Eof) => return Ok(()),
140            Err(e) => return Err(e),
141        };
142
143        match frame {
144            FrontendMessage::Query(sql) => {
145                handle_simple_query(&mut stream, &runtime, &sql).await?;
146            }
147            FrontendMessage::Parse(msg) => {
148                handle_parse(&mut stream, &mut prepared, msg).await?;
149            }
150            FrontendMessage::Bind(msg) => {
151                handle_bind(&mut stream, &prepared, &mut portals, msg).await?;
152            }
153            FrontendMessage::Describe(msg) => {
154                handle_describe(&mut stream, &runtime, &prepared, &mut portals, msg).await?;
155            }
156            FrontendMessage::Execute(msg) => {
157                handle_execute(&mut stream, &runtime, &mut portals, msg).await?;
158            }
159            FrontendMessage::Close(msg) => {
160                handle_close(&mut stream, &mut prepared, &mut portals, msg).await?;
161            }
162            FrontendMessage::Terminate => return Ok(()),
163            FrontendMessage::Flush => {
164                // Frames are written immediately; no additional marker is
165                // needed. ReadyForQuery belongs to Sync, not Flush.
166                continue;
167            }
168            FrontendMessage::Sync => {
169                write_frame(
170                    &mut stream,
171                    &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
172                )
173                .await?;
174            }
175            FrontendMessage::PasswordMessage(_) => {
176                // Should only arrive during auth. Ignore post-auth.
177                continue;
178            }
179            FrontendMessage::Unknown { tag, .. } => {
180                send_error(
181                    &mut stream,
182                    "0A000",
183                    &format!("unsupported frame tag 0x{tag:02x}"),
184                )
185                .await?;
186                write_frame(
187                    &mut stream,
188                    &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
189                )
190                .await?;
191            }
192            other => {
193                send_error(
194                    &mut stream,
195                    "0A000",
196                    &format!("unsupported frame {other:?}"),
197                )
198                .await?;
199                write_frame(
200                    &mut stream,
201                    &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
202                )
203                .await?;
204            }
205        }
206    }
207}
208
209async fn handle_parse<S>(
210    stream: &mut S,
211    prepared: &mut HashMap<String, PgPreparedStatement>,
212    msg: super::protocol::ParseMessage,
213) -> Result<(), PgWireError>
214where
215    S: AsyncWrite + Unpin,
216{
217    let inferred_param_type_oids = infer_pg_cast_param_type_oids(&msg.query);
218    let sql = rewrite_pg_parameter_casts(&msg.query);
219    let parsed_param_count = match crate::storage::query::modes::parse_multi(&sql) {
220        Ok(parsed) => Some(
221            crate::storage::query::user_params::scan_parameters(&parsed)
222                .into_iter()
223                .map(|param| param.index + 1)
224                .max()
225                .unwrap_or(0),
226        ),
227        Err(err) => {
228            if pg_scalar_select_param_index(&sql).is_none() {
229                send_error(stream, "42601", &err.to_string()).await?;
230                return Ok(());
231            }
232            None
233        }
234    };
235    let mut param_type_oids = msg.param_type_oids;
236    if param_type_oids.is_empty() {
237        let count = parsed_param_count
238            .or_else(|| pg_scalar_select_param_index(&sql).map(|idx| idx + 1))
239            .unwrap_or(0);
240        param_type_oids.resize(count, PgOid::Unknown.as_u32());
241    }
242    for (idx, oid) in inferred_param_type_oids {
243        if idx >= param_type_oids.len() {
244            param_type_oids.resize(idx + 1, PgOid::Unknown.as_u32());
245        }
246        if param_type_oids[idx] == PgOid::Unknown.as_u32() {
247            param_type_oids[idx] = oid;
248        }
249    }
250    prepared.insert(
251        msg.statement,
252        PgPreparedStatement {
253            sql,
254            param_type_oids,
255        },
256    );
257    write_frame(stream, &BackendMessage::ParseComplete).await
258}
259
260async fn handle_bind<S>(
261    stream: &mut S,
262    prepared: &HashMap<String, PgPreparedStatement>,
263    portals: &mut HashMap<String, PgPortal>,
264    msg: super::protocol::BindMessage,
265) -> Result<(), PgWireError>
266where
267    S: AsyncWrite + Unpin,
268{
269    let Some(stmt) = prepared.get(&msg.statement) else {
270        send_error(
271            stream,
272            "26000",
273            &format!("prepared statement {:?} does not exist", msg.statement),
274        )
275        .await?;
276        return Ok(());
277    };
278    let params = match bind_pg_params(stmt, &msg) {
279        Ok(params) => params,
280        Err(err) => {
281            send_error(stream, "22023", &err).await?;
282            return Ok(());
283        }
284    };
285    portals.insert(
286        msg.portal,
287        PgPortal {
288            sql: stmt.sql.clone(),
289            params,
290            result_format_codes: msg.result_format_codes,
291            row_description_sent: false,
292            described_result: None,
293        },
294    );
295    write_frame(stream, &BackendMessage::BindComplete).await
296}
297
298async fn handle_describe<S>(
299    stream: &mut S,
300    runtime: &RedDBRuntime,
301    prepared: &HashMap<String, PgPreparedStatement>,
302    portals: &mut HashMap<String, PgPortal>,
303    msg: super::protocol::DescribeMessage,
304) -> Result<(), PgWireError>
305where
306    S: AsyncWrite + Unpin,
307{
308    match msg.target {
309        DescribeTarget::Statement => {
310            let Some(stmt) = prepared.get(&msg.name) else {
311                send_error(
312                    stream,
313                    "26000",
314                    &format!("prepared statement {:?} does not exist", msg.name),
315                )
316                .await?;
317                return Ok(());
318            };
319            write_frame(
320                stream,
321                &BackendMessage::ParameterDescription(stmt.param_type_oids.clone()),
322            )
323            .await?;
324            if is_ask_query(&stmt.sql) {
325                emit_ask_row_description(stream).await
326            } else {
327                write_frame(stream, &BackendMessage::NoData).await
328            }
329        }
330        DescribeTarget::Portal => {
331            let Some(portal) = portals.get_mut(&msg.name) else {
332                send_error(
333                    stream,
334                    "34000",
335                    &format!("portal {:?} does not exist", msg.name),
336                )
337                .await?;
338                return Ok(());
339            };
340            if is_ask_query(&portal.sql) {
341                emit_ask_row_description(stream).await?;
342                portal.row_description_sent = true;
343                Ok(())
344            } else if is_row_returning_query(&portal.sql) {
345                let result = match execute_pg_query_result(runtime, &portal.sql, &portal.params) {
346                    Ok(result) => result,
347                    Err(err) => {
348                        let code = classify_sqlstate(&err);
349                        send_error(stream, code, &err).await?;
350                        return Ok(());
351                    }
352                };
353                emit_row_description_for_result(stream, &result).await?;
354                portal.row_description_sent = true;
355                portal.described_result = Some(result);
356                Ok(())
357            } else {
358                write_frame(stream, &BackendMessage::NoData).await
359            }
360        }
361    }
362}
363
364async fn handle_execute<S>(
365    stream: &mut S,
366    runtime: &RedDBRuntime,
367    portals: &mut HashMap<String, PgPortal>,
368    msg: super::protocol::ExecuteMessage,
369) -> Result<(), PgWireError>
370where
371    S: AsyncWrite + Unpin,
372{
373    let Some(portal) = portals.get_mut(&msg.portal) else {
374        send_error(
375            stream,
376            "34000",
377            &format!("portal {:?} does not exist", msg.portal),
378        )
379        .await?;
380        return Ok(());
381    };
382    let _max_rows = msg.max_rows;
383    let was_described = portal.row_description_sent || portal.described_result.is_some();
384    portal.row_description_sent = false;
385    let result = match portal.described_result.take() {
386        Some(result) => Ok(result),
387        None => execute_pg_query_result(runtime, &portal.sql, &portal.params),
388    };
389    match result {
390        Ok(result) if was_described => {
391            emit_success_result_without_row_description(stream, &result).await
392        }
393        Ok(result) => emit_success_result(stream, &result).await,
394        Err(err) => {
395            let code = classify_sqlstate(&err);
396            send_error(stream, code, &err).await
397        }
398    }
399}
400
401async fn handle_close<S>(
402    stream: &mut S,
403    prepared: &mut HashMap<String, PgPreparedStatement>,
404    portals: &mut HashMap<String, PgPortal>,
405    msg: super::protocol::CloseMessage,
406) -> Result<(), PgWireError>
407where
408    S: AsyncWrite + Unpin,
409{
410    match msg.target {
411        DescribeTarget::Statement => {
412            prepared.remove(&msg.name);
413        }
414        DescribeTarget::Portal => {
415            portals.remove(&msg.name);
416        }
417    }
418    write_frame(stream, &BackendMessage::CloseComplete).await
419}
420
421fn bind_pg_params(
422    stmt: &PgPreparedStatement,
423    msg: &super::protocol::BindMessage,
424) -> Result<Vec<Value>, String> {
425    if !matches!(msg.param_format_codes.len(), 0 | 1)
426        && msg.param_format_codes.len() != msg.params.len()
427    {
428        return Err("Bind format count must be 0, 1, or match parameter count".to_string());
429    }
430    msg.params
431        .iter()
432        .enumerate()
433        .map(|(idx, param)| {
434            let oid = stmt
435                .param_type_oids
436                .get(idx)
437                .copied()
438                .unwrap_or(PgOid::Unknown.as_u32());
439            let format_code = match msg.param_format_codes.as_slice() {
440                [] => 0,
441                [format] => *format,
442                formats => formats[idx],
443            };
444            pg_param_to_value(oid, format_code, param.as_deref())
445        })
446        .collect()
447}
448
449fn execute_pg_query_result(
450    runtime: &RedDBRuntime,
451    sql: &str,
452    params: &[Value],
453) -> Result<crate::runtime::RuntimeQueryResult, String> {
454    if let Some(result) = try_execute_pg_scalar_select(sql, params) {
455        return Ok(result);
456    }
457    if params.is_empty() {
458        return match translate_pg_catalog_query(runtime, sql) {
459            Ok(Some(result)) => Ok(crate::runtime::RuntimeQueryResult {
460                query: sql.to_string(),
461                mode: crate::storage::query::modes::QueryMode::Sql,
462                statement: "select",
463                engine: "pg-catalog",
464                result,
465                affected_rows: 0,
466                statement_type: "select",
467            }),
468            Ok(None) => runtime.execute_query(sql).map_err(|err| err.to_string()),
469            Err(err) => Err(err.to_string()),
470        };
471    }
472
473    let parsed = crate::storage::query::modes::parse_multi(sql).map_err(|err| err.to_string())?;
474    let bound =
475        crate::storage::query::user_params::bind(&parsed, params).map_err(|err| err.to_string())?;
476    runtime
477        .execute_query_expr(bound)
478        .map_err(|err| err.to_string())
479}
480
481fn try_execute_pg_scalar_select(
482    sql: &str,
483    params: &[Value],
484) -> Option<crate::runtime::RuntimeQueryResult> {
485    let index = pg_scalar_select_param_index(sql)?;
486    let value = params.get(index)?.clone();
487    let mut result = UnifiedResult::with_columns(vec!["?column?".to_string()]);
488    let mut record = UnifiedRecord::new();
489    record.set("?column?", value);
490    result.push(record);
491    Some(crate::runtime::RuntimeQueryResult {
492        query: sql.to_string(),
493        mode: crate::storage::query::modes::QueryMode::Sql,
494        statement: "select",
495        engine: "pg-wire",
496        result,
497        affected_rows: 0,
498        statement_type: "select",
499    })
500}
501
502fn pg_scalar_select_param_index(sql: &str) -> Option<usize> {
503    let trimmed = sql.trim().trim_end_matches(';').trim();
504    let lower = trimmed.to_ascii_lowercase();
505    let body = lower.strip_prefix("select ")?;
506    let param = if let Some(inner) = body.strip_prefix("cast(") {
507        let end = inner.find(" as ")?;
508        &inner[..end]
509    } else {
510        body.split_whitespace().next()?
511    };
512    let digits = param.strip_prefix('$')?;
513    let n = digits.parse::<usize>().ok()?;
514    n.checked_sub(1)
515}
516
517fn rewrite_pg_parameter_casts(sql: &str) -> String {
518    let mut out = String::with_capacity(sql.len());
519    let bytes = sql.as_bytes();
520    let mut cursor = 0;
521    let mut pos = 0;
522    while pos < bytes.len() {
523        if bytes[pos] != b'$' {
524            pos += 1;
525            continue;
526        }
527        let param_start = pos;
528        pos += 1;
529        let digits_start = pos;
530        while pos < bytes.len() && bytes[pos].is_ascii_digit() {
531            pos += 1;
532        }
533        if digits_start == pos {
534            continue;
535        }
536        if pos + 2 <= bytes.len() && &bytes[pos..pos + 2] == b"::" {
537            let param_end = pos;
538            pos += 2;
539            let type_start = pos;
540            while pos < bytes.len()
541                && (bytes[pos].is_ascii_alphanumeric() || matches!(bytes[pos], b'_' | b'.'))
542            {
543                pos += 1;
544            }
545            if type_start != pos {
546                out.push_str(&sql[cursor..param_start]);
547                out.push_str(&sql[param_start..param_end]);
548                cursor = pos;
549                continue;
550            }
551        }
552    }
553    out.push_str(&sql[cursor..]);
554    out
555}
556
557fn infer_pg_cast_param_type_oids(sql: &str) -> Vec<(usize, u32)> {
558    let mut out = Vec::new();
559    let bytes = sql.as_bytes();
560    let mut pos = 0;
561    while pos < bytes.len() {
562        if bytes[pos] != b'$' {
563            pos += 1;
564            continue;
565        }
566        pos += 1;
567        let digits_start = pos;
568        while pos < bytes.len() && bytes[pos].is_ascii_digit() {
569            pos += 1;
570        }
571        if digits_start == pos {
572            continue;
573        }
574        let Some(param_index) = sql[digits_start..pos]
575            .parse::<usize>()
576            .ok()
577            .and_then(|idx| idx.checked_sub(1))
578        else {
579            continue;
580        };
581        if pos + 2 > bytes.len() || &bytes[pos..pos + 2] != b"::" {
582            continue;
583        }
584        pos += 2;
585        let type_start = pos;
586        while pos < bytes.len()
587            && (bytes[pos].is_ascii_alphanumeric() || matches!(bytes[pos], b'_' | b'.'))
588        {
589            pos += 1;
590        }
591        if type_start == pos {
592            continue;
593        }
594        if let Some(oid) = pg_cast_type_oid(&sql[type_start..pos]) {
595            out.push((param_index, oid));
596        }
597    }
598    out
599}
600
601fn pg_cast_type_oid(ty: &str) -> Option<u32> {
602    let lower = ty.to_ascii_lowercase();
603    let short = lower.rsplit('.').next().unwrap_or(lower.as_str());
604    let oid = match short {
605        "bool" | "boolean" => PgOid::Bool,
606        "int2" | "smallint" => PgOid::Int2,
607        "int" | "int4" | "integer" => PgOid::Int4,
608        "int8" | "bigint" => PgOid::Int8,
609        "float4" | "real" => PgOid::Float4,
610        "float8" | "double" | "doubleprecision" => PgOid::Float8,
611        "numeric" | "decimal" => PgOid::Numeric,
612        "bytea" => PgOid::Bytea,
613        "json" => PgOid::Json,
614        "jsonb" => PgOid::Jsonb,
615        "text" => PgOid::Text,
616        "varchar" | "character varying" => PgOid::Varchar,
617        "uuid" => PgOid::Uuid,
618        "timestamp" => PgOid::Timestamp,
619        "timestamptz" | "timestampz" => PgOid::TimestampTz,
620        "vector" => PgOid::Vector,
621        _ => return None,
622    };
623    Some(oid.as_u32())
624}
625
626fn is_row_returning_query(sql: &str) -> bool {
627    let trimmed = sql.trim_start().to_ascii_lowercase();
628    trimmed.starts_with("select")
629        || trimmed.starts_with("with")
630        || trimmed.starts_with("ask")
631        || trimmed.starts_with("search")
632        || trimmed.starts_with("vector")
633        || trimmed.starts_with("hybrid")
634}
635
636fn is_ask_query(sql: &str) -> bool {
637    sql.trim_start().to_ascii_lowercase().starts_with("ask")
638}
639
640async fn send_auth_ok<S>(
641    stream: &mut S,
642    config: &PgWireConfig,
643    params: &super::protocol::StartupParams,
644) -> Result<(), PgWireError>
645where
646    S: AsyncWrite + Unpin,
647{
648    // Phase 3.1: trust auth. We always send AuthenticationOk.
649    write_frame(stream, &BackendMessage::AuthenticationOk).await?;
650
651    // Standard ParameterStatus frames. Drivers gate capabilities on these.
652    for (name, value) in [
653        ("server_version", config.server_version.as_str()),
654        ("server_encoding", "UTF8"),
655        ("client_encoding", "UTF8"),
656        ("DateStyle", "ISO, MDY"),
657        ("TimeZone", "UTC"),
658        ("integer_datetimes", "on"),
659        ("standard_conforming_strings", "on"),
660        (
661            "application_name",
662            params.get("application_name").unwrap_or(""),
663        ),
664    ] {
665        write_frame(
666            stream,
667            &BackendMessage::ParameterStatus {
668                name: name.to_string(),
669                value: value.to_string(),
670            },
671        )
672        .await?;
673    }
674
675    // BackendKeyData: (pid, secret_key). Used by CancelRequest; we don't
676    // honour cancels in 3.1 so random-ish values are fine.
677    write_frame(
678        stream,
679        &BackendMessage::BackendKeyData {
680            pid: std::process::id(),
681            key: 0xDEADBEEF,
682        },
683    )
684    .await?;
685
686    write_frame(
687        stream,
688        &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
689    )
690    .await?;
691    Ok(())
692}
693
694async fn handle_simple_query<S>(
695    stream: &mut S,
696    runtime: &RedDBRuntime,
697    sql: &str,
698) -> Result<(), PgWireError>
699where
700    S: AsyncWrite + Unpin,
701{
702    // Empty query convention: PG emits EmptyQueryResponse instead of a
703    // CommandComplete. Some clients (psql `\;`) rely on this.
704    if sql.trim().is_empty() {
705        write_frame(stream, &BackendMessage::EmptyQueryResponse).await?;
706        write_frame(
707            stream,
708            &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
709        )
710        .await?;
711        return Ok(());
712    }
713
714    if let Some(tag) = pg_session_compat_command_tag(sql) {
715        write_frame(stream, &BackendMessage::CommandComplete(tag.to_string())).await?;
716        write_frame(
717            stream,
718            &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
719        )
720        .await?;
721        return Ok(());
722    }
723
724    let query_result = match translate_pg_catalog_query(runtime, sql) {
725        Ok(Some(result)) => Ok(crate::runtime::RuntimeQueryResult {
726            query: sql.to_string(),
727            mode: crate::storage::query::modes::QueryMode::Sql,
728            statement: "select",
729            engine: "pg-catalog",
730            result,
731            affected_rows: 0,
732            statement_type: "select",
733        }),
734        Ok(None) => runtime.execute_query(sql),
735        Err(err) => Err(err),
736    };
737
738    match query_result {
739        Ok(result) => {
740            emit_success_result(stream, &result).await?;
741        }
742        Err(err) => {
743            // PG SQLSTATE class 42 covers syntax / binding errors; we use
744            // 42P01 (undefined_table) and 42601 (syntax_error) when we can
745            // detect; otherwise fall back to XX000 (internal error).
746            let code = classify_sqlstate(&err.to_string());
747            send_error(stream, code, &err.to_string()).await?;
748        }
749    }
750
751    write_frame(
752        stream,
753        &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
754    )
755    .await?;
756    Ok(())
757}
758
759fn pg_session_compat_command_tag(sql: &str) -> Option<&'static str> {
760    let lower = sql.trim().trim_end_matches(';').to_ascii_lowercase();
761    if lower.starts_with("set ") {
762        return Some("SET");
763    }
764    None
765}
766
767async fn emit_success_result<S>(
768    stream: &mut S,
769    result: &crate::runtime::RuntimeQueryResult,
770) -> Result<(), PgWireError>
771where
772    S: AsyncWrite + Unpin,
773{
774    if result.statement == "ask" {
775        emit_ask_result_row(stream, result).await?;
776        write_frame(
777            stream,
778            &BackendMessage::CommandComplete("SELECT 1".to_string()),
779        )
780        .await?;
781    } else if result_returns_rows(result) {
782        emit_result_rows(stream, &result.result).await?;
783        write_frame(
784            stream,
785            &BackendMessage::CommandComplete(format!("SELECT {}", result.result.records.len())),
786        )
787        .await?;
788    } else {
789        // DDL / DML / config statements: echo the runtime's
790        // high-level statement tag back. PG format is
791        // "<CMD> [<OID>] <COUNT>"; we keep the count where
792        // applicable and fall back to the runtime's message.
793        let tag = match result.statement_type {
794            "insert" => format!("INSERT 0 {}", result.affected_rows),
795            "update" => format!("UPDATE {}", result.affected_rows),
796            "delete" => format!("DELETE {}", result.affected_rows),
797            other => other.to_uppercase(),
798        };
799        write_frame(stream, &BackendMessage::CommandComplete(tag)).await?;
800    }
801    Ok(())
802}
803
804async fn emit_success_result_without_row_description<S>(
805    stream: &mut S,
806    result: &crate::runtime::RuntimeQueryResult,
807) -> Result<(), PgWireError>
808where
809    S: AsyncWrite + Unpin,
810{
811    if result.statement == "ask" {
812        let row = ask_query_result_to_pg_wire_row(result)
813            .ok_or_else(|| PgWireError::Protocol("ASK result missing row body".to_string()))?;
814        write_frame(stream, &BackendMessage::DataRow(row.cells)).await?;
815        write_frame(
816            stream,
817            &BackendMessage::CommandComplete("SELECT 1".to_string()),
818        )
819        .await?;
820    } else if result_returns_rows(result) {
821        emit_result_data_rows(stream, &result.result).await?;
822        write_frame(
823            stream,
824            &BackendMessage::CommandComplete(format!("SELECT {}", result.result.records.len())),
825        )
826        .await?;
827    } else {
828        let tag = match result.statement_type {
829            "insert" => format!("INSERT 0 {}", result.affected_rows),
830            "update" => format!("UPDATE {}", result.affected_rows),
831            "delete" => format!("DELETE {}", result.affected_rows),
832            other => other.to_uppercase(),
833        };
834        write_frame(stream, &BackendMessage::CommandComplete(tag)).await?;
835    }
836    Ok(())
837}
838
839async fn emit_row_description_for_result<S>(
840    stream: &mut S,
841    result: &crate::runtime::RuntimeQueryResult,
842) -> Result<(), PgWireError>
843where
844    S: AsyncWrite + Unpin,
845{
846    if result.statement == "ask" {
847        emit_ask_row_description(stream).await
848    } else if result_returns_rows(result) {
849        emit_result_row_description(stream, &result.result).await
850    } else {
851        write_frame(stream, &BackendMessage::NoData).await
852    }
853}
854
855fn result_returns_rows(result: &crate::runtime::RuntimeQueryResult) -> bool {
856    result.statement_type == "select"
857}
858
859async fn emit_result_rows<S>(
860    stream: &mut S,
861    result: &crate::storage::query::unified::UnifiedResult,
862) -> Result<(), PgWireError>
863where
864    S: AsyncWrite + Unpin,
865{
866    emit_result_row_description(stream, result).await?;
867    emit_result_data_rows(stream, result).await
868}
869
870async fn emit_result_row_description<S>(
871    stream: &mut S,
872    result: &crate::storage::query::unified::UnifiedResult,
873) -> Result<(), PgWireError>
874where
875    S: AsyncWrite + Unpin,
876{
877    // RowDescription: derived from the first record's column ordering.
878    // When `result.columns` is non-empty we honour that order; otherwise
879    // we synthesise one from the record's field order.
880    let columns: Vec<String> = if !result.columns.is_empty() {
881        result.columns.clone()
882    } else if let Some(first) = result.records.first() {
883        record_field_names(first)
884    } else {
885        Vec::new()
886    };
887
888    // Peek at the first record for per-column type OIDs. When there's no
889    // data row we fall back to TEXT for every column — clients render
890    // empty result sets happily.
891    let type_oids: Vec<PgOid> = columns
892        .iter()
893        .map(|col| {
894            result
895                .records
896                .first()
897                .and_then(|r| record_get(r, col))
898                .map(PgOid::from_value)
899                .unwrap_or(PgOid::Text)
900        })
901        .collect();
902
903    let descriptors: Vec<ColumnDescriptor> = columns
904        .iter()
905        .zip(type_oids.iter())
906        .map(|(name, oid)| ColumnDescriptor {
907            name: name.clone(),
908            table_oid: 0,
909            column_attr: 0,
910            type_oid: oid.as_u32(),
911            type_size: -1,
912            type_mod: -1,
913            format: 0,
914        })
915        .collect();
916
917    write_frame(stream, &BackendMessage::RowDescription(descriptors)).await
918}
919
920async fn emit_result_data_rows<S>(
921    stream: &mut S,
922    result: &crate::storage::query::unified::UnifiedResult,
923) -> Result<(), PgWireError>
924where
925    S: AsyncWrite + Unpin,
926{
927    let columns: Vec<String> = if !result.columns.is_empty() {
928        result.columns.clone()
929    } else if let Some(first) = result.records.first() {
930        record_field_names(first)
931    } else {
932        Vec::new()
933    };
934    for record in &result.records {
935        let fields: Vec<Option<Vec<u8>>> = columns
936            .iter()
937            .map(|col| record_get(record, col).and_then(value_to_pg_wire_bytes))
938            .collect();
939        write_frame(stream, &BackendMessage::DataRow(fields)).await?;
940    }
941
942    Ok(())
943}
944
945async fn emit_ask_result_row<S>(
946    stream: &mut S,
947    result: &crate::runtime::RuntimeQueryResult,
948) -> Result<(), PgWireError>
949where
950    S: AsyncWrite + Unpin,
951{
952    let row = ask_query_result_to_pg_wire_row(result)
953        .ok_or_else(|| PgWireError::Protocol("ASK result missing row body".to_string()))?;
954
955    emit_ask_row_description(stream).await?;
956    write_frame(stream, &BackendMessage::DataRow(row.cells)).await?;
957    Ok(())
958}
959
960async fn emit_ask_row_description<S>(stream: &mut S) -> Result<(), PgWireError>
961where
962    S: AsyncWrite + Unpin,
963{
964    let descriptors: Vec<ColumnDescriptor> = crate::runtime::ai::pg_wire_ask_row_encoder::columns()
965        .iter()
966        .map(|col| ColumnDescriptor {
967            name: col.name.to_string(),
968            table_oid: 0,
969            column_attr: 0,
970            type_oid: col.oid.as_u32(),
971            type_size: -1,
972            type_mod: -1,
973            format: 0,
974        })
975        .collect();
976    write_frame(stream, &BackendMessage::RowDescription(descriptors)).await
977}
978
979fn ask_query_result_to_pg_wire_row(
980    result: &crate::runtime::RuntimeQueryResult,
981) -> Option<crate::runtime::ai::pg_wire_ask_row_encoder::AskRow> {
982    if result.statement != "ask" {
983        return None;
984    }
985    let record = result.result.records.first()?;
986    let sources_flat_json =
987        json_field(record, "sources_flat").unwrap_or(crate::json::Value::Array(Vec::new()));
988    let citations_json =
989        json_field(record, "citations").unwrap_or(crate::json::Value::Array(Vec::new()));
990    let validation_json = json_field(record, "validation")
991        .unwrap_or_else(|| crate::json::Value::Object(Default::default()));
992
993    let effective_mode = match text_field(record, "mode").as_deref() {
994        Some("lenient") => Mode::Lenient,
995        _ => Mode::Strict,
996    };
997
998    let ask = AskResult {
999        answer: text_field(record, "answer")?,
1000        sources_flat: ask_sources_flat(&sources_flat_json),
1001        citations: ask_citations(&citations_json),
1002        validation: ask_validation(&validation_json),
1003        cache_hit: bool_field(record, "cache_hit").unwrap_or(false),
1004        provider: text_field(record, "provider").unwrap_or_default(),
1005        model: text_field(record, "model").unwrap_or_default(),
1006        prompt_tokens: u32_field(record, "prompt_tokens").unwrap_or(0),
1007        completion_tokens: u32_field(record, "completion_tokens").unwrap_or(0),
1008        cost_usd: f64_field(record, "cost_usd").unwrap_or(0.0),
1009        effective_mode,
1010        retry_count: u32_field(record, "retry_count").unwrap_or(0),
1011    };
1012
1013    Some(crate::runtime::ai::pg_wire_ask_row_encoder::encode(&ask))
1014}
1015
1016fn record_field<'a>(record: &'a UnifiedRecord, key: &str) -> Option<&'a Value> {
1017    record.iter_fields().find_map(|(name, value)| {
1018        let name: &str = name;
1019        (name == key).then_some(value)
1020    })
1021}
1022
1023fn text_field(record: &UnifiedRecord, key: &str) -> Option<String> {
1024    match record_field(record, key)? {
1025        Value::Text(s) => Some(s.to_string()),
1026        Value::Email(s) | Value::Url(s) | Value::NodeRef(s) | Value::EdgeRef(s) => Some(s.clone()),
1027        other => Some(other.to_string()),
1028    }
1029}
1030
1031fn bool_field(record: &UnifiedRecord, key: &str) -> Option<bool> {
1032    match record_field(record, key)? {
1033        Value::Boolean(value) => Some(*value),
1034        _ => None,
1035    }
1036}
1037
1038fn u32_field(record: &UnifiedRecord, key: &str) -> Option<u32> {
1039    match record_field(record, key)? {
1040        Value::Integer(n) => (*n >= 0).then_some((*n).min(u32::MAX as i64) as u32),
1041        Value::UnsignedInteger(n) => Some((*n).min(u32::MAX as u64) as u32),
1042        Value::BigInt(n)
1043        | Value::TimestampMs(n)
1044        | Value::Timestamp(n)
1045        | Value::Duration(n)
1046        | Value::Decimal(n) => (*n >= 0).then_some((*n).min(u32::MAX as i64) as u32),
1047        Value::Float(n) => (*n >= 0.0).then_some((*n).min(u32::MAX as f64) as u32),
1048        _ => None,
1049    }
1050}
1051
1052fn f64_field(record: &UnifiedRecord, key: &str) -> Option<f64> {
1053    match record_field(record, key)? {
1054        Value::Integer(n) => Some(*n as f64),
1055        Value::UnsignedInteger(n) => Some(*n as f64),
1056        Value::BigInt(n)
1057        | Value::TimestampMs(n)
1058        | Value::Timestamp(n)
1059        | Value::Duration(n)
1060        | Value::Decimal(n) => Some(*n as f64),
1061        Value::Float(n) => Some(*n),
1062        _ => None,
1063    }
1064}
1065
1066fn json_field(record: &UnifiedRecord, key: &str) -> Option<crate::json::Value> {
1067    match record_field(record, key)? {
1068        Value::Json(bytes) => crate::json::from_slice(bytes).ok(),
1069        Value::Text(text) => crate::json::from_str(text).ok(),
1070        _ => None,
1071    }
1072}
1073
1074fn ask_sources_flat(value: &crate::json::Value) -> Vec<SourceRow> {
1075    value
1076        .as_array()
1077        .unwrap_or(&[])
1078        .iter()
1079        .filter_map(|source| {
1080            let urn = source
1081                .get("urn")
1082                .and_then(crate::json::Value::as_str)?
1083                .to_string();
1084            let payload = source
1085                .get("payload")
1086                .and_then(crate::json::Value::as_str)
1087                .map(ToString::to_string)
1088                .unwrap_or_else(|| source.to_string_compact());
1089            Some(SourceRow { urn, payload })
1090        })
1091        .collect()
1092}
1093
1094fn ask_citations(value: &crate::json::Value) -> Vec<Citation> {
1095    value
1096        .as_array()
1097        .unwrap_or(&[])
1098        .iter()
1099        .filter_map(|citation| {
1100            let marker = citation
1101                .get("marker")
1102                .and_then(crate::json::Value::as_u64)?;
1103            let urn = citation
1104                .get("urn")
1105                .and_then(crate::json::Value::as_str)?
1106                .to_string();
1107            Some(Citation {
1108                marker: marker.min(u32::MAX as u64) as u32,
1109                urn,
1110            })
1111        })
1112        .collect()
1113}
1114
1115fn ask_validation(value: &crate::json::Value) -> Validation {
1116    Validation {
1117        ok: value
1118            .get("ok")
1119            .and_then(crate::json::Value::as_bool)
1120            .unwrap_or(true),
1121        warnings: validation_items(value, "warnings")
1122            .into_iter()
1123            .map(|(kind, detail)| ValidationWarning { kind, detail })
1124            .collect(),
1125        errors: validation_items(value, "errors")
1126            .into_iter()
1127            .map(|(kind, detail)| ValidationError { kind, detail })
1128            .collect(),
1129    }
1130}
1131
1132fn validation_items(value: &crate::json::Value, key: &str) -> Vec<(String, String)> {
1133    value
1134        .get(key)
1135        .and_then(crate::json::Value::as_array)
1136        .unwrap_or(&[])
1137        .iter()
1138        .filter_map(|item| {
1139            Some((
1140                item.get("kind")
1141                    .and_then(crate::json::Value::as_str)?
1142                    .to_string(),
1143                item.get("detail")
1144                    .and_then(crate::json::Value::as_str)
1145                    .unwrap_or("")
1146                    .to_string(),
1147            ))
1148        })
1149        .collect()
1150}
1151
1152/// Best-effort field lookup on a `UnifiedRecord`. The record API lives in
1153/// `storage::query::unified` and today uses `HashMap<String, Value>` under
1154/// the hood — we use `get` if it exists, else fall back to serialised map.
1155fn record_get<'a>(record: &'a UnifiedRecord, key: &str) -> Option<&'a Value> {
1156    record.get(key)
1157}
1158
1159/// Extract column names in iteration order from a single record. When
1160/// the caller didn't supply an explicit `columns` projection we use the
1161/// first record's field ordering as the canonical tuple shape.
1162///
1163/// HashMap iteration order is non-deterministic — for Phase 3.1 we
1164/// accept the shuffle since PG clients receive the ordered header via
1165/// RowDescription and match cells positionally. A stable ordering
1166/// would require keeping an insertion-order index alongside `values`.
1167fn record_field_names(record: &UnifiedRecord) -> Vec<String> {
1168    // `column_names()` merges the columnar scan side-channel with
1169    // the HashMap so scan rows (which populate only columnar) still
1170    // surface their field names in PG wire output.
1171    record
1172        .column_names()
1173        .into_iter()
1174        .map(|k| k.to_string())
1175        .collect()
1176}
1177
1178async fn send_error<S>(stream: &mut S, code: &str, message: &str) -> Result<(), PgWireError>
1179where
1180    S: AsyncWrite + Unpin,
1181{
1182    write_frame(
1183        stream,
1184        &BackendMessage::ErrorResponse {
1185            severity: "ERROR".to_string(),
1186            code: code.to_string(),
1187            message: message.to_string(),
1188        },
1189    )
1190    .await
1191}
1192
1193/// Heuristically map a runtime error message onto a PG SQLSTATE. Full
1194/// coverage would map every `RedDBError` variant; this is enough for the
1195/// common psql / JDBC paths.
1196fn classify_sqlstate(msg: &str) -> &'static str {
1197    let lower = msg.to_ascii_lowercase();
1198    if lower.contains("not found") || lower.contains("does not exist") {
1199        // 42P01 undefined_table; close enough for collection-not-found.
1200        "42P01"
1201    } else if lower.contains("parse") || lower.contains("expected") || lower.contains("syntax") {
1202        "42601"
1203    } else if lower.contains("already exists") {
1204        "42P07"
1205    } else if lower.contains("permission") || lower.contains("auth") {
1206        "28000"
1207    } else {
1208        "XX000"
1209    }
1210}
1211
1212#[cfg(test)]
1213mod tests {
1214    use super::*;
1215    use crate::api::RedDBOptions;
1216    use crate::runtime::RuntimeQueryResult;
1217    use crate::storage::query::modes::QueryMode;
1218    use crate::storage::query::unified::UnifiedResult;
1219    use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
1220
1221    #[tokio::test]
1222    async fn extended_parse_bind_execute_returns_rows() {
1223        let runtime = Arc::new(RedDBRuntime::with_options(RedDBOptions::in_memory()).unwrap());
1224        let config = Arc::new(PgWireConfig::default());
1225        let (server_io, mut client_io) = tokio::io::duplex(64 * 1024);
1226        let server = tokio::spawn(async move {
1227            handle_connection(server_io, runtime, config).await.unwrap();
1228        });
1229
1230        write_startup(&mut client_io).await;
1231        read_until_ready(&mut client_io).await;
1232
1233        write_frontend_frame(
1234            &mut client_io,
1235            b'P',
1236            parse_body("", "SELECT $1::int", &[PgOid::Int4.as_u32()]),
1237        )
1238        .await;
1239        write_frontend_frame(
1240            &mut client_io,
1241            b'B',
1242            bind_body("", "", &[0], &[Some(b"42".as_slice())], &[]),
1243        )
1244        .await;
1245        write_frontend_frame(&mut client_io, b'D', describe_body(b'P', "")).await;
1246        write_frontend_frame(&mut client_io, b'E', execute_body("", 0)).await;
1247        write_frontend_frame(&mut client_io, b'S', Vec::new()).await;
1248
1249        let frames = read_until_ready(&mut client_io).await;
1250        assert_eq!(
1251            frames.iter().map(|(tag, _)| *tag).collect::<Vec<_>>(),
1252            b"12TDCZ"
1253        );
1254        let columns = decode_row_description(&frames[2].1);
1255        assert_eq!(columns.len(), 1);
1256        let cells = decode_data_row(&frames[3].1);
1257        assert_eq!(cells.len(), 1);
1258        assert_eq!(cells[0].as_deref(), Some(b"42".as_slice()));
1259        assert_eq!(decode_command_complete(&frames[4].1), "SELECT 1");
1260
1261        write_frontend_frame(&mut client_io, b'X', Vec::new()).await;
1262        server.await.unwrap();
1263    }
1264
1265    #[test]
1266    fn infer_pg_cast_param_type_oids_from_parameter_casts() {
1267        assert_eq!(
1268            infer_pg_cast_param_type_oids("INSERT INTO t (id, name) VALUES ($1::int, $2::text)"),
1269            vec![(0, PgOid::Int4.as_u32()), (1, PgOid::Text.as_u32())]
1270        );
1271        assert_eq!(
1272            infer_pg_cast_param_type_oids("SEARCH SIMILAR [1.0] COLLECTION v LIMIT $1::int8"),
1273            vec![(0, PgOid::Int8.as_u32())]
1274        );
1275    }
1276
1277    #[test]
1278    fn pg_session_compat_accepts_driver_setup_set_commands() {
1279        assert_eq!(
1280            pg_session_compat_command_tag("SET extra_float_digits = 3"),
1281            Some("SET")
1282        );
1283        assert_eq!(
1284            pg_session_compat_command_tag("SET application_name = 'pgjdbc'"),
1285            Some("SET")
1286        );
1287        assert_eq!(pg_session_compat_command_tag("SELECT 1"), None);
1288    }
1289
1290    #[tokio::test]
1291    async fn ask_success_result_uses_canonical_pg_wire_row_shape() {
1292        let mut result = UnifiedResult::with_columns(vec![
1293            "answer".into(),
1294            "provider".into(),
1295            "model".into(),
1296            "prompt_tokens".into(),
1297            "completion_tokens".into(),
1298            "sources_count".into(),
1299            "sources_flat".into(),
1300            "citations".into(),
1301            "validation".into(),
1302        ]);
1303        let mut record = UnifiedRecord::new();
1304        record.set("answer", Value::text("Deploy failed [^1]."));
1305        record.set("provider", Value::text("openai"));
1306        record.set("model", Value::text("gpt-4o-mini"));
1307        record.set("prompt_tokens", Value::Integer(11));
1308        record.set("completion_tokens", Value::Integer(7));
1309        record.set(
1310            "sources_flat",
1311            Value::Json(
1312                br#"[{"urn":"urn:reddb:row:deployments:1","kind":"row","collection":"deployments","id":"1"}]"#
1313                    .to_vec(),
1314            ),
1315        );
1316        record.set(
1317            "citations",
1318            Value::Json(br#"[{"marker":1,"urn":"urn:reddb:row:deployments:1"}]"#.to_vec()),
1319        );
1320        record.set(
1321            "validation",
1322            Value::Json(br#"{"ok":true,"warnings":[],"errors":[]}"#.to_vec()),
1323        );
1324        result.push(record);
1325
1326        let qr = RuntimeQueryResult {
1327            query: "ASK 'why did deploy fail?'".to_string(),
1328            mode: QueryMode::Sql,
1329            statement: "ask",
1330            engine: "runtime-ai",
1331            result,
1332            affected_rows: 0,
1333            statement_type: "select",
1334        };
1335
1336        let mut out = Vec::new();
1337        emit_success_result(&mut out, &qr).await.unwrap();
1338        let frames = decode_frames(&out);
1339
1340        assert_eq!(
1341            frames.iter().map(|(tag, _)| *tag).collect::<Vec<_>>(),
1342            b"TDC"
1343        );
1344
1345        let columns = decode_row_description(frames[0].1);
1346        assert_eq!(
1347            columns,
1348            vec![
1349                ("answer".to_string(), PgOid::Text.as_u32()),
1350                ("cache_hit".to_string(), PgOid::Bool.as_u32()),
1351                ("citations".to_string(), PgOid::Jsonb.as_u32()),
1352                ("completion_tokens".to_string(), PgOid::Int8.as_u32()),
1353                ("cost_usd".to_string(), PgOid::Numeric.as_u32()),
1354                ("mode".to_string(), PgOid::Text.as_u32()),
1355                ("model".to_string(), PgOid::Text.as_u32()),
1356                ("prompt_tokens".to_string(), PgOid::Int8.as_u32()),
1357                ("provider".to_string(), PgOid::Text.as_u32()),
1358                ("retry_count".to_string(), PgOid::Int8.as_u32()),
1359                ("sources_flat".to_string(), PgOid::Jsonb.as_u32()),
1360                ("validation".to_string(), PgOid::Jsonb.as_u32()),
1361            ]
1362        );
1363
1364        let cells = decode_data_row(frames[1].1);
1365        assert_eq!(cells.len(), 12);
1366        assert_eq!(cells[0].as_deref(), Some(b"Deploy failed [^1].".as_slice()));
1367        assert_eq!(cells[1].as_deref(), Some(b"f".as_slice()));
1368        assert_eq!(cells[4].as_deref(), Some(b"0".as_slice()));
1369        assert_eq!(cells[5].as_deref(), Some(b"strict".as_slice()));
1370        assert_eq!(cells[9].as_deref(), Some(b"0".as_slice()));
1371        assert!(std::str::from_utf8(cells[10].as_deref().unwrap())
1372            .unwrap()
1373            .contains(r#""payload""#));
1374        assert_eq!(decode_command_complete(frames[2].1), "SELECT 1");
1375    }
1376
1377    fn decode_frames(bytes: &[u8]) -> Vec<(u8, &[u8])> {
1378        let mut pos = 0;
1379        let mut frames = Vec::new();
1380        while pos < bytes.len() {
1381            let tag = bytes[pos];
1382            let len = u32::from_be_bytes([
1383                bytes[pos + 1],
1384                bytes[pos + 2],
1385                bytes[pos + 3],
1386                bytes[pos + 4],
1387            ]) as usize;
1388            let body_start = pos + 5;
1389            let body_end = pos + 1 + len;
1390            frames.push((tag, &bytes[body_start..body_end]));
1391            pos = body_end;
1392        }
1393        frames
1394    }
1395
1396    fn decode_row_description(body: &[u8]) -> Vec<(String, u32)> {
1397        let count = i16::from_be_bytes([body[0], body[1]]) as usize;
1398        let mut pos = 2;
1399        let mut columns = Vec::with_capacity(count);
1400        for _ in 0..count {
1401            let end = body[pos..].iter().position(|&b| b == 0).unwrap() + pos;
1402            let name = std::str::from_utf8(&body[pos..end]).unwrap().to_string();
1403            pos = end + 1;
1404            pos += 4; // table oid
1405            pos += 2; // column attr
1406            let oid = u32::from_be_bytes([body[pos], body[pos + 1], body[pos + 2], body[pos + 3]]);
1407            pos += 4;
1408            pos += 2; // type size
1409            pos += 4; // type mod
1410            pos += 2; // format
1411            columns.push((name, oid));
1412        }
1413        columns
1414    }
1415
1416    fn decode_data_row(body: &[u8]) -> Vec<Option<Vec<u8>>> {
1417        let count = i16::from_be_bytes([body[0], body[1]]) as usize;
1418        let mut pos = 2;
1419        let mut cells = Vec::with_capacity(count);
1420        for _ in 0..count {
1421            let len = i32::from_be_bytes([body[pos], body[pos + 1], body[pos + 2], body[pos + 3]]);
1422            pos += 4;
1423            if len < 0 {
1424                cells.push(None);
1425            } else {
1426                let len = len as usize;
1427                cells.push(Some(body[pos..pos + len].to_vec()));
1428                pos += len;
1429            }
1430        }
1431        cells
1432    }
1433
1434    fn decode_command_complete(body: &[u8]) -> &str {
1435        let nul = body.iter().position(|&b| b == 0).unwrap_or(body.len());
1436        std::str::from_utf8(&body[..nul]).unwrap()
1437    }
1438
1439    async fn write_startup<W: AsyncWrite + Unpin>(stream: &mut W) {
1440        let mut payload = Vec::new();
1441        payload.extend_from_slice(&crate::wire::postgres::protocol::PG_PROTOCOL_V3.to_be_bytes());
1442        payload.extend_from_slice(b"user\0reddb\0");
1443        payload.push(0);
1444        let len = (payload.len() + 4) as u32;
1445        stream.write_all(&len.to_be_bytes()).await.unwrap();
1446        stream.write_all(&payload).await.unwrap();
1447    }
1448
1449    async fn write_frontend_frame<W: AsyncWrite + Unpin>(
1450        stream: &mut W,
1451        tag: u8,
1452        payload: Vec<u8>,
1453    ) {
1454        stream.write_all(&[tag]).await.unwrap();
1455        stream
1456            .write_all(&((payload.len() + 4) as u32).to_be_bytes())
1457            .await
1458            .unwrap();
1459        stream.write_all(&payload).await.unwrap();
1460    }
1461
1462    async fn read_backend_frame<R: AsyncRead + Unpin>(stream: &mut R) -> (u8, Vec<u8>) {
1463        let mut tag = [0u8; 1];
1464        stream.read_exact(&mut tag).await.unwrap();
1465        let mut len = [0u8; 4];
1466        stream.read_exact(&mut len).await.unwrap();
1467        let len = u32::from_be_bytes(len) as usize;
1468        let mut body = vec![0u8; len - 4];
1469        stream.read_exact(&mut body).await.unwrap();
1470        (tag[0], body)
1471    }
1472
1473    async fn read_until_ready<R: AsyncRead + Unpin>(stream: &mut R) -> Vec<(u8, Vec<u8>)> {
1474        let mut frames = Vec::new();
1475        loop {
1476            let frame = read_backend_frame(stream).await;
1477            let done = frame.0 == b'Z';
1478            frames.push(frame);
1479            if done {
1480                return frames;
1481            }
1482        }
1483    }
1484
1485    fn parse_body(statement: &str, query: &str, oids: &[u32]) -> Vec<u8> {
1486        let mut out = Vec::new();
1487        push_pg_cstring(&mut out, statement);
1488        push_pg_cstring(&mut out, query);
1489        out.extend_from_slice(&(oids.len() as i16).to_be_bytes());
1490        for oid in oids {
1491            out.extend_from_slice(&oid.to_be_bytes());
1492        }
1493        out
1494    }
1495
1496    fn bind_body(
1497        portal: &str,
1498        statement: &str,
1499        formats: &[i16],
1500        params: &[Option<&[u8]>],
1501        result_formats: &[i16],
1502    ) -> Vec<u8> {
1503        let mut out = Vec::new();
1504        push_pg_cstring(&mut out, portal);
1505        push_pg_cstring(&mut out, statement);
1506        out.extend_from_slice(&(formats.len() as i16).to_be_bytes());
1507        for format in formats {
1508            out.extend_from_slice(&format.to_be_bytes());
1509        }
1510        out.extend_from_slice(&(params.len() as i16).to_be_bytes());
1511        for param in params {
1512            match param {
1513                Some(bytes) => {
1514                    out.extend_from_slice(&(bytes.len() as i32).to_be_bytes());
1515                    out.extend_from_slice(bytes);
1516                }
1517                None => out.extend_from_slice(&(-1i32).to_be_bytes()),
1518            }
1519        }
1520        out.extend_from_slice(&(result_formats.len() as i16).to_be_bytes());
1521        for format in result_formats {
1522            out.extend_from_slice(&format.to_be_bytes());
1523        }
1524        out
1525    }
1526
1527    fn describe_body(target: u8, name: &str) -> Vec<u8> {
1528        let mut out = vec![target];
1529        push_pg_cstring(&mut out, name);
1530        out
1531    }
1532
1533    fn execute_body(portal: &str, max_rows: u32) -> Vec<u8> {
1534        let mut out = Vec::new();
1535        push_pg_cstring(&mut out, portal);
1536        out.extend_from_slice(&max_rows.to_be_bytes());
1537        out
1538    }
1539
1540    fn push_pg_cstring(out: &mut Vec<u8>, value: &str) {
1541        out.extend_from_slice(value.as_bytes());
1542        out.push(0);
1543    }
1544}