qail_pg/protocol/
wire.rs

1//! PostgreSQL Wire Protocol Messages
2//!
3//! Implementation of the PostgreSQL Frontend/Backend Protocol.
4//! Reference: https://www.postgresql.org/docs/current/protocol-message-formats.html
5
6/// Frontend (client → server) message types
7#[derive(Debug, Clone)]
8pub enum FrontendMessage {
9    /// Startup message (sent first, no type byte)
10    Startup { user: String, database: String },
11    PasswordMessage(String),
12    Query(String),
13    /// Parse (prepared statement)
14    Parse {
15        name: String,
16        query: String,
17        param_types: Vec<u32>,
18    },
19    /// Bind parameters to prepared statement
20    Bind {
21        portal: String,
22        statement: String,
23        params: Vec<Option<Vec<u8>>>,
24    },
25    /// Execute portal
26    Execute { portal: String, max_rows: i32 },
27    Sync,
28    Terminate,
29    /// SASL initial response (first message in SCRAM)
30    SASLInitialResponse { mechanism: String, data: Vec<u8> },
31    /// SASL response (subsequent messages in SCRAM)
32    SASLResponse(Vec<u8>),
33}
34
35/// Backend (server → client) message types
36#[derive(Debug, Clone)]
37pub enum BackendMessage {
38    /// Authentication request
39    AuthenticationOk,
40    AuthenticationMD5Password([u8; 4]),
41    AuthenticationSASL(Vec<String>),
42    AuthenticationSASLContinue(Vec<u8>),
43    AuthenticationSASLFinal(Vec<u8>),
44    /// Parameter status (server config)
45    ParameterStatus {
46        name: String,
47        value: String,
48    },
49    /// Backend key data (for cancel)
50    BackendKeyData {
51        process_id: i32,
52        secret_key: i32,
53    },
54    ReadyForQuery(TransactionStatus),
55    RowDescription(Vec<FieldDescription>),
56    DataRow(Vec<Option<Vec<u8>>>),
57    CommandComplete(String),
58    ErrorResponse(ErrorFields),
59    ParseComplete,
60    BindComplete,
61    NoData,
62    /// Copy in response (server ready to receive COPY data)
63    CopyInResponse {
64        format: u8,
65        column_formats: Vec<u8>,
66    },
67    /// Copy out response (server will send COPY data)
68    CopyOutResponse {
69        format: u8,
70        column_formats: Vec<u8>,
71    },
72    CopyData(Vec<u8>),
73    CopyDone,
74    /// Notification response (async notification from LISTEN/NOTIFY)
75    NotificationResponse {
76        process_id: i32,
77        channel: String,
78        payload: String,
79    },
80    EmptyQueryResponse,
81    /// Notice response (warning/info messages, not errors)
82    NoticeResponse(ErrorFields),
83}
84
85/// Transaction status
86#[derive(Debug, Clone, Copy)]
87pub enum TransactionStatus {
88    Idle,    // 'I'
89    InBlock, // 'T'
90    Failed,  // 'E'
91}
92
93/// Field description in RowDescription
94#[derive(Debug, Clone)]
95pub struct FieldDescription {
96    pub name: String,
97    pub table_oid: u32,
98    pub column_attr: i16,
99    pub type_oid: u32,
100    pub type_size: i16,
101    pub type_modifier: i32,
102    pub format: i16,
103}
104
105/// Error fields from ErrorResponse
106#[derive(Debug, Clone, Default)]
107pub struct ErrorFields {
108    pub severity: String,
109    pub code: String,
110    pub message: String,
111    pub detail: Option<String>,
112    pub hint: Option<String>,
113}
114
115impl FrontendMessage {
116    /// Encode message to bytes for sending over the wire.
117    pub fn encode(&self) -> Vec<u8> {
118        match self {
119            FrontendMessage::Startup { user, database } => {
120                let mut buf = Vec::new();
121                // Protocol version 3.0
122                buf.extend_from_slice(&196608i32.to_be_bytes());
123                // Parameters
124                buf.extend_from_slice(b"user\0");
125                buf.extend_from_slice(user.as_bytes());
126                buf.push(0);
127                buf.extend_from_slice(b"database\0");
128                buf.extend_from_slice(database.as_bytes());
129                buf.push(0);
130                buf.push(0); // Terminator
131
132                // Prepend length (includes length itself)
133                let len = (buf.len() + 4) as i32;
134                let mut result = len.to_be_bytes().to_vec();
135                result.extend(buf);
136                result
137            }
138            FrontendMessage::Query(sql) => {
139                let mut buf = Vec::new();
140                buf.push(b'Q');
141                let content = format!("{}\0", sql);
142                let len = (content.len() + 4) as i32;
143                buf.extend_from_slice(&len.to_be_bytes());
144                buf.extend_from_slice(content.as_bytes());
145                buf
146            }
147            FrontendMessage::Terminate => {
148                vec![b'X', 0, 0, 0, 4]
149            }
150            FrontendMessage::SASLInitialResponse { mechanism, data } => {
151                let mut buf = Vec::new();
152                buf.push(b'p'); // SASLInitialResponse uses 'p'
153
154                let mut content = Vec::new();
155                content.extend_from_slice(mechanism.as_bytes());
156                content.push(0); // null-terminated mechanism
157                content.extend_from_slice(&(data.len() as i32).to_be_bytes());
158                content.extend_from_slice(data);
159
160                let len = (content.len() + 4) as i32;
161                buf.extend_from_slice(&len.to_be_bytes());
162                buf.extend_from_slice(&content);
163                buf
164            }
165            FrontendMessage::SASLResponse(data) => {
166                let mut buf = Vec::new();
167                buf.push(b'p');
168
169                let len = (data.len() + 4) as i32;
170                buf.extend_from_slice(&len.to_be_bytes());
171                buf.extend_from_slice(data);
172                buf
173            }
174            FrontendMessage::PasswordMessage(password) => {
175                let mut buf = Vec::new();
176                buf.push(b'p');
177                let content = format!("{}\0", password);
178                let len = (content.len() + 4) as i32;
179                buf.extend_from_slice(&len.to_be_bytes());
180                buf.extend_from_slice(content.as_bytes());
181                buf
182            }
183            FrontendMessage::Parse { name, query, param_types } => {
184                let mut buf = Vec::new();
185                buf.push(b'P');
186
187                let mut content = Vec::new();
188                content.extend_from_slice(name.as_bytes());
189                content.push(0);
190                content.extend_from_slice(query.as_bytes());
191                content.push(0);
192                content.extend_from_slice(&(param_types.len() as i16).to_be_bytes());
193                for oid in param_types {
194                    content.extend_from_slice(&oid.to_be_bytes());
195                }
196
197                let len = (content.len() + 4) as i32;
198                buf.extend_from_slice(&len.to_be_bytes());
199                buf.extend_from_slice(&content);
200                buf
201            }
202            FrontendMessage::Bind { portal, statement, params } => {
203                let mut buf = Vec::new();
204                buf.push(b'B');
205
206                let mut content = Vec::new();
207                content.extend_from_slice(portal.as_bytes());
208                content.push(0);
209                content.extend_from_slice(statement.as_bytes());
210                content.push(0);
211                // Format codes (0 = all text)
212                content.extend_from_slice(&0i16.to_be_bytes());
213                // Parameter count
214                content.extend_from_slice(&(params.len() as i16).to_be_bytes());
215                for param in params {
216                    match param {
217                        Some(data) => {
218                            content.extend_from_slice(&(data.len() as i32).to_be_bytes());
219                            content.extend_from_slice(data);
220                        }
221                        None => content.extend_from_slice(&(-1i32).to_be_bytes()),
222                    }
223                }
224                // Result format codes (0 = all text)
225                content.extend_from_slice(&0i16.to_be_bytes());
226
227                let len = (content.len() + 4) as i32;
228                buf.extend_from_slice(&len.to_be_bytes());
229                buf.extend_from_slice(&content);
230                buf
231            }
232            FrontendMessage::Execute { portal, max_rows } => {
233                let mut buf = Vec::new();
234                buf.push(b'E');
235
236                let mut content = Vec::new();
237                content.extend_from_slice(portal.as_bytes());
238                content.push(0);
239                content.extend_from_slice(&max_rows.to_be_bytes());
240
241                let len = (content.len() + 4) as i32;
242                buf.extend_from_slice(&len.to_be_bytes());
243                buf.extend_from_slice(&content);
244                buf
245            }
246            FrontendMessage::Sync => {
247                vec![b'S', 0, 0, 0, 4]
248            }
249        }
250    }
251}
252
253impl BackendMessage {
254    /// Decode a message from wire bytes.
255    pub fn decode(buf: &[u8]) -> Result<(Self, usize), String> {
256        if buf.len() < 5 {
257            return Err("Buffer too short".to_string());
258        }
259
260        let msg_type = buf[0];
261        let len = i32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
262
263        if buf.len() < len + 1 {
264            return Err("Incomplete message".to_string());
265        }
266
267        let payload = &buf[5..len + 1];
268
269        let message = match msg_type {
270            b'R' => Self::decode_auth(payload)?,
271            b'S' => Self::decode_parameter_status(payload)?,
272            b'K' => Self::decode_backend_key(payload)?,
273            b'Z' => Self::decode_ready_for_query(payload)?,
274            b'T' => Self::decode_row_description(payload)?,
275            b'D' => Self::decode_data_row(payload)?,
276            b'C' => Self::decode_command_complete(payload)?,
277            b'E' => Self::decode_error_response(payload)?,
278            b'1' => BackendMessage::ParseComplete,
279            b'2' => BackendMessage::BindComplete,
280            b'n' => BackendMessage::NoData,
281            b'G' => Self::decode_copy_in_response(payload)?,
282            b'H' => Self::decode_copy_out_response(payload)?,
283            b'd' => BackendMessage::CopyData(payload.to_vec()),
284            b'c' => BackendMessage::CopyDone,
285            b'A' => Self::decode_notification_response(payload)?,
286            b'I' => BackendMessage::EmptyQueryResponse,
287            b'N' => BackendMessage::NoticeResponse(Self::parse_error_fields(payload)?),
288            _ => return Err(format!("Unknown message type: {}", msg_type as char)),
289        };
290
291        Ok((message, len + 1))
292    }
293
294    fn decode_auth(payload: &[u8]) -> Result<Self, String> {
295        let auth_type = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
296        match auth_type {
297            0 => Ok(BackendMessage::AuthenticationOk),
298            5 => {
299                let salt: [u8; 4] = payload[4..8].try_into().unwrap();
300                Ok(BackendMessage::AuthenticationMD5Password(salt))
301            }
302            10 => {
303                // SASL - parse mechanism list
304                let mut mechanisms = Vec::new();
305                let mut pos = 4;
306                while pos < payload.len() && payload[pos] != 0 {
307                    let end = payload[pos..]
308                        .iter()
309                        .position(|&b| b == 0)
310                        .map(|p| pos + p)
311                        .unwrap_or(payload.len());
312                    mechanisms.push(String::from_utf8_lossy(&payload[pos..end]).to_string());
313                    pos = end + 1;
314                }
315                Ok(BackendMessage::AuthenticationSASL(mechanisms))
316            }
317            11 => {
318                // SASL Continue - server challenge
319                Ok(BackendMessage::AuthenticationSASLContinue(
320                    payload[4..].to_vec(),
321                ))
322            }
323            12 => {
324                // SASL Final - server signature
325                Ok(BackendMessage::AuthenticationSASLFinal(
326                    payload[4..].to_vec(),
327                ))
328            }
329            _ => Err(format!("Unknown auth type: {}", auth_type)),
330        }
331    }
332
333    fn decode_parameter_status(payload: &[u8]) -> Result<Self, String> {
334        let parts: Vec<&[u8]> = payload.split(|&b| b == 0).collect();
335        let empty: &[u8] = b"";
336        Ok(BackendMessage::ParameterStatus {
337            name: String::from_utf8_lossy(parts.first().unwrap_or(&empty)).to_string(),
338            value: String::from_utf8_lossy(parts.get(1).unwrap_or(&empty)).to_string(),
339        })
340    }
341
342    fn decode_backend_key(payload: &[u8]) -> Result<Self, String> {
343        Ok(BackendMessage::BackendKeyData {
344            process_id: i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]),
345            secret_key: i32::from_be_bytes([payload[4], payload[5], payload[6], payload[7]]),
346        })
347    }
348
349    fn decode_ready_for_query(payload: &[u8]) -> Result<Self, String> {
350        let status = match payload[0] {
351            b'I' => TransactionStatus::Idle,
352            b'T' => TransactionStatus::InBlock,
353            b'E' => TransactionStatus::Failed,
354            _ => return Err("Unknown transaction status".to_string()),
355        };
356        Ok(BackendMessage::ReadyForQuery(status))
357    }
358
359    fn decode_row_description(payload: &[u8]) -> Result<Self, String> {
360        if payload.len() < 2 {
361            return Err("RowDescription payload too short".to_string());
362        }
363
364        let field_count = i16::from_be_bytes([payload[0], payload[1]]) as usize;
365        let mut fields = Vec::with_capacity(field_count);
366        let mut pos = 2;
367
368        for _ in 0..field_count {
369            // Field name (null-terminated string)
370            let name_end = payload[pos..]
371                .iter()
372                .position(|&b| b == 0)
373                .ok_or("Missing null terminator in field name")?;
374            let name = String::from_utf8_lossy(&payload[pos..pos + name_end]).to_string();
375            pos += name_end + 1; // Skip null terminator
376
377            // Ensure we have enough bytes for the fixed fields
378            if pos + 18 > payload.len() {
379                return Err("RowDescription field truncated".to_string());
380            }
381
382            let table_oid = u32::from_be_bytes([
383                payload[pos],
384                payload[pos + 1],
385                payload[pos + 2],
386                payload[pos + 3],
387            ]);
388            pos += 4;
389
390            let column_attr = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
391            pos += 2;
392
393            let type_oid = u32::from_be_bytes([
394                payload[pos],
395                payload[pos + 1],
396                payload[pos + 2],
397                payload[pos + 3],
398            ]);
399            pos += 4;
400
401            let type_size = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
402            pos += 2;
403
404            let type_modifier = i32::from_be_bytes([
405                payload[pos],
406                payload[pos + 1],
407                payload[pos + 2],
408                payload[pos + 3],
409            ]);
410            pos += 4;
411
412            let format = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
413            pos += 2;
414
415            fields.push(FieldDescription {
416                name,
417                table_oid,
418                column_attr,
419                type_oid,
420                type_size,
421                type_modifier,
422                format,
423            });
424        }
425
426        Ok(BackendMessage::RowDescription(fields))
427    }
428
429    fn decode_data_row(payload: &[u8]) -> Result<Self, String> {
430        if payload.len() < 2 {
431            return Err("DataRow payload too short".to_string());
432        }
433
434        let column_count = i16::from_be_bytes([payload[0], payload[1]]) as usize;
435        let mut columns = Vec::with_capacity(column_count);
436        let mut pos = 2;
437
438        for _ in 0..column_count {
439            if pos + 4 > payload.len() {
440                return Err("DataRow truncated".to_string());
441            }
442
443            let len = i32::from_be_bytes([
444                payload[pos],
445                payload[pos + 1],
446                payload[pos + 2],
447                payload[pos + 3],
448            ]);
449            pos += 4;
450
451            if len == -1 {
452                // NULL value
453                columns.push(None);
454            } else {
455                let len = len as usize;
456                if pos + len > payload.len() {
457                    return Err("DataRow column data truncated".to_string());
458                }
459                let data = payload[pos..pos + len].to_vec();
460                pos += len;
461                columns.push(Some(data));
462            }
463        }
464
465        Ok(BackendMessage::DataRow(columns))
466    }
467
468    fn decode_command_complete(payload: &[u8]) -> Result<Self, String> {
469        let tag = String::from_utf8_lossy(payload)
470            .trim_end_matches('\0')
471            .to_string();
472        Ok(BackendMessage::CommandComplete(tag))
473    }
474
475    fn decode_error_response(payload: &[u8]) -> Result<Self, String> {
476        Ok(BackendMessage::ErrorResponse(Self::parse_error_fields(
477            payload,
478        )?))
479    }
480
481    fn parse_error_fields(payload: &[u8]) -> Result<ErrorFields, String> {
482        let mut fields = ErrorFields::default();
483        let mut i = 0;
484        while i < payload.len() && payload[i] != 0 {
485            let field_type = payload[i];
486            i += 1;
487            let end = payload[i..].iter().position(|&b| b == 0).unwrap_or(0) + i;
488            let value = String::from_utf8_lossy(&payload[i..end]).to_string();
489            i = end + 1;
490
491            match field_type {
492                b'S' => fields.severity = value,
493                b'C' => fields.code = value,
494                b'M' => fields.message = value,
495                b'D' => fields.detail = Some(value),
496                b'H' => fields.hint = Some(value),
497                _ => {}
498            }
499        }
500        Ok(fields)
501    }
502
503    fn decode_copy_in_response(payload: &[u8]) -> Result<Self, String> {
504        if payload.is_empty() {
505            return Err("Empty CopyInResponse payload".to_string());
506        }
507        let format = payload[0];
508        let num_columns = if payload.len() >= 3 {
509            i16::from_be_bytes([payload[1], payload[2]]) as usize
510        } else {
511            0
512        };
513        let column_formats: Vec<u8> = if payload.len() > 3 && num_columns > 0 {
514            payload[3..].iter().take(num_columns).copied().collect()
515        } else {
516            vec![]
517        };
518        Ok(BackendMessage::CopyInResponse {
519            format,
520            column_formats,
521        })
522    }
523
524    fn decode_copy_out_response(payload: &[u8]) -> Result<Self, String> {
525        if payload.is_empty() {
526            return Err("Empty CopyOutResponse payload".to_string());
527        }
528        let format = payload[0];
529        let num_columns = if payload.len() >= 3 {
530            i16::from_be_bytes([payload[1], payload[2]]) as usize
531        } else {
532            0
533        };
534        let column_formats: Vec<u8> = if payload.len() > 3 && num_columns > 0 {
535            payload[3..].iter().take(num_columns).copied().collect()
536        } else {
537            vec![]
538        };
539        Ok(BackendMessage::CopyOutResponse {
540            format,
541            column_formats,
542        })
543    }
544
545    fn decode_notification_response(payload: &[u8]) -> Result<Self, String> {
546        if payload.len() < 4 {
547            return Err("NotificationResponse too short".to_string());
548        }
549        let process_id = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
550
551        // Channel name (null-terminated)
552        let mut i = 4;
553        let channel_end = payload[i..].iter().position(|&b| b == 0).unwrap_or(0) + i;
554        let channel = String::from_utf8_lossy(&payload[i..channel_end]).to_string();
555        i = channel_end + 1;
556
557        // Payload (null-terminated)
558        let payload_end = payload[i..].iter().position(|&b| b == 0).unwrap_or(0) + i;
559        let notification_payload = String::from_utf8_lossy(&payload[i..payload_end]).to_string();
560
561        Ok(BackendMessage::NotificationResponse {
562            process_id,
563            channel,
564            payload: notification_payload,
565        })
566    }
567}