Skip to main content

qail_pg/protocol/
wire.rs

1//! PostgreSQL Wire Protocol Messages
2//!
3//! Implementation of the PostgreSQL Frontend/Backend Protocol.
4//! Reference: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
5
6/// Frontend (client → server) message types
7#[derive(Debug, Clone)]
8pub enum FrontendMessage {
9    /// Startup message (sent first, no type byte)
10    Startup {
11        /// Database role / user name.
12        user: String,
13        /// Target database name.
14        database: String,
15    },
16    /// Password response (MD5 or cleartext).
17    PasswordMessage(String),
18    /// Simple query (SQL text).
19    Query(String),
20    /// Parse (prepared statement)
21    Parse {
22        /// Prepared statement name (empty string = unnamed).
23        name: String,
24        /// SQL query text with `$1`-style parameter placeholders.
25        query: String,
26        /// OIDs of the parameter types (empty = server infers).
27        param_types: Vec<u32>,
28    },
29    /// Bind parameters to prepared statement
30    Bind {
31        /// Destination portal name (empty = unnamed).
32        portal: String,
33        /// Source prepared statement name.
34        statement: String,
35        /// Parameter values (`None` = SQL NULL).
36        params: Vec<Option<Vec<u8>>>,
37    },
38    /// Execute portal
39    Execute {
40        /// Portal name to execute.
41        portal: String,
42        /// Maximum rows to return (0 = no limit).
43        max_rows: i32,
44    },
45    /// Sync — marks the end of an extended-query pipeline.
46    Sync,
47    /// Terminate — closes the connection.
48    Terminate,
49    /// SASL initial response (first message in SCRAM)
50    SASLInitialResponse {
51        /// SASL mechanism name (e.g. `SCRAM-SHA-256`).
52        mechanism: String,
53        /// Client-first message bytes.
54        data: Vec<u8>,
55    },
56    /// SASL response (subsequent messages in SCRAM)
57    SASLResponse(Vec<u8>),
58    /// CopyFail — abort a COPY IN with an error message
59    CopyFail(String),
60    /// Close — explicitly release a prepared statement or portal
61    Close {
62        /// `true` for portal, `false` for prepared statement.
63        is_portal: bool,
64        /// Name of the portal or statement to close.
65        name: String,
66    },
67}
68
69/// Backend (server → client) message types
70#[derive(Debug, Clone)]
71pub enum BackendMessage {
72    /// Authentication request
73    /// Authentication succeeded.
74    AuthenticationOk,
75    /// Server requests MD5-hashed password; salt provided.
76    AuthenticationMD5Password([u8; 4]),
77    /// Server initiates SASL handshake with supported mechanisms.
78    AuthenticationSASL(Vec<String>),
79    /// SASL challenge from server.
80    AuthenticationSASLContinue(Vec<u8>),
81    /// SASL authentication complete; final server data.
82    AuthenticationSASLFinal(Vec<u8>),
83    /// Parameter status (server config)
84    ParameterStatus {
85        /// Parameter name (e.g. `server_version`, `TimeZone`).
86        name: String,
87        /// Current parameter value.
88        value: String,
89    },
90    /// Backend key data (for cancel)
91    BackendKeyData {
92        /// Backend process ID (used for cancel requests).
93        process_id: i32,
94        /// Cancel secret key.
95        secret_key: i32,
96    },
97    /// Server is ready; transaction state indicated.
98    ReadyForQuery(TransactionStatus),
99    /// Column metadata for the upcoming data rows.
100    RowDescription(Vec<FieldDescription>),
101    /// One data row; each element is `None` for SQL NULL or the raw bytes.
102    DataRow(Vec<Option<Vec<u8>>>),
103    /// Command completed with a tag like `SELECT 5` or `INSERT 0 1`.
104    CommandComplete(String),
105    /// Error response with structured fields (severity, code, message, etc.).
106    ErrorResponse(ErrorFields),
107    /// Parse step succeeded.
108    ParseComplete,
109    /// Bind step succeeded.
110    BindComplete,
111    /// Describe returned no row description (e.g. for DML statements).
112    NoData,
113    /// Copy in response (server ready to receive COPY data)
114    CopyInResponse {
115        /// Overall format: 0 = text, 1 = binary.
116        format: u8,
117        /// Per-column format codes.
118        column_formats: Vec<u8>,
119    },
120    /// Copy out response (server will send COPY data)
121    CopyOutResponse {
122        /// Overall format: 0 = text, 1 = binary.
123        format: u8,
124        /// Per-column format codes.
125        column_formats: Vec<u8>,
126    },
127    /// Raw COPY data chunk from the server.
128    CopyData(Vec<u8>),
129    /// COPY transfer complete.
130    CopyDone,
131    /// Notification response (async notification from LISTEN/NOTIFY)
132    NotificationResponse {
133        /// Backend process ID that sent the notification.
134        process_id: i32,
135        /// Channel name.
136        channel: String,
137        /// Notification payload string.
138        payload: String,
139    },
140    /// Empty query string was submitted.
141    EmptyQueryResponse,
142    /// Notice response (warning/info messages, not errors)
143    NoticeResponse(ErrorFields),
144    /// Parameter description (OIDs of parameters in a prepared statement)
145    /// Sent by server in response to Describe(Statement)
146    ParameterDescription(Vec<u32>),
147    /// Close complete (server confirmation that a prepared statement/portal was released)
148    CloseComplete,
149}
150
151/// Transaction status
152#[derive(Debug, Clone, Copy)]
153pub enum TransactionStatus {
154    /// Not inside a transaction block (`I`).
155    Idle,
156    /// Inside a transaction block (`T`).
157    InBlock,
158    /// Inside a failed transaction block (`E`).
159    Failed,
160}
161
162/// Field description in RowDescription
163#[derive(Debug, Clone)]
164pub struct FieldDescription {
165    /// Column name (or alias).
166    pub name: String,
167    /// OID of the source table (0 if not a table column).
168    pub table_oid: u32,
169    /// Column attribute number within the table (0 if not a table column).
170    pub column_attr: i16,
171    /// OID of the column's data type.
172    pub type_oid: u32,
173    /// Data type size in bytes (negative = variable-length).
174    pub type_size: i16,
175    /// Type-specific modifier (e.g. precision for `numeric`).
176    pub type_modifier: i32,
177    /// Format code: 0 = text, 1 = binary.
178    pub format: i16,
179}
180
181/// Error fields from ErrorResponse
182#[derive(Debug, Clone, Default)]
183pub struct ErrorFields {
184    /// Severity level (e.g. `ERROR`, `FATAL`, `WARNING`).
185    pub severity: String,
186    /// SQLSTATE error code (e.g. `23505` for unique violation).
187    pub code: String,
188    /// Human-readable error message.
189    pub message: String,
190    /// Optional detailed error description.
191    pub detail: Option<String>,
192    /// Optional hint for resolving the error.
193    pub hint: Option<String>,
194}
195
196impl FrontendMessage {
197    /// Encode message to bytes for sending over the wire.
198    pub fn encode(&self) -> Vec<u8> {
199        match self {
200            FrontendMessage::Startup { user, database } => {
201                let mut buf = Vec::new();
202                // Protocol version 3.0
203                buf.extend_from_slice(&196608i32.to_be_bytes());
204                // Parameters
205                buf.extend_from_slice(b"user\0");
206                buf.extend_from_slice(user.as_bytes());
207                buf.push(0);
208                buf.extend_from_slice(b"database\0");
209                buf.extend_from_slice(database.as_bytes());
210                buf.push(0);
211                buf.push(0); // Terminator
212
213                // Prepend length (includes length itself)
214                let len = (buf.len() + 4) as i32;
215                let mut result = len.to_be_bytes().to_vec();
216                result.extend(buf);
217                result
218            }
219            FrontendMessage::Query(sql) => {
220                let mut buf = Vec::new();
221                buf.push(b'Q');
222                let content = format!("{}\0", sql);
223                let len = (content.len() + 4) as i32;
224                buf.extend_from_slice(&len.to_be_bytes());
225                buf.extend_from_slice(content.as_bytes());
226                buf
227            }
228            FrontendMessage::Terminate => {
229                vec![b'X', 0, 0, 0, 4]
230            }
231            FrontendMessage::SASLInitialResponse { mechanism, data } => {
232                let mut buf = Vec::new();
233                buf.push(b'p'); // SASLInitialResponse uses 'p'
234
235                let mut content = Vec::new();
236                content.extend_from_slice(mechanism.as_bytes());
237                content.push(0); // null-terminated mechanism
238                content.extend_from_slice(&(data.len() as i32).to_be_bytes());
239                content.extend_from_slice(data);
240
241                let len = (content.len() + 4) as i32;
242                buf.extend_from_slice(&len.to_be_bytes());
243                buf.extend_from_slice(&content);
244                buf
245            }
246            FrontendMessage::SASLResponse(data) => {
247                let mut buf = Vec::new();
248                buf.push(b'p');
249
250                let len = (data.len() + 4) as i32;
251                buf.extend_from_slice(&len.to_be_bytes());
252                buf.extend_from_slice(data);
253                buf
254            }
255            FrontendMessage::PasswordMessage(password) => {
256                let mut buf = Vec::new();
257                buf.push(b'p');
258                let content = format!("{}\0", password);
259                let len = (content.len() + 4) as i32;
260                buf.extend_from_slice(&len.to_be_bytes());
261                buf.extend_from_slice(content.as_bytes());
262                buf
263            }
264            FrontendMessage::Parse { name, query, param_types } => {
265                let mut buf = Vec::new();
266                buf.push(b'P');
267
268                let mut content = Vec::new();
269                content.extend_from_slice(name.as_bytes());
270                content.push(0);
271                content.extend_from_slice(query.as_bytes());
272                content.push(0);
273                content.extend_from_slice(&(param_types.len() as i16).to_be_bytes());
274                for oid in param_types {
275                    content.extend_from_slice(&oid.to_be_bytes());
276                }
277
278                let len = (content.len() + 4) as i32;
279                buf.extend_from_slice(&len.to_be_bytes());
280                buf.extend_from_slice(&content);
281                buf
282            }
283            FrontendMessage::Bind { portal, statement, params } => {
284                let mut buf = Vec::new();
285                buf.push(b'B');
286
287                let mut content = Vec::new();
288                content.extend_from_slice(portal.as_bytes());
289                content.push(0);
290                content.extend_from_slice(statement.as_bytes());
291                content.push(0);
292                // Format codes (0 = all text)
293                content.extend_from_slice(&0i16.to_be_bytes());
294                // Parameter count
295                content.extend_from_slice(&(params.len() as i16).to_be_bytes());
296                for param in params {
297                    match param {
298                        Some(data) => {
299                            content.extend_from_slice(&(data.len() as i32).to_be_bytes());
300                            content.extend_from_slice(data);
301                        }
302                        None => content.extend_from_slice(&(-1i32).to_be_bytes()),
303                    }
304                }
305                // Result format codes (0 = all text)
306                content.extend_from_slice(&0i16.to_be_bytes());
307
308                let len = (content.len() + 4) as i32;
309                buf.extend_from_slice(&len.to_be_bytes());
310                buf.extend_from_slice(&content);
311                buf
312            }
313            FrontendMessage::Execute { portal, max_rows } => {
314                let mut buf = Vec::new();
315                buf.push(b'E');
316
317                let mut content = Vec::new();
318                content.extend_from_slice(portal.as_bytes());
319                content.push(0);
320                content.extend_from_slice(&max_rows.to_be_bytes());
321
322                let len = (content.len() + 4) as i32;
323                buf.extend_from_slice(&len.to_be_bytes());
324                buf.extend_from_slice(&content);
325                buf
326            }
327            FrontendMessage::Sync => {
328                vec![b'S', 0, 0, 0, 4]
329            }
330            FrontendMessage::CopyFail(msg) => {
331                let mut buf = Vec::new();
332                buf.push(b'f');
333                let content = format!("{}\0", msg);
334                let len = (content.len() + 4) as i32;
335                buf.extend_from_slice(&len.to_be_bytes());
336                buf.extend_from_slice(content.as_bytes());
337                buf
338            }
339            FrontendMessage::Close { is_portal, name } => {
340                let mut buf = Vec::new();
341                buf.push(b'C');
342                let type_byte = if *is_portal { b'P' } else { b'S' };
343                let mut content = vec![type_byte];
344                content.extend_from_slice(name.as_bytes());
345                content.push(0);
346                let len = (content.len() + 4) as i32;
347                buf.extend_from_slice(&len.to_be_bytes());
348                buf.extend_from_slice(&content);
349                buf
350            }
351        }
352    }
353}
354
355impl BackendMessage {
356    /// Decode a message from wire bytes.
357    pub fn decode(buf: &[u8]) -> Result<(Self, usize), String> {
358        if buf.len() < 5 {
359            return Err("Buffer too short".to_string());
360        }
361
362        let msg_type = buf[0];
363        let len = u32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
364
365        if buf.len() < len + 1 {
366            return Err("Incomplete message".to_string());
367        }
368
369        let payload = &buf[5..len + 1];
370
371        let message = match msg_type {
372            b'R' => Self::decode_auth(payload)?,
373            b'S' => Self::decode_parameter_status(payload)?,
374            b'K' => Self::decode_backend_key(payload)?,
375            b'Z' => Self::decode_ready_for_query(payload)?,
376            b'T' => Self::decode_row_description(payload)?,
377            b'D' => Self::decode_data_row(payload)?,
378            b'C' => Self::decode_command_complete(payload)?,
379            b'E' => Self::decode_error_response(payload)?,
380            b'1' => BackendMessage::ParseComplete,
381            b'2' => BackendMessage::BindComplete,
382            b'3' => BackendMessage::CloseComplete,
383            b'n' => BackendMessage::NoData,
384            b't' => Self::decode_parameter_description(payload)?,
385            b'G' => Self::decode_copy_in_response(payload)?,
386            b'H' => Self::decode_copy_out_response(payload)?,
387            b'd' => BackendMessage::CopyData(payload.to_vec()),
388            b'c' => BackendMessage::CopyDone,
389            b'A' => Self::decode_notification_response(payload)?,
390            b'I' => BackendMessage::EmptyQueryResponse,
391            b'N' => BackendMessage::NoticeResponse(Self::parse_error_fields(payload)?),
392            _ => return Err(format!("Unknown message type: {}", msg_type as char)),
393        };
394
395        Ok((message, len + 1))
396    }
397
398    fn decode_auth(payload: &[u8]) -> Result<Self, String> {
399        if payload.len() < 4 {
400            return Err("Auth payload too short".to_string());
401        }
402        let auth_type = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
403        match auth_type {
404            0 => Ok(BackendMessage::AuthenticationOk),
405            5 => {
406                if payload.len() < 8 {
407                    return Err("MD5 auth payload too short (need salt)".to_string());
408                }
409                // SAFETY: Length is verified on the check above (payload.len() < 8 returns Err).
410                let salt: [u8; 4] = payload[4..8].try_into().expect("salt slice is exactly 4 bytes");
411                Ok(BackendMessage::AuthenticationMD5Password(salt))
412            }
413            10 => {
414                // SASL - parse mechanism list
415                let mut mechanisms = Vec::new();
416                let mut pos = 4;
417                while pos < payload.len() && payload[pos] != 0 {
418                    let end = payload[pos..]
419                        .iter()
420                        .position(|&b| b == 0)
421                        .map(|p| pos + p)
422                        .unwrap_or(payload.len());
423                    mechanisms.push(String::from_utf8_lossy(&payload[pos..end]).to_string());
424                    pos = end + 1;
425                }
426                Ok(BackendMessage::AuthenticationSASL(mechanisms))
427            }
428            11 => {
429                // SASL Continue - server challenge
430                Ok(BackendMessage::AuthenticationSASLContinue(
431                    payload[4..].to_vec(),
432                ))
433            }
434            12 => {
435                // SASL Final - server signature
436                Ok(BackendMessage::AuthenticationSASLFinal(
437                    payload[4..].to_vec(),
438                ))
439            }
440            _ => Err(format!("Unknown auth type: {}", auth_type)),
441        }
442    }
443
444    fn decode_parameter_status(payload: &[u8]) -> Result<Self, String> {
445        let parts: Vec<&[u8]> = payload.split(|&b| b == 0).collect();
446        let empty: &[u8] = b"";
447        Ok(BackendMessage::ParameterStatus {
448            name: String::from_utf8_lossy(parts.first().unwrap_or(&empty)).to_string(),
449            value: String::from_utf8_lossy(parts.get(1).unwrap_or(&empty)).to_string(),
450        })
451    }
452
453    fn decode_backend_key(payload: &[u8]) -> Result<Self, String> {
454        if payload.len() < 8 {
455            return Err("BackendKeyData payload too short".to_string());
456        }
457        Ok(BackendMessage::BackendKeyData {
458            process_id: i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]),
459            secret_key: i32::from_be_bytes([payload[4], payload[5], payload[6], payload[7]]),
460        })
461    }
462
463    fn decode_ready_for_query(payload: &[u8]) -> Result<Self, String> {
464        if payload.is_empty() {
465            return Err("ReadyForQuery payload empty".to_string());
466        }
467        let status = match payload[0] {
468            b'I' => TransactionStatus::Idle,
469            b'T' => TransactionStatus::InBlock,
470            b'E' => TransactionStatus::Failed,
471            _ => return Err("Unknown transaction status".to_string()),
472        };
473        Ok(BackendMessage::ReadyForQuery(status))
474    }
475
476    fn decode_row_description(payload: &[u8]) -> Result<Self, String> {
477        if payload.len() < 2 {
478            return Err("RowDescription payload too short".to_string());
479        }
480
481        let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
482        if raw_count < 0 {
483            return Err(format!("RowDescription invalid field count: {}", raw_count));
484        }
485        let field_count = raw_count as usize;
486        let mut fields = Vec::with_capacity(field_count);
487        let mut pos = 2;
488
489        for _ in 0..field_count {
490            // Field name (null-terminated string)
491            let name_end = payload[pos..]
492                .iter()
493                .position(|&b| b == 0)
494                .ok_or("Missing null terminator in field name")?;
495            let name = String::from_utf8_lossy(&payload[pos..pos + name_end]).to_string();
496            pos += name_end + 1; // Skip null terminator
497
498            // Ensure we have enough bytes for the fixed fields
499            if pos + 18 > payload.len() {
500                return Err("RowDescription field truncated".to_string());
501            }
502
503            let table_oid = u32::from_be_bytes([
504                payload[pos],
505                payload[pos + 1],
506                payload[pos + 2],
507                payload[pos + 3],
508            ]);
509            pos += 4;
510
511            let column_attr = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
512            pos += 2;
513
514            let type_oid = u32::from_be_bytes([
515                payload[pos],
516                payload[pos + 1],
517                payload[pos + 2],
518                payload[pos + 3],
519            ]);
520            pos += 4;
521
522            let type_size = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
523            pos += 2;
524
525            let type_modifier = i32::from_be_bytes([
526                payload[pos],
527                payload[pos + 1],
528                payload[pos + 2],
529                payload[pos + 3],
530            ]);
531            pos += 4;
532
533            let format = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
534            pos += 2;
535
536            fields.push(FieldDescription {
537                name,
538                table_oid,
539                column_attr,
540                type_oid,
541                type_size,
542                type_modifier,
543                format,
544            });
545        }
546
547        Ok(BackendMessage::RowDescription(fields))
548    }
549
550    fn decode_data_row(payload: &[u8]) -> Result<Self, String> {
551        if payload.len() < 2 {
552            return Err("DataRow payload too short".to_string());
553        }
554
555        let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
556        if raw_count < 0 {
557            return Err(format!("DataRow invalid column count: {}", raw_count));
558        }
559        let column_count = raw_count as usize;
560        // Sanity check: each column needs at least 4 bytes (length field)
561        if column_count > (payload.len() - 2) / 4 + 1 {
562            return Err(format!(
563                "DataRow claims {} columns but payload is only {} bytes",
564                column_count,
565                payload.len()
566            ));
567        }
568        let mut columns = Vec::with_capacity(column_count);
569        let mut pos = 2;
570
571        for _ in 0..column_count {
572            if pos + 4 > payload.len() {
573                return Err("DataRow truncated".to_string());
574            }
575
576            let len = i32::from_be_bytes([
577                payload[pos],
578                payload[pos + 1],
579                payload[pos + 2],
580                payload[pos + 3],
581            ]);
582            pos += 4;
583
584            if len == -1 {
585                // NULL value
586                columns.push(None);
587            } else {
588                let len = len as usize;
589                if pos + len > payload.len() {
590                    return Err("DataRow column data truncated".to_string());
591                }
592                let data = payload[pos..pos + len].to_vec();
593                pos += len;
594                columns.push(Some(data));
595            }
596        }
597
598        Ok(BackendMessage::DataRow(columns))
599    }
600
601    fn decode_command_complete(payload: &[u8]) -> Result<Self, String> {
602        let tag = String::from_utf8_lossy(payload)
603            .trim_end_matches('\0')
604            .to_string();
605        Ok(BackendMessage::CommandComplete(tag))
606    }
607
608    fn decode_error_response(payload: &[u8]) -> Result<Self, String> {
609        Ok(BackendMessage::ErrorResponse(Self::parse_error_fields(
610            payload,
611        )?))
612    }
613
614    fn parse_error_fields(payload: &[u8]) -> Result<ErrorFields, String> {
615        let mut fields = ErrorFields::default();
616        let mut i = 0;
617        while i < payload.len() && payload[i] != 0 {
618            let field_type = payload[i];
619            i += 1;
620            let end = payload[i..].iter().position(|&b| b == 0).unwrap_or(0) + i;
621            let value = String::from_utf8_lossy(&payload[i..end]).to_string();
622            i = end + 1;
623
624            match field_type {
625                b'S' => fields.severity = value,
626                b'C' => fields.code = value,
627                b'M' => fields.message = value,
628                b'D' => fields.detail = Some(value),
629                b'H' => fields.hint = Some(value),
630                _ => {}
631            }
632        }
633        Ok(fields)
634    }
635
636    fn decode_parameter_description(payload: &[u8]) -> Result<Self, String> {
637        let count = if payload.len() >= 2 {
638            i16::from_be_bytes([payload[0], payload[1]]) as usize
639        } else {
640            0
641        };
642        let mut oids = Vec::with_capacity(count);
643        let mut pos = 2;
644        for _ in 0..count {
645            if pos + 4 <= payload.len() {
646                oids.push(u32::from_be_bytes([
647                    payload[pos], payload[pos + 1], payload[pos + 2], payload[pos + 3],
648                ]));
649                pos += 4;
650            }
651        }
652        Ok(BackendMessage::ParameterDescription(oids))
653    }
654
655    fn decode_copy_in_response(payload: &[u8]) -> Result<Self, String> {
656        if payload.is_empty() {
657            return Err("Empty CopyInResponse payload".to_string());
658        }
659        let format = payload[0];
660        let num_columns = if payload.len() >= 3 {
661            i16::from_be_bytes([payload[1], payload[2]]) as usize
662        } else {
663            0
664        };
665        let column_formats: Vec<u8> = if payload.len() > 3 && num_columns > 0 {
666            payload[3..].iter().take(num_columns).copied().collect()
667        } else {
668            vec![]
669        };
670        Ok(BackendMessage::CopyInResponse {
671            format,
672            column_formats,
673        })
674    }
675
676    fn decode_copy_out_response(payload: &[u8]) -> Result<Self, String> {
677        if payload.is_empty() {
678            return Err("Empty CopyOutResponse payload".to_string());
679        }
680        let format = payload[0];
681        let num_columns = if payload.len() >= 3 {
682            i16::from_be_bytes([payload[1], payload[2]]) as usize
683        } else {
684            0
685        };
686        let column_formats: Vec<u8> = if payload.len() > 3 && num_columns > 0 {
687            payload[3..].iter().take(num_columns).copied().collect()
688        } else {
689            vec![]
690        };
691        Ok(BackendMessage::CopyOutResponse {
692            format,
693            column_formats,
694        })
695    }
696
697    fn decode_notification_response(payload: &[u8]) -> Result<Self, String> {
698        if payload.len() < 4 {
699            return Err("NotificationResponse too short".to_string());
700        }
701        let process_id = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
702
703        // Channel name (null-terminated)
704        let mut i = 4;
705        let channel_end = payload[i..].iter().position(|&b| b == 0).unwrap_or(0) + i;
706        let channel = String::from_utf8_lossy(&payload[i..channel_end]).to_string();
707        i = channel_end + 1;
708
709        // Payload (null-terminated)
710        let payload_end = payload[i..].iter().position(|&b| b == 0).unwrap_or(0) + i;
711        let notification_payload = String::from_utf8_lossy(&payload[i..payload_end]).to_string();
712
713        Ok(BackendMessage::NotificationResponse {
714            process_id,
715            channel,
716            payload: notification_payload,
717        })
718    }
719}
720
721#[cfg(test)]
722mod tests {
723    use super::*;
724
725    /// Helper: build a raw wire message from type byte + payload.
726    fn wire_msg(msg_type: u8, payload: &[u8]) -> Vec<u8> {
727        let len = (payload.len() + 4) as u32;
728        let mut buf = vec![msg_type];
729        buf.extend_from_slice(&len.to_be_bytes());
730        buf.extend_from_slice(payload);
731        buf
732    }
733
734    // ========== Buffer boundary tests ==========
735
736    #[test]
737    fn decode_empty_buffer_returns_error() {
738        assert!(BackendMessage::decode(&[]).is_err());
739    }
740
741    #[test]
742    fn decode_too_short_buffer_returns_error() {
743        // 1-4 bytes are all too short for the 5-byte header
744        for len in 1..5 {
745            let buf = vec![b'Z'; len];
746            let result = BackendMessage::decode(&buf);
747            assert!(result.is_err(), "Expected error for {}-byte buffer", len);
748        }
749    }
750
751    #[test]
752    fn decode_incomplete_message_returns_error() {
753        // Header says length=100 but only 10 bytes present
754        let mut buf = vec![b'Z'];
755        buf.extend_from_slice(&100u32.to_be_bytes());
756        buf.extend_from_slice(&[0u8; 5]); // only 5 payload bytes, need 96
757        assert!(BackendMessage::decode(&buf).unwrap_err().contains("Incomplete"));
758    }
759
760    #[test]
761    fn decode_unknown_message_type_returns_error() {
762        let buf = wire_msg(b'@', &[0]);
763        let result = BackendMessage::decode(&buf);
764        assert!(result.unwrap_err().contains("Unknown message type"));
765    }
766
767    // ========== Auth decode tests ==========
768
769    #[test]
770    fn decode_auth_ok() {
771        let payload = 0i32.to_be_bytes();
772        let buf = wire_msg(b'R', &payload);
773        let (msg, consumed) = BackendMessage::decode(&buf).unwrap();
774        assert!(matches!(msg, BackendMessage::AuthenticationOk));
775        assert_eq!(consumed, buf.len());
776    }
777
778    #[test]
779    fn decode_auth_payload_too_short() {
780        // Auth needs at least 4 bytes for type field
781        let buf = wire_msg(b'R', &[0, 0]);
782        assert!(BackendMessage::decode(&buf).unwrap_err().contains("too short"));
783    }
784
785    #[test]
786    fn decode_auth_md5_missing_salt() {
787        // Auth type 5 (MD5) needs 8 bytes total (4 type + 4 salt)
788        let mut payload = 5i32.to_be_bytes().to_vec();
789        payload.extend_from_slice(&[0, 0, 0]); // only 3 salt bytes, need 4
790        let buf = wire_msg(b'R', &payload);
791        assert!(BackendMessage::decode(&buf).unwrap_err().contains("MD5"));
792    }
793
794    #[test]
795    fn decode_auth_md5_valid_salt() {
796        let mut payload = 5i32.to_be_bytes().to_vec();
797        payload.extend_from_slice(&[0xDE, 0xAD, 0xBE, 0xEF]);
798        let buf = wire_msg(b'R', &payload);
799        let (msg, _) = BackendMessage::decode(&buf).unwrap();
800        match msg {
801            BackendMessage::AuthenticationMD5Password(salt) => {
802                assert_eq!(salt, [0xDE, 0xAD, 0xBE, 0xEF]);
803            }
804            _ => panic!("Expected MD5 auth"),
805        }
806    }
807
808    #[test]
809    fn decode_auth_unknown_type_returns_error() {
810        let payload = 99i32.to_be_bytes();
811        let buf = wire_msg(b'R', &payload);
812        assert!(BackendMessage::decode(&buf).unwrap_err().contains("Unknown auth type"));
813    }
814
815    #[test]
816    fn decode_auth_sasl_mechanisms() {
817        let mut payload = 10i32.to_be_bytes().to_vec();
818        payload.extend_from_slice(b"SCRAM-SHA-256\0\0"); // one mechanism + double null
819        let buf = wire_msg(b'R', &payload);
820        let (msg, _) = BackendMessage::decode(&buf).unwrap();
821        match msg {
822            BackendMessage::AuthenticationSASL(mechs) => {
823                assert_eq!(mechs, vec!["SCRAM-SHA-256"]);
824            }
825            _ => panic!("Expected SASL auth"),
826        }
827    }
828
829    // ========== ReadyForQuery tests ==========
830
831    #[test]
832    fn decode_ready_for_query_idle() {
833        let buf = wire_msg(b'Z', &[b'I']);
834        let (msg, _) = BackendMessage::decode(&buf).unwrap();
835        assert!(matches!(msg, BackendMessage::ReadyForQuery(TransactionStatus::Idle)));
836    }
837
838    #[test]
839    fn decode_ready_for_query_in_transaction() {
840        let buf = wire_msg(b'Z', &[b'T']);
841        let (msg, _) = BackendMessage::decode(&buf).unwrap();
842        assert!(matches!(msg, BackendMessage::ReadyForQuery(TransactionStatus::InBlock)));
843    }
844
845    #[test]
846    fn decode_ready_for_query_failed() {
847        let buf = wire_msg(b'Z', &[b'E']);
848        let (msg, _) = BackendMessage::decode(&buf).unwrap();
849        assert!(matches!(msg, BackendMessage::ReadyForQuery(TransactionStatus::Failed)));
850    }
851
852    #[test]
853    fn decode_ready_for_query_empty_payload() {
854        let buf = wire_msg(b'Z', &[]);
855        assert!(BackendMessage::decode(&buf).unwrap_err().contains("empty"));
856    }
857
858    #[test]
859    fn decode_ready_for_query_unknown_status() {
860        let buf = wire_msg(b'Z', &[b'X']);
861        assert!(BackendMessage::decode(&buf).unwrap_err().contains("Unknown transaction"));
862    }
863
864    // ========== DataRow tests ==========
865
866    #[test]
867    fn decode_data_row_empty_columns() {
868        let payload = 0i16.to_be_bytes();
869        let buf = wire_msg(b'D', &payload);
870        let (msg, _) = BackendMessage::decode(&buf).unwrap();
871        match msg {
872            BackendMessage::DataRow(cols) => assert!(cols.is_empty()),
873            _ => panic!("Expected DataRow"),
874        }
875    }
876
877    #[test]
878    fn decode_data_row_with_null() {
879        let mut payload = 1i16.to_be_bytes().to_vec();
880        payload.extend_from_slice(&(-1i32).to_be_bytes()); // NULL
881        let buf = wire_msg(b'D', &payload);
882        let (msg, _) = BackendMessage::decode(&buf).unwrap();
883        match msg {
884            BackendMessage::DataRow(cols) => {
885                assert_eq!(cols.len(), 1);
886                assert!(cols[0].is_none());
887            }
888            _ => panic!("Expected DataRow"),
889        }
890    }
891
892    #[test]
893    fn decode_data_row_with_value() {
894        let mut payload = 1i16.to_be_bytes().to_vec();
895        let data = b"hello";
896        payload.extend_from_slice(&(data.len() as i32).to_be_bytes());
897        payload.extend_from_slice(data);
898        let buf = wire_msg(b'D', &payload);
899        let (msg, _) = BackendMessage::decode(&buf).unwrap();
900        match msg {
901            BackendMessage::DataRow(cols) => {
902                assert_eq!(cols.len(), 1);
903                assert_eq!(cols[0].as_deref(), Some(b"hello".as_slice()));
904            }
905            _ => panic!("Expected DataRow"),
906        }
907    }
908
909    #[test]
910    fn decode_data_row_negative_count_returns_error() {
911        let payload = (-1i16).to_be_bytes();
912        let buf = wire_msg(b'D', &payload);
913        assert!(BackendMessage::decode(&buf).unwrap_err().contains("invalid column count"));
914    }
915
916    #[test]
917    fn decode_data_row_truncated_column_data() {
918        let mut payload = 1i16.to_be_bytes().to_vec();
919        // Claims 100 bytes of data but payload ends immediately
920        payload.extend_from_slice(&100i32.to_be_bytes());
921        let buf = wire_msg(b'D', &payload);
922        assert!(BackendMessage::decode(&buf).unwrap_err().contains("truncated"));
923    }
924
925    #[test]
926    fn decode_data_row_payload_too_short() {
927        let buf = wire_msg(b'D', &[0]); // only 1 byte, need 2
928        assert!(BackendMessage::decode(&buf).unwrap_err().contains("too short"));
929    }
930
931    #[test]
932    fn decode_data_row_claims_too_many_columns() {
933        // Claims 1000 columns but only a few bytes of payload
934        let payload = 1000i16.to_be_bytes();
935        let buf = wire_msg(b'D', &payload);
936        assert!(BackendMessage::decode(&buf).unwrap_err().contains("claims"));
937    }
938
939    // ========== RowDescription tests ==========
940
941    #[test]
942    fn decode_row_description_zero_fields() {
943        let payload = 0i16.to_be_bytes();
944        let buf = wire_msg(b'T', &payload);
945        let (msg, _) = BackendMessage::decode(&buf).unwrap();
946        match msg {
947            BackendMessage::RowDescription(fields) => assert!(fields.is_empty()),
948            _ => panic!("Expected RowDescription"),
949        }
950    }
951
952    #[test]
953    fn decode_row_description_negative_count() {
954        let payload = (-1i16).to_be_bytes();
955        let buf = wire_msg(b'T', &payload);
956        assert!(BackendMessage::decode(&buf).unwrap_err().contains("invalid field count"));
957    }
958
959    #[test]
960    fn decode_row_description_truncated_field() {
961        let mut payload = 1i16.to_be_bytes().to_vec();
962        payload.extend_from_slice(b"id\0"); // field name
963        // Missing 18 bytes of fixed field data
964        let buf = wire_msg(b'T', &payload);
965        assert!(BackendMessage::decode(&buf).unwrap_err().contains("truncated"));
966    }
967
968    #[test]
969    fn decode_row_description_single_field() {
970        let mut payload = 1i16.to_be_bytes().to_vec();
971        payload.extend_from_slice(b"id\0");         // name
972        payload.extend_from_slice(&0u32.to_be_bytes()); // table_oid
973        payload.extend_from_slice(&0i16.to_be_bytes()); // column_attr
974        payload.extend_from_slice(&23u32.to_be_bytes()); // type_oid (int4)
975        payload.extend_from_slice(&4i16.to_be_bytes()); // type_size
976        payload.extend_from_slice(&(-1i32).to_be_bytes()); // type_modifier
977        payload.extend_from_slice(&0i16.to_be_bytes()); // format (text)
978        let buf = wire_msg(b'T', &payload);
979        let (msg, _) = BackendMessage::decode(&buf).unwrap();
980        match msg {
981            BackendMessage::RowDescription(fields) => {
982                assert_eq!(fields.len(), 1);
983                assert_eq!(fields[0].name, "id");
984                assert_eq!(fields[0].type_oid, 23); // int4
985            }
986            _ => panic!("Expected RowDescription"),
987        }
988    }
989
990    // ========== BackendKeyData tests ==========
991
992    #[test]
993    fn decode_backend_key_data() {
994        let mut payload = 42i32.to_be_bytes().to_vec();
995        payload.extend_from_slice(&99i32.to_be_bytes());
996        let buf = wire_msg(b'K', &payload);
997        let (msg, _) = BackendMessage::decode(&buf).unwrap();
998        match msg {
999            BackendMessage::BackendKeyData { process_id, secret_key } => {
1000                assert_eq!(process_id, 42);
1001                assert_eq!(secret_key, 99);
1002            }
1003            _ => panic!("Expected BackendKeyData"),
1004        }
1005    }
1006
1007    #[test]
1008    fn decode_backend_key_too_short() {
1009        let buf = wire_msg(b'K', &[0, 0, 0, 42]); // only 4 bytes, need 8
1010        assert!(BackendMessage::decode(&buf).unwrap_err().contains("too short"));
1011    }
1012
1013    // ========== ErrorResponse tests ==========
1014
1015    #[test]
1016    fn decode_error_response_with_fields() {
1017        let mut payload = Vec::new();
1018        payload.push(b'S'); payload.extend_from_slice(b"ERROR\0");
1019        payload.push(b'C'); payload.extend_from_slice(b"42P01\0");
1020        payload.push(b'M'); payload.extend_from_slice(b"relation does not exist\0");
1021        payload.push(0); // terminator
1022        let buf = wire_msg(b'E', &payload);
1023        let (msg, _) = BackendMessage::decode(&buf).unwrap();
1024        match msg {
1025            BackendMessage::ErrorResponse(fields) => {
1026                assert_eq!(fields.severity, "ERROR");
1027                assert_eq!(fields.code, "42P01");
1028                assert_eq!(fields.message, "relation does not exist");
1029            }
1030            _ => panic!("Expected ErrorResponse"),
1031        }
1032    }
1033
1034    #[test]
1035    fn decode_error_response_empty() {
1036        let buf = wire_msg(b'E', &[0]); // just terminator
1037        let (msg, _) = BackendMessage::decode(&buf).unwrap();
1038        match msg {
1039            BackendMessage::ErrorResponse(fields) => {
1040                assert!(fields.message.is_empty());
1041            }
1042            _ => panic!("Expected ErrorResponse"),
1043        }
1044    }
1045
1046    // ========== CommandComplete tests ==========
1047
1048    #[test]
1049    fn decode_command_complete() {
1050        let buf = wire_msg(b'C', b"INSERT 0 1\0");
1051        let (msg, _) = BackendMessage::decode(&buf).unwrap();
1052        match msg {
1053            BackendMessage::CommandComplete(tag) => assert_eq!(tag, "INSERT 0 1"),
1054            _ => panic!("Expected CommandComplete"),
1055        }
1056    }
1057
1058    // ========== Simple type tests ==========
1059
1060    #[test]
1061    fn decode_parse_complete() {
1062        let buf = wire_msg(b'1', &[]);
1063        let (msg, _) = BackendMessage::decode(&buf).unwrap();
1064        assert!(matches!(msg, BackendMessage::ParseComplete));
1065    }
1066
1067    #[test]
1068    fn decode_bind_complete() {
1069        let buf = wire_msg(b'2', &[]);
1070        let (msg, _) = BackendMessage::decode(&buf).unwrap();
1071        assert!(matches!(msg, BackendMessage::BindComplete));
1072    }
1073
1074    #[test]
1075    fn decode_no_data() {
1076        let buf = wire_msg(b'n', &[]);
1077        let (msg, _) = BackendMessage::decode(&buf).unwrap();
1078        assert!(matches!(msg, BackendMessage::NoData));
1079    }
1080
1081    #[test]
1082    fn decode_empty_query_response() {
1083        let buf = wire_msg(b'I', &[]);
1084        let (msg, _) = BackendMessage::decode(&buf).unwrap();
1085        assert!(matches!(msg, BackendMessage::EmptyQueryResponse));
1086    }
1087
1088    // ========== NotificationResponse tests ==========
1089
1090    #[test]
1091    fn decode_notification_response() {
1092        let mut payload = 1i32.to_be_bytes().to_vec();
1093        payload.extend_from_slice(b"my_channel\0");
1094        payload.extend_from_slice(b"hello world\0");
1095        let buf = wire_msg(b'A', &payload);
1096        let (msg, _) = BackendMessage::decode(&buf).unwrap();
1097        match msg {
1098            BackendMessage::NotificationResponse { process_id, channel, payload } => {
1099                assert_eq!(process_id, 1);
1100                assert_eq!(channel, "my_channel");
1101                assert_eq!(payload, "hello world");
1102            }
1103            _ => panic!("Expected NotificationResponse"),
1104        }
1105    }
1106
1107    #[test]
1108    fn decode_notification_too_short() {
1109        let buf = wire_msg(b'A', &[0, 0]); // need at least 4 bytes
1110        assert!(BackendMessage::decode(&buf).unwrap_err().contains("too short"));
1111    }
1112
1113    // ========== CopyInResponse / CopyOutResponse tests ==========
1114
1115    #[test]
1116    fn decode_copy_in_response_empty_payload() {
1117        let buf = wire_msg(b'G', &[]);
1118        assert!(BackendMessage::decode(&buf).unwrap_err().contains("Empty"));
1119    }
1120
1121    #[test]
1122    fn decode_copy_out_response_empty_payload() {
1123        let buf = wire_msg(b'H', &[]);
1124        assert!(BackendMessage::decode(&buf).unwrap_err().contains("Empty"));
1125    }
1126
1127    #[test]
1128    fn decode_copy_in_response_text_format() {
1129        let mut payload = vec![0u8]; // text format
1130        payload.extend_from_slice(&1i16.to_be_bytes()); // 1 column
1131        payload.push(0); // column format: text
1132        let buf = wire_msg(b'G', &payload);
1133        let (msg, _) = BackendMessage::decode(&buf).unwrap();
1134        match msg {
1135            BackendMessage::CopyInResponse { format, column_formats } => {
1136                assert_eq!(format, 0);
1137                assert_eq!(column_formats, vec![0]);
1138            }
1139            _ => panic!("Expected CopyInResponse"),
1140        }
1141    }
1142
1143    // ========== Message consumed length test ==========
1144
1145    #[test]
1146    fn decode_consumed_length_is_correct() {
1147        let buf = wire_msg(b'Z', &[b'I']);
1148        let (_, consumed) = BackendMessage::decode(&buf).unwrap();
1149        assert_eq!(consumed, buf.len());
1150    }
1151
1152    #[test]
1153    fn decode_with_trailing_data_only_consumes_one_message() {
1154        let mut buf = wire_msg(b'Z', &[b'I']);
1155        buf.extend_from_slice(&wire_msg(b'Z', &[b'T'])); // second message appended
1156        let (msg, consumed) = BackendMessage::decode(&buf).unwrap();
1157        assert!(matches!(msg, BackendMessage::ReadyForQuery(TransactionStatus::Idle)));
1158        // Should only consume the first message
1159        assert_eq!(consumed, 6); // 1 type + 4 length + 1 payload
1160    }
1161
1162    // ========== FrontendMessage encode roundtrip tests ==========
1163
1164    #[test]
1165    fn encode_sync() {
1166        let msg = FrontendMessage::Sync;
1167        let encoded = msg.encode();
1168        assert_eq!(encoded, vec![b'S', 0, 0, 0, 4]);
1169    }
1170
1171    #[test]
1172    fn encode_terminate() {
1173        let msg = FrontendMessage::Terminate;
1174        let encoded = msg.encode();
1175        assert_eq!(encoded, vec![b'X', 0, 0, 0, 4]);
1176    }
1177}
1178