Skip to main content

qail_pg/protocol/
wire.rs

1//! PostgreSQL Wire Protocol Messages
2//!
3//! Implementation of the PostgreSQL Frontend/Backend Protocol.
4//! Reference: https://www.postgresql.org/docs/current/protocol-message-formats.html
5
6/// Frontend (client → server) message types
7#[derive(Debug, Clone)]
8pub enum FrontendMessage {
9    /// Startup message (sent first, no type byte)
10    Startup { user: String, database: String },
11    PasswordMessage(String),
12    Query(String),
13    /// Parse (prepared statement)
14    Parse {
15        name: String,
16        query: String,
17        param_types: Vec<u32>,
18    },
19    /// Bind parameters to prepared statement
20    Bind {
21        portal: String,
22        statement: String,
23        params: Vec<Option<Vec<u8>>>,
24    },
25    /// Execute portal
26    Execute { portal: String, max_rows: i32 },
27    Sync,
28    Terminate,
29    /// SASL initial response (first message in SCRAM)
30    SASLInitialResponse { mechanism: String, data: Vec<u8> },
31    /// SASL response (subsequent messages in SCRAM)
32    SASLResponse(Vec<u8>),
33}
34
35/// Backend (server → client) message types
36#[derive(Debug, Clone)]
37pub enum BackendMessage {
38    /// Authentication request
39    AuthenticationOk,
40    AuthenticationMD5Password([u8; 4]),
41    AuthenticationSASL(Vec<String>),
42    AuthenticationSASLContinue(Vec<u8>),
43    AuthenticationSASLFinal(Vec<u8>),
44    /// Parameter status (server config)
45    ParameterStatus {
46        name: String,
47        value: String,
48    },
49    /// Backend key data (for cancel)
50    BackendKeyData {
51        process_id: i32,
52        secret_key: i32,
53    },
54    ReadyForQuery(TransactionStatus),
55    RowDescription(Vec<FieldDescription>),
56    DataRow(Vec<Option<Vec<u8>>>),
57    CommandComplete(String),
58    ErrorResponse(ErrorFields),
59    ParseComplete,
60    BindComplete,
61    NoData,
62    /// Copy in response (server ready to receive COPY data)
63    CopyInResponse {
64        format: u8,
65        column_formats: Vec<u8>,
66    },
67    /// Copy out response (server will send COPY data)
68    CopyOutResponse {
69        format: u8,
70        column_formats: Vec<u8>,
71    },
72    CopyData(Vec<u8>),
73    CopyDone,
74    /// Notification response (async notification from LISTEN/NOTIFY)
75    NotificationResponse {
76        process_id: i32,
77        channel: String,
78        payload: String,
79    },
80    EmptyQueryResponse,
81    /// Notice response (warning/info messages, not errors)
82    NoticeResponse(ErrorFields),
83    /// Parameter description (OIDs of parameters in a prepared statement)
84    /// Sent by server in response to Describe(Statement)
85    ParameterDescription(Vec<u32>),
86}
87
88/// Transaction status
89#[derive(Debug, Clone, Copy)]
90pub enum TransactionStatus {
91    Idle,    // 'I'
92    InBlock, // 'T'
93    Failed,  // 'E'
94}
95
96/// Field description in RowDescription
97#[derive(Debug, Clone)]
98pub struct FieldDescription {
99    pub name: String,
100    pub table_oid: u32,
101    pub column_attr: i16,
102    pub type_oid: u32,
103    pub type_size: i16,
104    pub type_modifier: i32,
105    pub format: i16,
106}
107
108/// Error fields from ErrorResponse
109#[derive(Debug, Clone, Default)]
110pub struct ErrorFields {
111    pub severity: String,
112    pub code: String,
113    pub message: String,
114    pub detail: Option<String>,
115    pub hint: Option<String>,
116}
117
118impl FrontendMessage {
119    /// Encode message to bytes for sending over the wire.
120    pub fn encode(&self) -> Vec<u8> {
121        match self {
122            FrontendMessage::Startup { user, database } => {
123                let mut buf = Vec::new();
124                // Protocol version 3.0
125                buf.extend_from_slice(&196608i32.to_be_bytes());
126                // Parameters
127                buf.extend_from_slice(b"user\0");
128                buf.extend_from_slice(user.as_bytes());
129                buf.push(0);
130                buf.extend_from_slice(b"database\0");
131                buf.extend_from_slice(database.as_bytes());
132                buf.push(0);
133                buf.push(0); // Terminator
134
135                // Prepend length (includes length itself)
136                let len = (buf.len() + 4) as i32;
137                let mut result = len.to_be_bytes().to_vec();
138                result.extend(buf);
139                result
140            }
141            FrontendMessage::Query(sql) => {
142                let mut buf = Vec::new();
143                buf.push(b'Q');
144                let content = format!("{}\0", sql);
145                let len = (content.len() + 4) as i32;
146                buf.extend_from_slice(&len.to_be_bytes());
147                buf.extend_from_slice(content.as_bytes());
148                buf
149            }
150            FrontendMessage::Terminate => {
151                vec![b'X', 0, 0, 0, 4]
152            }
153            FrontendMessage::SASLInitialResponse { mechanism, data } => {
154                let mut buf = Vec::new();
155                buf.push(b'p'); // SASLInitialResponse uses 'p'
156
157                let mut content = Vec::new();
158                content.extend_from_slice(mechanism.as_bytes());
159                content.push(0); // null-terminated mechanism
160                content.extend_from_slice(&(data.len() as i32).to_be_bytes());
161                content.extend_from_slice(data);
162
163                let len = (content.len() + 4) as i32;
164                buf.extend_from_slice(&len.to_be_bytes());
165                buf.extend_from_slice(&content);
166                buf
167            }
168            FrontendMessage::SASLResponse(data) => {
169                let mut buf = Vec::new();
170                buf.push(b'p');
171
172                let len = (data.len() + 4) as i32;
173                buf.extend_from_slice(&len.to_be_bytes());
174                buf.extend_from_slice(data);
175                buf
176            }
177            FrontendMessage::PasswordMessage(password) => {
178                let mut buf = Vec::new();
179                buf.push(b'p');
180                let content = format!("{}\0", password);
181                let len = (content.len() + 4) as i32;
182                buf.extend_from_slice(&len.to_be_bytes());
183                buf.extend_from_slice(content.as_bytes());
184                buf
185            }
186            FrontendMessage::Parse { name, query, param_types } => {
187                let mut buf = Vec::new();
188                buf.push(b'P');
189
190                let mut content = Vec::new();
191                content.extend_from_slice(name.as_bytes());
192                content.push(0);
193                content.extend_from_slice(query.as_bytes());
194                content.push(0);
195                content.extend_from_slice(&(param_types.len() as i16).to_be_bytes());
196                for oid in param_types {
197                    content.extend_from_slice(&oid.to_be_bytes());
198                }
199
200                let len = (content.len() + 4) as i32;
201                buf.extend_from_slice(&len.to_be_bytes());
202                buf.extend_from_slice(&content);
203                buf
204            }
205            FrontendMessage::Bind { portal, statement, params } => {
206                let mut buf = Vec::new();
207                buf.push(b'B');
208
209                let mut content = Vec::new();
210                content.extend_from_slice(portal.as_bytes());
211                content.push(0);
212                content.extend_from_slice(statement.as_bytes());
213                content.push(0);
214                // Format codes (0 = all text)
215                content.extend_from_slice(&0i16.to_be_bytes());
216                // Parameter count
217                content.extend_from_slice(&(params.len() as i16).to_be_bytes());
218                for param in params {
219                    match param {
220                        Some(data) => {
221                            content.extend_from_slice(&(data.len() as i32).to_be_bytes());
222                            content.extend_from_slice(data);
223                        }
224                        None => content.extend_from_slice(&(-1i32).to_be_bytes()),
225                    }
226                }
227                // Result format codes (0 = all text)
228                content.extend_from_slice(&0i16.to_be_bytes());
229
230                let len = (content.len() + 4) as i32;
231                buf.extend_from_slice(&len.to_be_bytes());
232                buf.extend_from_slice(&content);
233                buf
234            }
235            FrontendMessage::Execute { portal, max_rows } => {
236                let mut buf = Vec::new();
237                buf.push(b'E');
238
239                let mut content = Vec::new();
240                content.extend_from_slice(portal.as_bytes());
241                content.push(0);
242                content.extend_from_slice(&max_rows.to_be_bytes());
243
244                let len = (content.len() + 4) as i32;
245                buf.extend_from_slice(&len.to_be_bytes());
246                buf.extend_from_slice(&content);
247                buf
248            }
249            FrontendMessage::Sync => {
250                vec![b'S', 0, 0, 0, 4]
251            }
252        }
253    }
254}
255
256impl BackendMessage {
257    /// Decode a message from wire bytes.
258    pub fn decode(buf: &[u8]) -> Result<(Self, usize), String> {
259        if buf.len() < 5 {
260            return Err("Buffer too short".to_string());
261        }
262
263        let msg_type = buf[0];
264        let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
265
266        if buf.len() < len + 1 {
267            return Err("Incomplete message".to_string());
268        }
269
270        let payload = &buf[5..len + 1];
271
272        let message = match msg_type {
273            b'R' => Self::decode_auth(payload)?,
274            b'S' => Self::decode_parameter_status(payload)?,
275            b'K' => Self::decode_backend_key(payload)?,
276            b'Z' => Self::decode_ready_for_query(payload)?,
277            b'T' => Self::decode_row_description(payload)?,
278            b'D' => Self::decode_data_row(payload)?,
279            b'C' => Self::decode_command_complete(payload)?,
280            b'E' => Self::decode_error_response(payload)?,
281            b'1' => BackendMessage::ParseComplete,
282            b'2' => BackendMessage::BindComplete,
283            b'n' => BackendMessage::NoData,
284            b't' => Self::decode_parameter_description(payload)?,
285            b'G' => Self::decode_copy_in_response(payload)?,
286            b'H' => Self::decode_copy_out_response(payload)?,
287            b'd' => BackendMessage::CopyData(payload.to_vec()),
288            b'c' => BackendMessage::CopyDone,
289            b'A' => Self::decode_notification_response(payload)?,
290            b'I' => BackendMessage::EmptyQueryResponse,
291            b'N' => BackendMessage::NoticeResponse(Self::parse_error_fields(payload)?),
292            _ => return Err(format!("Unknown message type: {}", msg_type as char)),
293        };
294
295        Ok((message, len + 1))
296    }
297
298    fn decode_auth(payload: &[u8]) -> Result<Self, String> {
299        let auth_type = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
300        match auth_type {
301            0 => Ok(BackendMessage::AuthenticationOk),
302            5 => {
303                let salt: [u8; 4] = payload[4..8].try_into().unwrap();
304                Ok(BackendMessage::AuthenticationMD5Password(salt))
305            }
306            10 => {
307                // SASL - parse mechanism list
308                let mut mechanisms = Vec::new();
309                let mut pos = 4;
310                while pos < payload.len() && payload[pos] != 0 {
311                    let end = payload[pos..]
312                        .iter()
313                        .position(|&b| b == 0)
314                        .map(|p| pos + p)
315                        .unwrap_or(payload.len());
316                    mechanisms.push(String::from_utf8_lossy(&payload[pos..end]).to_string());
317                    pos = end + 1;
318                }
319                Ok(BackendMessage::AuthenticationSASL(mechanisms))
320            }
321            11 => {
322                // SASL Continue - server challenge
323                Ok(BackendMessage::AuthenticationSASLContinue(
324                    payload[4..].to_vec(),
325                ))
326            }
327            12 => {
328                // SASL Final - server signature
329                Ok(BackendMessage::AuthenticationSASLFinal(
330                    payload[4..].to_vec(),
331                ))
332            }
333            _ => Err(format!("Unknown auth type: {}", auth_type)),
334        }
335    }
336
337    fn decode_parameter_status(payload: &[u8]) -> Result<Self, String> {
338        let parts: Vec<&[u8]> = payload.split(|&b| b == 0).collect();
339        let empty: &[u8] = b"";
340        Ok(BackendMessage::ParameterStatus {
341            name: String::from_utf8_lossy(parts.first().unwrap_or(&empty)).to_string(),
342            value: String::from_utf8_lossy(parts.get(1).unwrap_or(&empty)).to_string(),
343        })
344    }
345
346    fn decode_backend_key(payload: &[u8]) -> Result<Self, String> {
347        Ok(BackendMessage::BackendKeyData {
348            process_id: i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]),
349            secret_key: i32::from_be_bytes([payload[4], payload[5], payload[6], payload[7]]),
350        })
351    }
352
353    fn decode_ready_for_query(payload: &[u8]) -> Result<Self, String> {
354        let status = match payload[0] {
355            b'I' => TransactionStatus::Idle,
356            b'T' => TransactionStatus::InBlock,
357            b'E' => TransactionStatus::Failed,
358            _ => return Err("Unknown transaction status".to_string()),
359        };
360        Ok(BackendMessage::ReadyForQuery(status))
361    }
362
363    fn decode_row_description(payload: &[u8]) -> Result<Self, String> {
364        if payload.len() < 2 {
365            return Err("RowDescription payload too short".to_string());
366        }
367
368        let field_count = i16::from_be_bytes([payload[0], payload[1]]) as usize;
369        let mut fields = Vec::with_capacity(field_count);
370        let mut pos = 2;
371
372        for _ in 0..field_count {
373            // Field name (null-terminated string)
374            let name_end = payload[pos..]
375                .iter()
376                .position(|&b| b == 0)
377                .ok_or("Missing null terminator in field name")?;
378            let name = String::from_utf8_lossy(&payload[pos..pos + name_end]).to_string();
379            pos += name_end + 1; // Skip null terminator
380
381            // Ensure we have enough bytes for the fixed fields
382            if pos + 18 > payload.len() {
383                return Err("RowDescription field truncated".to_string());
384            }
385
386            let table_oid = u32::from_be_bytes([
387                payload[pos],
388                payload[pos + 1],
389                payload[pos + 2],
390                payload[pos + 3],
391            ]);
392            pos += 4;
393
394            let column_attr = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
395            pos += 2;
396
397            let type_oid = u32::from_be_bytes([
398                payload[pos],
399                payload[pos + 1],
400                payload[pos + 2],
401                payload[pos + 3],
402            ]);
403            pos += 4;
404
405            let type_size = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
406            pos += 2;
407
408            let type_modifier = i32::from_be_bytes([
409                payload[pos],
410                payload[pos + 1],
411                payload[pos + 2],
412                payload[pos + 3],
413            ]);
414            pos += 4;
415
416            let format = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
417            pos += 2;
418
419            fields.push(FieldDescription {
420                name,
421                table_oid,
422                column_attr,
423                type_oid,
424                type_size,
425                type_modifier,
426                format,
427            });
428        }
429
430        Ok(BackendMessage::RowDescription(fields))
431    }
432
433    fn decode_data_row(payload: &[u8]) -> Result<Self, String> {
434        if payload.len() < 2 {
435            return Err("DataRow payload too short".to_string());
436        }
437
438        let column_count = i16::from_be_bytes([payload[0], payload[1]]) as usize;
439        let mut columns = Vec::with_capacity(column_count);
440        let mut pos = 2;
441
442        for _ in 0..column_count {
443            if pos + 4 > payload.len() {
444                return Err("DataRow truncated".to_string());
445            }
446
447            let len = i32::from_be_bytes([
448                payload[pos],
449                payload[pos + 1],
450                payload[pos + 2],
451                payload[pos + 3],
452            ]);
453            pos += 4;
454
455            if len == -1 {
456                // NULL value
457                columns.push(None);
458            } else {
459                let len = len as usize;
460                if pos + len > payload.len() {
461                    return Err("DataRow column data truncated".to_string());
462                }
463                let data = payload[pos..pos + len].to_vec();
464                pos += len;
465                columns.push(Some(data));
466            }
467        }
468
469        Ok(BackendMessage::DataRow(columns))
470    }
471
472    fn decode_command_complete(payload: &[u8]) -> Result<Self, String> {
473        let tag = String::from_utf8_lossy(payload)
474            .trim_end_matches('\0')
475            .to_string();
476        Ok(BackendMessage::CommandComplete(tag))
477    }
478
479    fn decode_error_response(payload: &[u8]) -> Result<Self, String> {
480        Ok(BackendMessage::ErrorResponse(Self::parse_error_fields(
481            payload,
482        )?))
483    }
484
485    fn parse_error_fields(payload: &[u8]) -> Result<ErrorFields, String> {
486        let mut fields = ErrorFields::default();
487        let mut i = 0;
488        while i < payload.len() && payload[i] != 0 {
489            let field_type = payload[i];
490            i += 1;
491            let end = payload[i..].iter().position(|&b| b == 0).unwrap_or(0) + i;
492            let value = String::from_utf8_lossy(&payload[i..end]).to_string();
493            i = end + 1;
494
495            match field_type {
496                b'S' => fields.severity = value,
497                b'C' => fields.code = value,
498                b'M' => fields.message = value,
499                b'D' => fields.detail = Some(value),
500                b'H' => fields.hint = Some(value),
501                _ => {}
502            }
503        }
504        Ok(fields)
505    }
506
507    fn decode_parameter_description(payload: &[u8]) -> Result<Self, String> {
508        let count = if payload.len() >= 2 {
509            i16::from_be_bytes([payload[0], payload[1]]) as usize
510        } else {
511            0
512        };
513        let mut oids = Vec::with_capacity(count);
514        let mut pos = 2;
515        for _ in 0..count {
516            if pos + 4 <= payload.len() {
517                oids.push(u32::from_be_bytes([
518                    payload[pos], payload[pos + 1], payload[pos + 2], payload[pos + 3],
519                ]));
520                pos += 4;
521            }
522        }
523        Ok(BackendMessage::ParameterDescription(oids))
524    }
525
526    fn decode_copy_in_response(payload: &[u8]) -> Result<Self, String> {
527        if payload.is_empty() {
528            return Err("Empty CopyInResponse payload".to_string());
529        }
530        let format = payload[0];
531        let num_columns = if payload.len() >= 3 {
532            i16::from_be_bytes([payload[1], payload[2]]) as usize
533        } else {
534            0
535        };
536        let column_formats: Vec<u8> = if payload.len() > 3 && num_columns > 0 {
537            payload[3..].iter().take(num_columns).copied().collect()
538        } else {
539            vec![]
540        };
541        Ok(BackendMessage::CopyInResponse {
542            format,
543            column_formats,
544        })
545    }
546
547    fn decode_copy_out_response(payload: &[u8]) -> Result<Self, String> {
548        if payload.is_empty() {
549            return Err("Empty CopyOutResponse payload".to_string());
550        }
551        let format = payload[0];
552        let num_columns = if payload.len() >= 3 {
553            i16::from_be_bytes([payload[1], payload[2]]) as usize
554        } else {
555            0
556        };
557        let column_formats: Vec<u8> = if payload.len() > 3 && num_columns > 0 {
558            payload[3..].iter().take(num_columns).copied().collect()
559        } else {
560            vec![]
561        };
562        Ok(BackendMessage::CopyOutResponse {
563            format,
564            column_formats,
565        })
566    }
567
568    fn decode_notification_response(payload: &[u8]) -> Result<Self, String> {
569        if payload.len() < 4 {
570            return Err("NotificationResponse too short".to_string());
571        }
572        let process_id = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
573
574        // Channel name (null-terminated)
575        let mut i = 4;
576        let channel_end = payload[i..].iter().position(|&b| b == 0).unwrap_or(0) + i;
577        let channel = String::from_utf8_lossy(&payload[i..channel_end]).to_string();
578        i = channel_end + 1;
579
580        // Payload (null-terminated)
581        let payload_end = payload[i..].iter().position(|&b| b == 0).unwrap_or(0) + i;
582        let notification_payload = String::from_utf8_lossy(&payload[i..payload_end]).to_string();
583
584        Ok(BackendMessage::NotificationResponse {
585            process_id,
586            channel,
587            payload: notification_payload,
588        })
589    }
590}