Skip to main content

reddb_server/wire/postgres/
protocol.rs

1//! PostgreSQL v3 wire protocol message framing (Phase 3.1 PG parity).
2//!
3//! Implements the bits of the PG v3 protocol RedDB needs for simple
4//! query support: startup negotiation, authentication (trust), the
5//! simple query flow (`Q` → `T`/`D`*/`C`/`Z`), and error reporting.
6//!
7//! The full PG reference lives at:
8//! <https://www.postgresql.org/docs/current/protocol-message-formats.html>
9//!
10//! # Frame format (v3)
11//!
12//! After the startup message, every frame is:
13//! ```text
14//! [u8 type] [i32 length (includes itself)] [payload]
15//! ```
16//! Frames are big-endian. We use `tokio::io::AsyncRead/Write` so the
17//! listener can plug into the same task model as the existing wire
18//! binary protocol.
19
20use std::io;
21
22use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
23
24/// Protocol version constant: 3.0 → 196608 (major<<16 | minor).
25pub const PG_PROTOCOL_V3: u32 = 3 << 16;
26
27/// Special startup-phase requests that share the StartupMessage length
28/// header. The PG reference calls out three: SSLRequest (80877103),
29/// GSSENCRequest (80877104), CancelRequest (80877102).
30pub const PG_SSL_REQUEST: u32 = 80877103;
31pub const PG_GSSENC_REQUEST: u32 = 80877104;
32pub const PG_CANCEL_REQUEST: u32 = 80877102;
33
34/// Error type surfaced by the framing layer. Wraps IO errors plus
35/// structural validation failures (bad message tag, truncated frame).
36#[derive(Debug)]
37pub enum PgWireError {
38    Io(io::Error),
39    Protocol(String),
40    /// Client closed the connection cleanly (EOF before a frame).
41    Eof,
42}
43
44impl From<io::Error> for PgWireError {
45    fn from(err: io::Error) -> Self {
46        if err.kind() == io::ErrorKind::UnexpectedEof {
47            PgWireError::Eof
48        } else {
49            PgWireError::Io(err)
50        }
51    }
52}
53
54impl std::fmt::Display for PgWireError {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        match self {
57            PgWireError::Io(e) => write!(f, "pg wire io: {e}"),
58            PgWireError::Protocol(m) => write!(f, "pg wire protocol: {m}"),
59            PgWireError::Eof => write!(f, "pg wire eof"),
60        }
61    }
62}
63
64impl std::error::Error for PgWireError {}
65
66/// Frontend (client → server) messages we parse.
67#[derive(Debug, Clone)]
68pub enum FrontendMessage {
69    /// Pre-handshake StartupMessage payload (parameters map).
70    Startup(StartupParams),
71    /// SSL handshake request — we reject with 'N' (not supported).
72    SslRequest,
73    /// GSSAPI encryption request — we reject with 'N'.
74    GssEncRequest,
75    /// `Q` — simple query.
76    Query(String),
77    /// `P` — Parse: name a prepared statement and its SQL text.
78    Parse(ParseMessage),
79    /// `B` — Bind: attach concrete parameter bytes to a named statement.
80    Bind(BindMessage),
81    /// `D` — Describe a prepared statement or portal.
82    Describe(DescribeMessage),
83    /// `E` — Execute a bound portal.
84    Execute(ExecuteMessage),
85    /// `C` — Close a prepared statement or portal.
86    Close(CloseMessage),
87    /// `p` — password / SASL response. Payload is ignored for `trust` auth.
88    PasswordMessage(Vec<u8>),
89    /// `X` — Terminate.
90    Terminate,
91    /// `H` — Flush. Send buffered results.
92    Flush,
93    /// `S` — Sync. End of extended query batch.
94    Sync,
95    /// Any other frame we don't implement yet; carries the raw tag for
96    /// logging / ErrorResponse reply.
97    Unknown { tag: u8, payload: Vec<u8> },
98}
99
100#[derive(Debug, Clone)]
101pub struct ParseMessage {
102    pub statement: String,
103    pub query: String,
104    pub param_type_oids: Vec<u32>,
105}
106
107#[derive(Debug, Clone)]
108pub struct BindMessage {
109    pub portal: String,
110    pub statement: String,
111    pub param_format_codes: Vec<i16>,
112    pub params: Vec<Option<Vec<u8>>>,
113    pub result_format_codes: Vec<i16>,
114}
115
116#[derive(Debug, Clone)]
117pub struct DescribeMessage {
118    pub target: DescribeTarget,
119    pub name: String,
120}
121
122#[derive(Debug, Clone, Copy, PartialEq, Eq)]
123pub enum DescribeTarget {
124    Statement,
125    Portal,
126}
127
128#[derive(Debug, Clone)]
129pub struct ExecuteMessage {
130    pub portal: String,
131    pub max_rows: u32,
132}
133
134#[derive(Debug, Clone)]
135pub struct CloseMessage {
136    pub target: DescribeTarget,
137    pub name: String,
138}
139
140#[derive(Debug, Clone, Default)]
141pub struct StartupParams {
142    /// Key/value pairs from the startup message (user, database, etc.).
143    pub params: Vec<(String, String)>,
144}
145
146impl StartupParams {
147    pub fn get(&self, key: &str) -> Option<&str> {
148        self.params
149            .iter()
150            .find(|(k, _)| k == key)
151            .map(|(_, v)| v.as_str())
152    }
153}
154
155/// Backend (server → client) messages we emit.
156#[derive(Debug, Clone)]
157pub enum BackendMessage {
158    /// `R` — AuthenticationOk (subtype 0).
159    AuthenticationOk,
160    /// `S` — ParameterStatus (server_version, client_encoding, ...).
161    ParameterStatus { name: String, value: String },
162    /// `K` — BackendKeyData (cancel key).
163    BackendKeyData { pid: u32, key: u32 },
164    /// `Z` — ReadyForQuery. Status: 'I' idle, 'T' in-txn, 'E' failed-txn.
165    ReadyForQuery(TransactionStatus),
166    /// `T` — RowDescription.
167    RowDescription(Vec<ColumnDescriptor>),
168    /// `D` — DataRow. Each field is `Some(bytes)` or `None` (NULL).
169    DataRow(Vec<Option<Vec<u8>>>),
170    /// `C` — CommandComplete (e.g. "SELECT 3", "INSERT 0 1").
171    CommandComplete(String),
172    /// `1` — ParseComplete.
173    ParseComplete,
174    /// `2` — BindComplete.
175    BindComplete,
176    /// `3` — CloseComplete.
177    CloseComplete,
178    /// `t` — ParameterDescription.
179    ParameterDescription(Vec<u32>),
180    /// `n` — NoData.
181    NoData,
182    /// `E` — ErrorResponse with severity + code + message.
183    ErrorResponse {
184        severity: String,
185        code: String,
186        message: String,
187    },
188    /// `N` — NoticeResponse (non-fatal).
189    NoticeResponse { message: String },
190    /// `I` — EmptyQueryResponse.
191    EmptyQueryResponse,
192}
193
194#[derive(Debug, Clone, Copy, PartialEq, Eq)]
195pub enum TransactionStatus {
196    /// Not inside a transaction.
197    Idle,
198    /// Inside a transaction block.
199    InTransaction,
200    /// Failed transaction, awaiting ROLLBACK.
201    Failed,
202}
203
204impl TransactionStatus {
205    pub fn as_byte(self) -> u8 {
206        match self {
207            TransactionStatus::Idle => b'I',
208            TransactionStatus::InTransaction => b'T',
209            TransactionStatus::Failed => b'E',
210        }
211    }
212}
213
214#[derive(Debug, Clone)]
215pub struct ColumnDescriptor {
216    pub name: String,
217    /// Table OID (0 when not from a real table — common for computed columns).
218    pub table_oid: u32,
219    /// Column attribute number within the table (0 when synthetic).
220    pub column_attr: i16,
221    /// PG type OID (`pg_type.oid`).
222    pub type_oid: u32,
223    /// Fixed size of the data type, or -1 for variable length.
224    pub type_size: i16,
225    /// Type modifier (e.g. VARCHAR(n) → n+4). -1 when unused.
226    pub type_mod: i32,
227    /// Format code: 0 = text, 1 = binary. We always emit text in 3.1.
228    pub format: i16,
229}
230
231// ────────────────────────────────────────────────────────────────────
232// Frontend parsing
233// ────────────────────────────────────────────────────────────────────
234
235/// Read the initial StartupMessage (or SSL/GSS request). The startup
236/// frame has no type byte — just a length prefix followed by the
237/// payload. Returns either a decoded Startup/SSL/GSS message or an error.
238pub async fn read_startup<R: AsyncRead + Unpin>(
239    stream: &mut R,
240) -> Result<FrontendMessage, PgWireError> {
241    let mut len_buf = [0u8; 4];
242    stream.read_exact(&mut len_buf).await?;
243    let len = u32::from_be_bytes(len_buf);
244    if !(8..=65536).contains(&len) {
245        return Err(PgWireError::Protocol(format!(
246            "startup length {len} out of range"
247        )));
248    }
249    let body_len = (len as usize) - 4;
250    let mut body = vec![0u8; body_len];
251    stream.read_exact(&mut body).await?;
252    if body_len < 4 {
253        return Err(PgWireError::Protocol("startup payload too short".into()));
254    }
255    let version = u32::from_be_bytes([body[0], body[1], body[2], body[3]]);
256
257    match version {
258        PG_SSL_REQUEST => Ok(FrontendMessage::SslRequest),
259        PG_GSSENC_REQUEST => Ok(FrontendMessage::GssEncRequest),
260        PG_PROTOCOL_V3 => {
261            // Parameter map is a run of null-terminated strings terminated
262            // by an empty string.
263            let mut params: Vec<(String, String)> = Vec::new();
264            let mut pos = 4usize;
265            while pos < body_len {
266                if body[pos] == 0 {
267                    break;
268                }
269                let key = read_cstring(&body, &mut pos)?;
270                if pos >= body_len {
271                    return Err(PgWireError::Protocol(
272                        "startup parameter missing value".into(),
273                    ));
274                }
275                let value = read_cstring(&body, &mut pos)?;
276                params.push((key, value));
277            }
278            Ok(FrontendMessage::Startup(StartupParams { params }))
279        }
280        // CancelRequest is sent on a fresh connection and doesn't produce
281        // a response — surface as Unknown so caller can close.
282        PG_CANCEL_REQUEST => Ok(FrontendMessage::Unknown {
283            tag: b'K',
284            payload: body,
285        }),
286        _ => Err(PgWireError::Protocol(format!(
287            "unsupported protocol version {version}"
288        ))),
289    }
290}
291
292/// Read a regular tagged frame after the startup handshake.
293pub async fn read_frame<R: AsyncRead + Unpin>(
294    stream: &mut R,
295) -> Result<FrontendMessage, PgWireError> {
296    let mut tag_buf = [0u8; 1];
297    match stream.read_exact(&mut tag_buf).await {
298        Ok(_) => {}
299        Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Err(PgWireError::Eof),
300        Err(e) => return Err(PgWireError::Io(e)),
301    }
302    let tag = tag_buf[0];
303
304    let mut len_buf = [0u8; 4];
305    stream.read_exact(&mut len_buf).await?;
306    let len = u32::from_be_bytes(len_buf);
307    if !(4..=1_048_576).contains(&len) {
308        return Err(PgWireError::Protocol(format!(
309            "frame length {len} out of bounds"
310        )));
311    }
312    let payload_len = (len as usize) - 4;
313    let mut payload = vec![0u8; payload_len];
314    stream.read_exact(&mut payload).await?;
315
316    Ok(match tag {
317        b'Q' => {
318            // Null-terminated SQL string.
319            let mut pos = 0;
320            let query = read_cstring(&payload, &mut pos)?;
321            FrontendMessage::Query(query)
322        }
323        b'P' => FrontendMessage::Parse(parse_parse_message(&payload)?),
324        b'B' => FrontendMessage::Bind(parse_bind_message(&payload)?),
325        b'D' => FrontendMessage::Describe(parse_describe_message(&payload)?),
326        b'E' => FrontendMessage::Execute(parse_execute_message(&payload)?),
327        b'C' => FrontendMessage::Close(parse_close_message(&payload)?),
328        b'p' => FrontendMessage::PasswordMessage(payload),
329        b'X' => FrontendMessage::Terminate,
330        b'H' => FrontendMessage::Flush,
331        b'S' => FrontendMessage::Sync,
332        other => FrontendMessage::Unknown {
333            tag: other,
334            payload,
335        },
336    })
337}
338
339// ────────────────────────────────────────────────────────────────────
340// Backend emission
341// ────────────────────────────────────────────────────────────────────
342
343/// Emit a raw byte (used for the SSL/GSS negotiation response: 'N'
344/// meaning "not supported, continue in plaintext").
345pub async fn write_raw_byte<W: AsyncWrite + Unpin>(
346    stream: &mut W,
347    byte: u8,
348) -> Result<(), PgWireError> {
349    stream.write_all(&[byte]).await?;
350    Ok(())
351}
352
353/// Serialize + send a backend message.
354pub async fn write_frame<W: AsyncWrite + Unpin>(
355    stream: &mut W,
356    msg: &BackendMessage,
357) -> Result<(), PgWireError> {
358    let (tag, payload) = encode_backend(msg);
359    // Length includes the length field itself (4 bytes) + payload.
360    let length = (payload.len() + 4) as u32;
361    stream.write_all(&[tag]).await?;
362    stream.write_all(&length.to_be_bytes()).await?;
363    stream.write_all(&payload).await?;
364    Ok(())
365}
366
367/// F-02 (audit doc, 2026-05-06):
368/// PG3 wire encodes user-controlled bytes as `tag|value|NUL` C-strings
369/// in `ErrorResponse`, `NoticeResponse`, `CommandComplete`,
370/// `RowDescription` column names, and `ParameterStatus`. An embedded
371/// NUL in a user-supplied message field truncates the C-string and
372/// lets an attacker smuggle additional protocol fields into the frame.
373///
374/// Mitigation: every byte slice that gets followed by a `\0` terminator
375/// passes through `sanitize_cstring_bytes` first, which substitutes the
376/// Unicode replacement codepoint `U+FFFD` (3 UTF-8 bytes: `EF BF BD`)
377/// for any embedded NUL byte. The substitution preserves the visible
378/// shape of the message for debugging without giving an attacker a
379/// path to inject a synthetic protocol field. Emitting `U+FFFD` is
380/// safe for the PG client side: every PG client we know of reports
381/// errors as opaque strings rather than parsing them.
382fn sanitize_cstring_bytes(input: &[u8]) -> Vec<u8> {
383    if !input.contains(&0) {
384        return input.to_vec();
385    }
386    let mut out = Vec::with_capacity(input.len() + 8);
387    for &b in input {
388        if b == 0 {
389            // U+FFFD REPLACEMENT CHARACTER (UTF-8 EF BF BD)
390            out.extend_from_slice(&[0xEF, 0xBF, 0xBD]);
391        } else {
392            out.push(b);
393        }
394    }
395    out
396}
397
398#[inline]
399fn push_cstring(buf: &mut Vec<u8>, value: &str) {
400    buf.extend_from_slice(&sanitize_cstring_bytes(value.as_bytes()));
401    buf.push(0);
402}
403
404fn encode_backend(msg: &BackendMessage) -> (u8, Vec<u8>) {
405    match msg {
406        BackendMessage::AuthenticationOk => {
407            // Subtype 0 = AuthenticationOk.
408            (b'R', vec![0, 0, 0, 0])
409        }
410        BackendMessage::ParameterStatus { name, value } => {
411            let mut buf = Vec::with_capacity(name.len() + value.len() + 2);
412            // F-02: name + value are user-controlled in some pathways.
413            push_cstring(&mut buf, name);
414            push_cstring(&mut buf, value);
415            (b'S', buf)
416        }
417        BackendMessage::BackendKeyData { pid, key } => {
418            let mut buf = Vec::with_capacity(8);
419            buf.extend_from_slice(&pid.to_be_bytes());
420            buf.extend_from_slice(&key.to_be_bytes());
421            (b'K', buf)
422        }
423        BackendMessage::ReadyForQuery(status) => (b'Z', vec![status.as_byte()]),
424        BackendMessage::RowDescription(cols) => {
425            let mut buf = Vec::new();
426            buf.extend_from_slice(&(cols.len() as i16).to_be_bytes());
427            for col in cols {
428                // F-02: column name is user-derived (SELECT ... AS "x\0y").
429                push_cstring(&mut buf, &col.name);
430                buf.extend_from_slice(&col.table_oid.to_be_bytes());
431                buf.extend_from_slice(&col.column_attr.to_be_bytes());
432                buf.extend_from_slice(&col.type_oid.to_be_bytes());
433                buf.extend_from_slice(&col.type_size.to_be_bytes());
434                buf.extend_from_slice(&col.type_mod.to_be_bytes());
435                buf.extend_from_slice(&col.format.to_be_bytes());
436            }
437            (b'T', buf)
438        }
439        BackendMessage::DataRow(fields) => {
440            let mut buf = Vec::new();
441            buf.extend_from_slice(&(fields.len() as i16).to_be_bytes());
442            for field in fields {
443                match field {
444                    None => {
445                        // -1 length signals NULL.
446                        buf.extend_from_slice(&(-1i32).to_be_bytes());
447                    }
448                    Some(bytes) => {
449                        // DataRow uses length-prefixed bytes, NOT
450                        // C-strings — embedded NULs are legal here
451                        // and must NOT be sanitized.
452                        buf.extend_from_slice(&(bytes.len() as i32).to_be_bytes());
453                        buf.extend_from_slice(bytes);
454                    }
455                }
456            }
457            (b'D', buf)
458        }
459        BackendMessage::CommandComplete(tag) => {
460            let mut buf = Vec::with_capacity(tag.len() + 1);
461            // F-02: command tag includes user-influenced row counts /
462            // statement classes; sanitize before NUL-terminating.
463            push_cstring(&mut buf, tag);
464            (b'C', buf)
465        }
466        BackendMessage::ParseComplete => (b'1', Vec::new()),
467        BackendMessage::BindComplete => (b'2', Vec::new()),
468        BackendMessage::CloseComplete => (b'3', Vec::new()),
469        BackendMessage::ParameterDescription(oids) => {
470            let mut buf = Vec::with_capacity(2 + oids.len() * 4);
471            buf.extend_from_slice(&(oids.len() as i16).to_be_bytes());
472            for oid in oids {
473                buf.extend_from_slice(&oid.to_be_bytes());
474            }
475            (b't', buf)
476        }
477        BackendMessage::NoData => (b'n', Vec::new()),
478        BackendMessage::ErrorResponse {
479            severity,
480            code,
481            message,
482        } => {
483            let mut buf = Vec::new();
484            // Field 'S' = severity (ERROR, FATAL, PANIC, ...)
485            buf.push(b'S');
486            push_cstring(&mut buf, severity);
487            // Field 'V' = non-localized severity (PG 9.6+).
488            buf.push(b'V');
489            push_cstring(&mut buf, severity);
490            // Field 'C' = SQLSTATE.
491            buf.push(b'C');
492            push_cstring(&mut buf, code);
493            // Field 'M' = human message — F-02 primary attack surface.
494            buf.push(b'M');
495            push_cstring(&mut buf, message);
496            // Trailing null terminator ends the field list.
497            buf.push(0);
498            (b'E', buf)
499        }
500        BackendMessage::NoticeResponse { message } => {
501            let mut buf = Vec::new();
502            buf.push(b'S');
503            buf.extend_from_slice(b"NOTICE");
504            buf.push(0);
505            buf.push(b'M');
506            // F-02: message is user-influenced.
507            push_cstring(&mut buf, message);
508            buf.push(0);
509            (b'N', buf)
510        }
511        BackendMessage::EmptyQueryResponse => (b'I', Vec::new()),
512    }
513}
514
515// ────────────────────────────────────────────────────────────────────
516// Helpers
517// ────────────────────────────────────────────────────────────────────
518
519/// Read a C-string (null-terminated UTF-8) starting at `pos`. Advances
520/// `pos` past the terminator. Returns `Protocol` error when malformed.
521fn read_cstring(buf: &[u8], pos: &mut usize) -> Result<String, PgWireError> {
522    let start = *pos;
523    while *pos < buf.len() && buf[*pos] != 0 {
524        *pos += 1;
525    }
526    if *pos >= buf.len() {
527        return Err(PgWireError::Protocol("cstring missing terminator".into()));
528    }
529    let s = std::str::from_utf8(&buf[start..*pos])
530        .map_err(|e| PgWireError::Protocol(format!("invalid utf8: {e}")))?
531        .to_string();
532    *pos += 1; // skip null
533    Ok(s)
534}
535
536fn parse_parse_message(payload: &[u8]) -> Result<ParseMessage, PgWireError> {
537    let mut pos = 0;
538    let statement = read_cstring(payload, &mut pos)?;
539    let query = read_cstring(payload, &mut pos)?;
540    let nparams = read_i16(payload, &mut pos, "Parse parameter count")?;
541    if nparams < 0 {
542        return Err(PgWireError::Protocol(
543            "negative Parse parameter count".into(),
544        ));
545    }
546    let mut param_type_oids = Vec::with_capacity(nparams as usize);
547    for _ in 0..nparams {
548        param_type_oids.push(read_u32(payload, &mut pos, "Parse parameter type OID")?);
549    }
550    ensure_consumed(payload, pos, "Parse")?;
551    Ok(ParseMessage {
552        statement,
553        query,
554        param_type_oids,
555    })
556}
557
558fn parse_bind_message(payload: &[u8]) -> Result<BindMessage, PgWireError> {
559    let mut pos = 0;
560    let portal = read_cstring(payload, &mut pos)?;
561    let statement = read_cstring(payload, &mut pos)?;
562
563    let nformats = read_i16(payload, &mut pos, "Bind format count")?;
564    if nformats < 0 {
565        return Err(PgWireError::Protocol("negative Bind format count".into()));
566    }
567    let mut param_format_codes = Vec::with_capacity(nformats as usize);
568    for _ in 0..nformats {
569        param_format_codes.push(read_i16(payload, &mut pos, "Bind format code")?);
570    }
571
572    let nparams = read_i16(payload, &mut pos, "Bind parameter count")?;
573    if nparams < 0 {
574        return Err(PgWireError::Protocol(
575            "negative Bind parameter count".into(),
576        ));
577    }
578    let mut params = Vec::with_capacity(nparams as usize);
579    for _ in 0..nparams {
580        let len = read_i32(payload, &mut pos, "Bind parameter length")?;
581        if len == -1 {
582            params.push(None);
583        } else if len < -1 {
584            return Err(PgWireError::Protocol(
585                "invalid Bind parameter length".into(),
586            ));
587        } else {
588            params.push(Some(
589                read_bytes(payload, &mut pos, len as usize, "Bind parameter")?.to_vec(),
590            ));
591        }
592    }
593
594    let nresult_formats = read_i16(payload, &mut pos, "Bind result format count")?;
595    if nresult_formats < 0 {
596        return Err(PgWireError::Protocol(
597            "negative Bind result format count".into(),
598        ));
599    }
600    let mut result_format_codes = Vec::with_capacity(nresult_formats as usize);
601    for _ in 0..nresult_formats {
602        result_format_codes.push(read_i16(payload, &mut pos, "Bind result format code")?);
603    }
604    ensure_consumed(payload, pos, "Bind")?;
605
606    Ok(BindMessage {
607        portal,
608        statement,
609        param_format_codes,
610        params,
611        result_format_codes,
612    })
613}
614
615fn parse_describe_message(payload: &[u8]) -> Result<DescribeMessage, PgWireError> {
616    let mut pos = 0;
617    let target = read_describe_target(payload, &mut pos, "Describe")?;
618    let name = read_cstring(payload, &mut pos)?;
619    ensure_consumed(payload, pos, "Describe")?;
620    Ok(DescribeMessage { target, name })
621}
622
623fn parse_execute_message(payload: &[u8]) -> Result<ExecuteMessage, PgWireError> {
624    let mut pos = 0;
625    let portal = read_cstring(payload, &mut pos)?;
626    let max_rows = read_u32(payload, &mut pos, "Execute max rows")?;
627    ensure_consumed(payload, pos, "Execute")?;
628    Ok(ExecuteMessage { portal, max_rows })
629}
630
631fn parse_close_message(payload: &[u8]) -> Result<CloseMessage, PgWireError> {
632    let mut pos = 0;
633    let target = read_describe_target(payload, &mut pos, "Close")?;
634    let name = read_cstring(payload, &mut pos)?;
635    ensure_consumed(payload, pos, "Close")?;
636    Ok(CloseMessage { target, name })
637}
638
639fn read_describe_target(
640    payload: &[u8],
641    pos: &mut usize,
642    frame: &'static str,
643) -> Result<DescribeTarget, PgWireError> {
644    let byte = *read_bytes(payload, pos, 1, frame)?
645        .first()
646        .expect("one target byte");
647    match byte {
648        b'S' => Ok(DescribeTarget::Statement),
649        b'P' => Ok(DescribeTarget::Portal),
650        other => Err(PgWireError::Protocol(format!(
651            "{frame} target must be 'S' or 'P', got 0x{other:02x}"
652        ))),
653    }
654}
655
656fn read_i16(payload: &[u8], pos: &mut usize, field: &'static str) -> Result<i16, PgWireError> {
657    let bytes = read_bytes(payload, pos, 2, field)?;
658    Ok(i16::from_be_bytes([bytes[0], bytes[1]]))
659}
660
661fn read_i32(payload: &[u8], pos: &mut usize, field: &'static str) -> Result<i32, PgWireError> {
662    let bytes = read_bytes(payload, pos, 4, field)?;
663    Ok(i32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
664}
665
666fn read_u32(payload: &[u8], pos: &mut usize, field: &'static str) -> Result<u32, PgWireError> {
667    let bytes = read_bytes(payload, pos, 4, field)?;
668    Ok(u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
669}
670
671fn read_bytes<'a>(
672    payload: &'a [u8],
673    pos: &mut usize,
674    len: usize,
675    field: &'static str,
676) -> Result<&'a [u8], PgWireError> {
677    let end = pos
678        .checked_add(len)
679        .ok_or_else(|| PgWireError::Protocol(format!("{field} length overflow")))?;
680    if end > payload.len() {
681        return Err(PgWireError::Protocol(format!("{field} truncated")));
682    }
683    let bytes = &payload[*pos..end];
684    *pos = end;
685    Ok(bytes)
686}
687
688fn ensure_consumed(payload: &[u8], pos: usize, frame: &'static str) -> Result<(), PgWireError> {
689    if pos == payload.len() {
690        Ok(())
691    } else {
692        Err(PgWireError::Protocol(format!(
693            "{frame} had {} trailing bytes",
694            payload.len() - pos
695        )))
696    }
697}
698
699#[cfg(test)]
700mod tests {
701    use super::*;
702
703    #[tokio::test]
704    async fn parse_startup_v3() {
705        // length (4) + version (4) + user\0val\0 + terminator\0
706        let mut payload: Vec<u8> = Vec::new();
707        payload.extend_from_slice(&PG_PROTOCOL_V3.to_be_bytes());
708        payload.extend_from_slice(b"user\0alice\0");
709        payload.push(0);
710        let len = (4 + payload.len()) as u32;
711        let mut frame = Vec::new();
712        frame.extend_from_slice(&len.to_be_bytes());
713        frame.extend_from_slice(&payload);
714
715        let mut cursor = std::io::Cursor::new(frame);
716        let msg = read_startup(&mut cursor).await.unwrap();
717        match msg {
718            FrontendMessage::Startup(params) => {
719                assert_eq!(params.get("user"), Some("alice"));
720            }
721            other => panic!("expected Startup, got {:?}", other),
722        }
723    }
724
725    #[tokio::test]
726    async fn parse_ssl_request() {
727        let mut frame: Vec<u8> = Vec::new();
728        frame.extend_from_slice(&8u32.to_be_bytes());
729        frame.extend_from_slice(&PG_SSL_REQUEST.to_be_bytes());
730        let mut cursor = std::io::Cursor::new(frame);
731        assert!(matches!(
732            read_startup(&mut cursor).await.unwrap(),
733            FrontendMessage::SslRequest
734        ));
735    }
736
737    #[tokio::test]
738    async fn parse_query_frame() {
739        let query = "SELECT 1\0";
740        let mut frame = Vec::new();
741        frame.push(b'Q');
742        let len = (4 + query.len()) as u32;
743        frame.extend_from_slice(&len.to_be_bytes());
744        frame.extend_from_slice(query.as_bytes());
745        let mut cursor = std::io::Cursor::new(frame);
746        match read_frame(&mut cursor).await.unwrap() {
747            FrontendMessage::Query(s) => assert_eq!(s, "SELECT 1"),
748            other => panic!("expected Query, got {:?}", other),
749        }
750    }
751
752    #[tokio::test]
753    async fn parse_extended_query_frames() {
754        let mut parse_payload = Vec::new();
755        push_test_cstring(&mut parse_payload, "");
756        push_test_cstring(&mut parse_payload, "SELECT $1");
757        parse_payload.extend_from_slice(&1i16.to_be_bytes());
758        parse_payload.extend_from_slice(&23u32.to_be_bytes());
759        let mut frame = tagged_frame(b'P', parse_payload);
760        let mut cursor = std::io::Cursor::new(frame);
761        match read_frame(&mut cursor).await.unwrap() {
762            FrontendMessage::Parse(msg) => {
763                assert_eq!(msg.statement, "");
764                assert_eq!(msg.query, "SELECT $1");
765                assert_eq!(msg.param_type_oids, vec![23]);
766            }
767            other => panic!("expected Parse, got {other:?}"),
768        }
769
770        let mut bind_payload = Vec::new();
771        push_test_cstring(&mut bind_payload, "");
772        push_test_cstring(&mut bind_payload, "");
773        bind_payload.extend_from_slice(&1i16.to_be_bytes());
774        bind_payload.extend_from_slice(&0i16.to_be_bytes());
775        bind_payload.extend_from_slice(&1i16.to_be_bytes());
776        bind_payload.extend_from_slice(&2i32.to_be_bytes());
777        bind_payload.extend_from_slice(b"42");
778        bind_payload.extend_from_slice(&0i16.to_be_bytes());
779        frame = tagged_frame(b'B', bind_payload);
780        let mut cursor = std::io::Cursor::new(frame);
781        match read_frame(&mut cursor).await.unwrap() {
782            FrontendMessage::Bind(msg) => {
783                assert_eq!(msg.portal, "");
784                assert_eq!(msg.statement, "");
785                assert_eq!(msg.param_format_codes, vec![0]);
786                assert_eq!(msg.params, vec![Some(b"42".to_vec())]);
787                assert!(msg.result_format_codes.is_empty());
788            }
789            other => panic!("expected Bind, got {other:?}"),
790        }
791
792        let mut describe_payload = vec![b'P'];
793        push_test_cstring(&mut describe_payload, "");
794        let mut cursor = std::io::Cursor::new(tagged_frame(b'D', describe_payload));
795        assert!(matches!(
796            read_frame(&mut cursor).await.unwrap(),
797            FrontendMessage::Describe(DescribeMessage {
798                target: DescribeTarget::Portal,
799                ..
800            })
801        ));
802    }
803
804    #[tokio::test]
805    async fn emit_ready_for_query() {
806        let mut out: Vec<u8> = Vec::new();
807        write_frame(
808            &mut out,
809            &BackendMessage::ReadyForQuery(TransactionStatus::Idle),
810        )
811        .await
812        .unwrap();
813        assert_eq!(out, vec![b'Z', 0, 0, 0, 5, b'I']);
814    }
815
816    #[tokio::test]
817    async fn emit_row_description_and_data_row() {
818        let mut out: Vec<u8> = Vec::new();
819        write_frame(
820            &mut out,
821            &BackendMessage::RowDescription(vec![ColumnDescriptor {
822                name: "id".to_string(),
823                table_oid: 0,
824                column_attr: 0,
825                type_oid: 23,
826                type_size: 4,
827                type_mod: -1,
828                format: 0,
829            }]),
830        )
831        .await
832        .unwrap();
833        assert_eq!(out[0], b'T');
834
835        let mut data: Vec<u8> = Vec::new();
836        write_frame(
837            &mut data,
838            &BackendMessage::DataRow(vec![Some(b"42".to_vec()), None]),
839        )
840        .await
841        .unwrap();
842        assert_eq!(data[0], b'D');
843    }
844
845    #[tokio::test]
846    async fn emit_extended_completion_frames() {
847        let mut out = Vec::new();
848        write_frame(&mut out, &BackendMessage::ParseComplete)
849            .await
850            .unwrap();
851        write_frame(&mut out, &BackendMessage::BindComplete)
852            .await
853            .unwrap();
854        write_frame(
855            &mut out,
856            &BackendMessage::ParameterDescription(vec![23, 25]),
857        )
858        .await
859        .unwrap();
860        write_frame(&mut out, &BackendMessage::NoData)
861            .await
862            .unwrap();
863        write_frame(&mut out, &BackendMessage::CloseComplete)
864            .await
865            .unwrap();
866        assert_eq!(collect_tags(&out), vec![b'1', b'2', b't', b'n', b'3']);
867    }
868
869    // ---------------------------------------------------------------
870    // F-02 (audit doc 2026-05-06): NUL-injection rejection in PG3
871    // C-string fields. Replacement codepoint U+FFFD is emitted
872    // instead of the raw NUL so the field cannot be terminated
873    // prematurely on the wire.
874    // ---------------------------------------------------------------
875
876    fn count_nul(buf: &[u8]) -> usize {
877        buf.iter().filter(|&&b| b == 0).count()
878    }
879
880    #[tokio::test]
881    async fn pg3_nul_error_response_message_field_sanitized() {
882        let mut out: Vec<u8> = Vec::new();
883        write_frame(
884            &mut out,
885            &BackendMessage::ErrorResponse {
886                severity: "ERROR".to_string(),
887                code: "42000".to_string(),
888                message: "smuggled\0M\x00injection".to_string(),
889            },
890        )
891        .await
892        .unwrap();
893        assert_eq!(out[0], b'E');
894        // ErrorResponse body: 4 inner C-string terminators (S/V/C/M)
895        // + 1 list-end terminator = 5 total NULs. The message field
896        // had 2 raw NULs in it; if not sanitized we'd see 7 NULs.
897        let body = &out[5..];
898        assert_eq!(
899            count_nul(body),
900            5,
901            "expected 5 NULs (4 field + 1 list-end), got {} :: body={:?}",
902            count_nul(body),
903            body
904        );
905        // U+FFFD must be present (EF BF BD).
906        assert!(
907            body.windows(3).any(|w| w == [0xEF, 0xBF, 0xBD]),
908            "expected U+FFFD substitution in body"
909        );
910    }
911
912    #[tokio::test]
913    async fn pg3_nul_notice_response_sanitized() {
914        let mut out: Vec<u8> = Vec::new();
915        write_frame(
916            &mut out,
917            &BackendMessage::NoticeResponse {
918                message: "evil\0field".to_string(),
919            },
920        )
921        .await
922        .unwrap();
923        assert_eq!(out[0], b'N');
924        let body = &out[5..];
925        // 2 inner C-string terminators (S, M) + 1 list-end = 3 NULs.
926        assert_eq!(count_nul(body), 3);
927        assert!(body.windows(3).any(|w| w == [0xEF, 0xBF, 0xBD]));
928    }
929
930    #[tokio::test]
931    async fn pg3_nul_command_complete_sanitized() {
932        let mut out: Vec<u8> = Vec::new();
933        write_frame(
934            &mut out,
935            &BackendMessage::CommandComplete("SELECT\0;DROP".to_string()),
936        )
937        .await
938        .unwrap();
939        assert_eq!(out[0], b'C');
940        let body = &out[5..];
941        // CommandComplete = single C-string + terminator -> 1 NUL.
942        assert_eq!(count_nul(body), 1);
943    }
944
945    #[tokio::test]
946    async fn pg3_nul_row_description_column_name_sanitized() {
947        let mut out: Vec<u8> = Vec::new();
948        write_frame(
949            &mut out,
950            &BackendMessage::RowDescription(vec![ColumnDescriptor {
951                name: "evil\0col".to_string(),
952                table_oid: 0,
953                column_attr: 0,
954                type_oid: 23,
955                type_size: 4,
956                type_mod: -1,
957                format: 0,
958            }]),
959        )
960        .await
961        .unwrap();
962        assert_eq!(out[0], b'T');
963        // The column-name region (after the i16 field count, before
964        // the OIDs) must contain exactly one terminator, not two.
965        let body = &out[5..];
966        // Skip 2 bytes (column count i16); next bytes up to the
967        // first NUL are the column name.
968        let name_region = &body[2..];
969        let first_nul = name_region.iter().position(|&b| b == 0).unwrap();
970        assert!(
971            name_region[..first_nul]
972                .windows(3)
973                .any(|w| w == [0xEF, 0xBF, 0xBD]),
974            "U+FFFD missing from sanitized column name"
975        );
976    }
977
978    #[test]
979    fn sanitize_cstring_fastpath_no_nul() {
980        let s = "no nuls here";
981        let out = sanitize_cstring_bytes(s.as_bytes());
982        assert_eq!(out, s.as_bytes());
983    }
984
985    #[test]
986    fn sanitize_cstring_substitutes_nul_with_replacement_codepoint() {
987        let s = b"a\0b\0c";
988        let out = sanitize_cstring_bytes(s);
989        // Each NUL becomes 3 bytes; total = 1 + 3 + 1 + 3 + 1 = 9.
990        assert_eq!(out.len(), 9);
991        assert!(!out.contains(&0));
992        assert_eq!(&out[1..4], &[0xEF, 0xBF, 0xBD]);
993        assert_eq!(&out[5..8], &[0xEF, 0xBF, 0xBD]);
994    }
995
996    fn tagged_frame(tag: u8, payload: Vec<u8>) -> Vec<u8> {
997        let mut frame = vec![tag];
998        frame.extend_from_slice(&((payload.len() + 4) as u32).to_be_bytes());
999        frame.extend_from_slice(&payload);
1000        frame
1001    }
1002
1003    fn push_test_cstring(out: &mut Vec<u8>, value: &str) {
1004        out.extend_from_slice(value.as_bytes());
1005        out.push(0);
1006    }
1007
1008    fn collect_tags(bytes: &[u8]) -> Vec<u8> {
1009        let mut tags = Vec::new();
1010        let mut pos = 0;
1011        while pos < bytes.len() {
1012            tags.push(bytes[pos]);
1013            let len = u32::from_be_bytes([
1014                bytes[pos + 1],
1015                bytes[pos + 2],
1016                bytes[pos + 3],
1017                bytes[pos + 4],
1018            ]) as usize;
1019            pos += 1 + len;
1020        }
1021        tags
1022    }
1023}