Skip to main content

qail_pg/protocol/
wire.rs

1//! PostgreSQL Wire Protocol Messages
2//!
3//! Implementation of the PostgreSQL Frontend/Backend Protocol.
4//! Reference: https://www.postgresql.org/docs/current/protocol-message-formats.html
5
6/// Frontend (client → server) message types
7#[derive(Debug, Clone)]
8pub enum FrontendMessage {
9    /// Startup message (sent first, no type byte)
10    Startup { user: String, database: String },
11    PasswordMessage(String),
12    Query(String),
13    /// Parse (prepared statement)
14    Parse {
15        name: String,
16        query: String,
17        param_types: Vec<u32>,
18    },
19    /// Bind parameters to prepared statement
20    Bind {
21        portal: String,
22        statement: String,
23        params: Vec<Option<Vec<u8>>>,
24    },
25    /// Execute portal
26    Execute { portal: String, max_rows: i32 },
27    Sync,
28    Terminate,
29    /// SASL initial response (first message in SCRAM)
30    SASLInitialResponse { mechanism: String, data: Vec<u8> },
31    /// SASL response (subsequent messages in SCRAM)
32    SASLResponse(Vec<u8>),
33    /// CopyFail — abort a COPY IN with an error message
34    CopyFail(String),
35    /// Close — explicitly release a prepared statement or portal
36    Close { is_portal: bool, name: String },
37}
38
39/// Backend (server → client) message types
40#[derive(Debug, Clone)]
41pub enum BackendMessage {
42    /// Authentication request
43    AuthenticationOk,
44    AuthenticationMD5Password([u8; 4]),
45    AuthenticationSASL(Vec<String>),
46    AuthenticationSASLContinue(Vec<u8>),
47    AuthenticationSASLFinal(Vec<u8>),
48    /// Parameter status (server config)
49    ParameterStatus {
50        name: String,
51        value: String,
52    },
53    /// Backend key data (for cancel)
54    BackendKeyData {
55        process_id: i32,
56        secret_key: i32,
57    },
58    ReadyForQuery(TransactionStatus),
59    RowDescription(Vec<FieldDescription>),
60    DataRow(Vec<Option<Vec<u8>>>),
61    CommandComplete(String),
62    ErrorResponse(ErrorFields),
63    ParseComplete,
64    BindComplete,
65    NoData,
66    /// Copy in response (server ready to receive COPY data)
67    CopyInResponse {
68        format: u8,
69        column_formats: Vec<u8>,
70    },
71    /// Copy out response (server will send COPY data)
72    CopyOutResponse {
73        format: u8,
74        column_formats: Vec<u8>,
75    },
76    CopyData(Vec<u8>),
77    CopyDone,
78    /// Notification response (async notification from LISTEN/NOTIFY)
79    NotificationResponse {
80        process_id: i32,
81        channel: String,
82        payload: String,
83    },
84    EmptyQueryResponse,
85    /// Notice response (warning/info messages, not errors)
86    NoticeResponse(ErrorFields),
87    /// Parameter description (OIDs of parameters in a prepared statement)
88    /// Sent by server in response to Describe(Statement)
89    ParameterDescription(Vec<u32>),
90    /// Close complete (server confirmation that a prepared statement/portal was released)
91    CloseComplete,
92}
93
94/// Transaction status
95#[derive(Debug, Clone, Copy)]
96pub enum TransactionStatus {
97    Idle,    // 'I'
98    InBlock, // 'T'
99    Failed,  // 'E'
100}
101
102/// Field description in RowDescription
103#[derive(Debug, Clone)]
104pub struct FieldDescription {
105    pub name: String,
106    pub table_oid: u32,
107    pub column_attr: i16,
108    pub type_oid: u32,
109    pub type_size: i16,
110    pub type_modifier: i32,
111    pub format: i16,
112}
113
114/// Error fields from ErrorResponse
115#[derive(Debug, Clone, Default)]
116pub struct ErrorFields {
117    pub severity: String,
118    pub code: String,
119    pub message: String,
120    pub detail: Option<String>,
121    pub hint: Option<String>,
122}
123
124impl FrontendMessage {
125    /// Encode message to bytes for sending over the wire.
126    pub fn encode(&self) -> Vec<u8> {
127        match self {
128            FrontendMessage::Startup { user, database } => {
129                let mut buf = Vec::new();
130                // Protocol version 3.0
131                buf.extend_from_slice(&196608i32.to_be_bytes());
132                // Parameters
133                buf.extend_from_slice(b"user\0");
134                buf.extend_from_slice(user.as_bytes());
135                buf.push(0);
136                buf.extend_from_slice(b"database\0");
137                buf.extend_from_slice(database.as_bytes());
138                buf.push(0);
139                buf.push(0); // Terminator
140
141                // Prepend length (includes length itself)
142                let len = (buf.len() + 4) as i32;
143                let mut result = len.to_be_bytes().to_vec();
144                result.extend(buf);
145                result
146            }
147            FrontendMessage::Query(sql) => {
148                let mut buf = Vec::new();
149                buf.push(b'Q');
150                let content = format!("{}\0", sql);
151                let len = (content.len() + 4) as i32;
152                buf.extend_from_slice(&len.to_be_bytes());
153                buf.extend_from_slice(content.as_bytes());
154                buf
155            }
156            FrontendMessage::Terminate => {
157                vec![b'X', 0, 0, 0, 4]
158            }
159            FrontendMessage::SASLInitialResponse { mechanism, data } => {
160                let mut buf = Vec::new();
161                buf.push(b'p'); // SASLInitialResponse uses 'p'
162
163                let mut content = Vec::new();
164                content.extend_from_slice(mechanism.as_bytes());
165                content.push(0); // null-terminated mechanism
166                content.extend_from_slice(&(data.len() as i32).to_be_bytes());
167                content.extend_from_slice(data);
168
169                let len = (content.len() + 4) as i32;
170                buf.extend_from_slice(&len.to_be_bytes());
171                buf.extend_from_slice(&content);
172                buf
173            }
174            FrontendMessage::SASLResponse(data) => {
175                let mut buf = Vec::new();
176                buf.push(b'p');
177
178                let len = (data.len() + 4) as i32;
179                buf.extend_from_slice(&len.to_be_bytes());
180                buf.extend_from_slice(data);
181                buf
182            }
183            FrontendMessage::PasswordMessage(password) => {
184                let mut buf = Vec::new();
185                buf.push(b'p');
186                let content = format!("{}\0", password);
187                let len = (content.len() + 4) as i32;
188                buf.extend_from_slice(&len.to_be_bytes());
189                buf.extend_from_slice(content.as_bytes());
190                buf
191            }
192            FrontendMessage::Parse { name, query, param_types } => {
193                let mut buf = Vec::new();
194                buf.push(b'P');
195
196                let mut content = Vec::new();
197                content.extend_from_slice(name.as_bytes());
198                content.push(0);
199                content.extend_from_slice(query.as_bytes());
200                content.push(0);
201                content.extend_from_slice(&(param_types.len() as i16).to_be_bytes());
202                for oid in param_types {
203                    content.extend_from_slice(&oid.to_be_bytes());
204                }
205
206                let len = (content.len() + 4) as i32;
207                buf.extend_from_slice(&len.to_be_bytes());
208                buf.extend_from_slice(&content);
209                buf
210            }
211            FrontendMessage::Bind { portal, statement, params } => {
212                let mut buf = Vec::new();
213                buf.push(b'B');
214
215                let mut content = Vec::new();
216                content.extend_from_slice(portal.as_bytes());
217                content.push(0);
218                content.extend_from_slice(statement.as_bytes());
219                content.push(0);
220                // Format codes (0 = all text)
221                content.extend_from_slice(&0i16.to_be_bytes());
222                // Parameter count
223                content.extend_from_slice(&(params.len() as i16).to_be_bytes());
224                for param in params {
225                    match param {
226                        Some(data) => {
227                            content.extend_from_slice(&(data.len() as i32).to_be_bytes());
228                            content.extend_from_slice(data);
229                        }
230                        None => content.extend_from_slice(&(-1i32).to_be_bytes()),
231                    }
232                }
233                // Result format codes (0 = all text)
234                content.extend_from_slice(&0i16.to_be_bytes());
235
236                let len = (content.len() + 4) as i32;
237                buf.extend_from_slice(&len.to_be_bytes());
238                buf.extend_from_slice(&content);
239                buf
240            }
241            FrontendMessage::Execute { portal, max_rows } => {
242                let mut buf = Vec::new();
243                buf.push(b'E');
244
245                let mut content = Vec::new();
246                content.extend_from_slice(portal.as_bytes());
247                content.push(0);
248                content.extend_from_slice(&max_rows.to_be_bytes());
249
250                let len = (content.len() + 4) as i32;
251                buf.extend_from_slice(&len.to_be_bytes());
252                buf.extend_from_slice(&content);
253                buf
254            }
255            FrontendMessage::Sync => {
256                vec![b'S', 0, 0, 0, 4]
257            }
258            FrontendMessage::CopyFail(msg) => {
259                let mut buf = Vec::new();
260                buf.push(b'f');
261                let content = format!("{}\0", msg);
262                let len = (content.len() + 4) as i32;
263                buf.extend_from_slice(&len.to_be_bytes());
264                buf.extend_from_slice(content.as_bytes());
265                buf
266            }
267            FrontendMessage::Close { is_portal, name } => {
268                let mut buf = Vec::new();
269                buf.push(b'C');
270                let type_byte = if *is_portal { b'P' } else { b'S' };
271                let mut content = vec![type_byte];
272                content.extend_from_slice(name.as_bytes());
273                content.push(0);
274                let len = (content.len() + 4) as i32;
275                buf.extend_from_slice(&len.to_be_bytes());
276                buf.extend_from_slice(&content);
277                buf
278            }
279        }
280    }
281}
282
283impl BackendMessage {
284    /// Decode a message from wire bytes.
285    pub fn decode(buf: &[u8]) -> Result<(Self, usize), String> {
286        if buf.len() < 5 {
287            return Err("Buffer too short".to_string());
288        }
289
290        let msg_type = buf[0];
291        let len = u32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
292
293        if buf.len() < len + 1 {
294            return Err("Incomplete message".to_string());
295        }
296
297        let payload = &buf[5..len + 1];
298
299        let message = match msg_type {
300            b'R' => Self::decode_auth(payload)?,
301            b'S' => Self::decode_parameter_status(payload)?,
302            b'K' => Self::decode_backend_key(payload)?,
303            b'Z' => Self::decode_ready_for_query(payload)?,
304            b'T' => Self::decode_row_description(payload)?,
305            b'D' => Self::decode_data_row(payload)?,
306            b'C' => Self::decode_command_complete(payload)?,
307            b'E' => Self::decode_error_response(payload)?,
308            b'1' => BackendMessage::ParseComplete,
309            b'2' => BackendMessage::BindComplete,
310            b'3' => BackendMessage::CloseComplete,
311            b'n' => BackendMessage::NoData,
312            b't' => Self::decode_parameter_description(payload)?,
313            b'G' => Self::decode_copy_in_response(payload)?,
314            b'H' => Self::decode_copy_out_response(payload)?,
315            b'd' => BackendMessage::CopyData(payload.to_vec()),
316            b'c' => BackendMessage::CopyDone,
317            b'A' => Self::decode_notification_response(payload)?,
318            b'I' => BackendMessage::EmptyQueryResponse,
319            b'N' => BackendMessage::NoticeResponse(Self::parse_error_fields(payload)?),
320            _ => return Err(format!("Unknown message type: {}", msg_type as char)),
321        };
322
323        Ok((message, len + 1))
324    }
325
326    fn decode_auth(payload: &[u8]) -> Result<Self, String> {
327        if payload.len() < 4 {
328            return Err("Auth payload too short".to_string());
329        }
330        let auth_type = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
331        match auth_type {
332            0 => Ok(BackendMessage::AuthenticationOk),
333            5 => {
334                if payload.len() < 8 {
335                    return Err("MD5 auth payload too short (need salt)".to_string());
336                }
337                let salt: [u8; 4] = payload[4..8].try_into().expect("salt length verified above");
338                Ok(BackendMessage::AuthenticationMD5Password(salt))
339            }
340            10 => {
341                // SASL - parse mechanism list
342                let mut mechanisms = Vec::new();
343                let mut pos = 4;
344                while pos < payload.len() && payload[pos] != 0 {
345                    let end = payload[pos..]
346                        .iter()
347                        .position(|&b| b == 0)
348                        .map(|p| pos + p)
349                        .unwrap_or(payload.len());
350                    mechanisms.push(String::from_utf8_lossy(&payload[pos..end]).to_string());
351                    pos = end + 1;
352                }
353                Ok(BackendMessage::AuthenticationSASL(mechanisms))
354            }
355            11 => {
356                // SASL Continue - server challenge
357                Ok(BackendMessage::AuthenticationSASLContinue(
358                    payload[4..].to_vec(),
359                ))
360            }
361            12 => {
362                // SASL Final - server signature
363                Ok(BackendMessage::AuthenticationSASLFinal(
364                    payload[4..].to_vec(),
365                ))
366            }
367            _ => Err(format!("Unknown auth type: {}", auth_type)),
368        }
369    }
370
371    fn decode_parameter_status(payload: &[u8]) -> Result<Self, String> {
372        let parts: Vec<&[u8]> = payload.split(|&b| b == 0).collect();
373        let empty: &[u8] = b"";
374        Ok(BackendMessage::ParameterStatus {
375            name: String::from_utf8_lossy(parts.first().unwrap_or(&empty)).to_string(),
376            value: String::from_utf8_lossy(parts.get(1).unwrap_or(&empty)).to_string(),
377        })
378    }
379
380    fn decode_backend_key(payload: &[u8]) -> Result<Self, String> {
381        if payload.len() < 8 {
382            return Err("BackendKeyData payload too short".to_string());
383        }
384        Ok(BackendMessage::BackendKeyData {
385            process_id: i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]),
386            secret_key: i32::from_be_bytes([payload[4], payload[5], payload[6], payload[7]]),
387        })
388    }
389
390    fn decode_ready_for_query(payload: &[u8]) -> Result<Self, String> {
391        if payload.is_empty() {
392            return Err("ReadyForQuery payload empty".to_string());
393        }
394        let status = match payload[0] {
395            b'I' => TransactionStatus::Idle,
396            b'T' => TransactionStatus::InBlock,
397            b'E' => TransactionStatus::Failed,
398            _ => return Err("Unknown transaction status".to_string()),
399        };
400        Ok(BackendMessage::ReadyForQuery(status))
401    }
402
403    fn decode_row_description(payload: &[u8]) -> Result<Self, String> {
404        if payload.len() < 2 {
405            return Err("RowDescription payload too short".to_string());
406        }
407
408        let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
409        if raw_count < 0 {
410            return Err(format!("RowDescription invalid field count: {}", raw_count));
411        }
412        let field_count = raw_count as usize;
413        let mut fields = Vec::with_capacity(field_count);
414        let mut pos = 2;
415
416        for _ in 0..field_count {
417            // Field name (null-terminated string)
418            let name_end = payload[pos..]
419                .iter()
420                .position(|&b| b == 0)
421                .ok_or("Missing null terminator in field name")?;
422            let name = String::from_utf8_lossy(&payload[pos..pos + name_end]).to_string();
423            pos += name_end + 1; // Skip null terminator
424
425            // Ensure we have enough bytes for the fixed fields
426            if pos + 18 > payload.len() {
427                return Err("RowDescription field truncated".to_string());
428            }
429
430            let table_oid = u32::from_be_bytes([
431                payload[pos],
432                payload[pos + 1],
433                payload[pos + 2],
434                payload[pos + 3],
435            ]);
436            pos += 4;
437
438            let column_attr = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
439            pos += 2;
440
441            let type_oid = u32::from_be_bytes([
442                payload[pos],
443                payload[pos + 1],
444                payload[pos + 2],
445                payload[pos + 3],
446            ]);
447            pos += 4;
448
449            let type_size = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
450            pos += 2;
451
452            let type_modifier = i32::from_be_bytes([
453                payload[pos],
454                payload[pos + 1],
455                payload[pos + 2],
456                payload[pos + 3],
457            ]);
458            pos += 4;
459
460            let format = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
461            pos += 2;
462
463            fields.push(FieldDescription {
464                name,
465                table_oid,
466                column_attr,
467                type_oid,
468                type_size,
469                type_modifier,
470                format,
471            });
472        }
473
474        Ok(BackendMessage::RowDescription(fields))
475    }
476
477    fn decode_data_row(payload: &[u8]) -> Result<Self, String> {
478        if payload.len() < 2 {
479            return Err("DataRow payload too short".to_string());
480        }
481
482        let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
483        if raw_count < 0 {
484            return Err(format!("DataRow invalid column count: {}", raw_count));
485        }
486        let column_count = raw_count as usize;
487        // Sanity check: each column needs at least 4 bytes (length field)
488        if column_count > (payload.len() - 2) / 4 + 1 {
489            return Err(format!(
490                "DataRow claims {} columns but payload is only {} bytes",
491                column_count,
492                payload.len()
493            ));
494        }
495        let mut columns = Vec::with_capacity(column_count);
496        let mut pos = 2;
497
498        for _ in 0..column_count {
499            if pos + 4 > payload.len() {
500                return Err("DataRow truncated".to_string());
501            }
502
503            let len = i32::from_be_bytes([
504                payload[pos],
505                payload[pos + 1],
506                payload[pos + 2],
507                payload[pos + 3],
508            ]);
509            pos += 4;
510
511            if len == -1 {
512                // NULL value
513                columns.push(None);
514            } else {
515                let len = len as usize;
516                if pos + len > payload.len() {
517                    return Err("DataRow column data truncated".to_string());
518                }
519                let data = payload[pos..pos + len].to_vec();
520                pos += len;
521                columns.push(Some(data));
522            }
523        }
524
525        Ok(BackendMessage::DataRow(columns))
526    }
527
528    fn decode_command_complete(payload: &[u8]) -> Result<Self, String> {
529        let tag = String::from_utf8_lossy(payload)
530            .trim_end_matches('\0')
531            .to_string();
532        Ok(BackendMessage::CommandComplete(tag))
533    }
534
535    fn decode_error_response(payload: &[u8]) -> Result<Self, String> {
536        Ok(BackendMessage::ErrorResponse(Self::parse_error_fields(
537            payload,
538        )?))
539    }
540
541    fn parse_error_fields(payload: &[u8]) -> Result<ErrorFields, String> {
542        let mut fields = ErrorFields::default();
543        let mut i = 0;
544        while i < payload.len() && payload[i] != 0 {
545            let field_type = payload[i];
546            i += 1;
547            let end = payload[i..].iter().position(|&b| b == 0).unwrap_or(0) + i;
548            let value = String::from_utf8_lossy(&payload[i..end]).to_string();
549            i = end + 1;
550
551            match field_type {
552                b'S' => fields.severity = value,
553                b'C' => fields.code = value,
554                b'M' => fields.message = value,
555                b'D' => fields.detail = Some(value),
556                b'H' => fields.hint = Some(value),
557                _ => {}
558            }
559        }
560        Ok(fields)
561    }
562
563    fn decode_parameter_description(payload: &[u8]) -> Result<Self, String> {
564        let count = if payload.len() >= 2 {
565            i16::from_be_bytes([payload[0], payload[1]]) as usize
566        } else {
567            0
568        };
569        let mut oids = Vec::with_capacity(count);
570        let mut pos = 2;
571        for _ in 0..count {
572            if pos + 4 <= payload.len() {
573                oids.push(u32::from_be_bytes([
574                    payload[pos], payload[pos + 1], payload[pos + 2], payload[pos + 3],
575                ]));
576                pos += 4;
577            }
578        }
579        Ok(BackendMessage::ParameterDescription(oids))
580    }
581
582    fn decode_copy_in_response(payload: &[u8]) -> Result<Self, String> {
583        if payload.is_empty() {
584            return Err("Empty CopyInResponse payload".to_string());
585        }
586        let format = payload[0];
587        let num_columns = if payload.len() >= 3 {
588            i16::from_be_bytes([payload[1], payload[2]]) as usize
589        } else {
590            0
591        };
592        let column_formats: Vec<u8> = if payload.len() > 3 && num_columns > 0 {
593            payload[3..].iter().take(num_columns).copied().collect()
594        } else {
595            vec![]
596        };
597        Ok(BackendMessage::CopyInResponse {
598            format,
599            column_formats,
600        })
601    }
602
603    fn decode_copy_out_response(payload: &[u8]) -> Result<Self, String> {
604        if payload.is_empty() {
605            return Err("Empty CopyOutResponse payload".to_string());
606        }
607        let format = payload[0];
608        let num_columns = if payload.len() >= 3 {
609            i16::from_be_bytes([payload[1], payload[2]]) as usize
610        } else {
611            0
612        };
613        let column_formats: Vec<u8> = if payload.len() > 3 && num_columns > 0 {
614            payload[3..].iter().take(num_columns).copied().collect()
615        } else {
616            vec![]
617        };
618        Ok(BackendMessage::CopyOutResponse {
619            format,
620            column_formats,
621        })
622    }
623
624    fn decode_notification_response(payload: &[u8]) -> Result<Self, String> {
625        if payload.len() < 4 {
626            return Err("NotificationResponse too short".to_string());
627        }
628        let process_id = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
629
630        // Channel name (null-terminated)
631        let mut i = 4;
632        let channel_end = payload[i..].iter().position(|&b| b == 0).unwrap_or(0) + i;
633        let channel = String::from_utf8_lossy(&payload[i..channel_end]).to_string();
634        i = channel_end + 1;
635
636        // Payload (null-terminated)
637        let payload_end = payload[i..].iter().position(|&b| b == 0).unwrap_or(0) + i;
638        let notification_payload = String::from_utf8_lossy(&payload[i..payload_end]).to_string();
639
640        Ok(BackendMessage::NotificationResponse {
641            process_id,
642            channel,
643            payload: notification_payload,
644        })
645    }
646}