Skip to main content

qail_pg/protocol/wire/
backend.rs

1//! BackendMessage decoder — server-to-client wire format.
2
3use super::types::*;
4
5/// Maximum backend frame length accepted by the decoder.
6///
7/// Mirrors driver-side guards to keep standalone `BackendMessage::decode`
8/// usage fail-closed against oversized frames.
9pub(crate) const MAX_BACKEND_FRAME_LEN: usize = 64 * 1024 * 1024;
10
11impl BackendMessage {
12    /// Decode a message from wire bytes.
13    pub fn decode(buf: &[u8]) -> Result<(Self, usize), String> {
14        if buf.len() < 5 {
15            return Err("Buffer too short".to_string());
16        }
17
18        let msg_type = buf[0];
19        let len = u32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
20
21        // PG protocol: length includes itself (4 bytes), so minimum valid length is 4.
22        // Anything less is a malformed message.
23        if len < 4 {
24            return Err(format!("Invalid message length: {} (minimum is 4)", len));
25        }
26        if len > MAX_BACKEND_FRAME_LEN {
27            return Err(format!(
28                "Message too large: {} bytes (max {})",
29                len, MAX_BACKEND_FRAME_LEN
30            ));
31        }
32
33        let frame_len = len
34            .checked_add(1)
35            .ok_or_else(|| "Message length overflow".to_string())?;
36
37        if buf.len() < frame_len {
38            return Err("Incomplete message".to_string());
39        }
40
41        let payload = &buf[5..frame_len];
42
43        let message = match msg_type {
44            b'R' => Self::decode_auth(payload)?,
45            b'S' => Self::decode_parameter_status(payload)?,
46            b'K' => Self::decode_backend_key(payload)?,
47            b'Z' => Self::decode_ready_for_query(payload)?,
48            b'T' => Self::decode_row_description(payload)?,
49            b'D' => Self::decode_data_row(payload)?,
50            b'C' => Self::decode_command_complete(payload)?,
51            b'E' => Self::decode_error_response(payload)?,
52            b'1' => {
53                if !payload.is_empty() {
54                    return Err("ParseComplete must have empty payload".to_string());
55                }
56                BackendMessage::ParseComplete
57            }
58            b'2' => {
59                if !payload.is_empty() {
60                    return Err("BindComplete must have empty payload".to_string());
61                }
62                BackendMessage::BindComplete
63            }
64            b'3' => {
65                if !payload.is_empty() {
66                    return Err("CloseComplete must have empty payload".to_string());
67                }
68                BackendMessage::CloseComplete
69            }
70            b'n' => {
71                if !payload.is_empty() {
72                    return Err("NoData must have empty payload".to_string());
73                }
74                BackendMessage::NoData
75            }
76            b's' => {
77                if !payload.is_empty() {
78                    return Err("PortalSuspended must have empty payload".to_string());
79                }
80                BackendMessage::PortalSuspended
81            }
82            b't' => Self::decode_parameter_description(payload)?,
83            b'G' => Self::decode_copy_in_response(payload)?,
84            b'H' => Self::decode_copy_out_response(payload)?,
85            b'W' => Self::decode_copy_both_response(payload)?,
86            b'd' => BackendMessage::CopyData(payload.to_vec()),
87            b'c' => {
88                if !payload.is_empty() {
89                    return Err("CopyDone must have empty payload".to_string());
90                }
91                BackendMessage::CopyDone
92            }
93            b'A' => Self::decode_notification_response(payload)?,
94            b'I' => {
95                if !payload.is_empty() {
96                    return Err("EmptyQueryResponse must have empty payload".to_string());
97                }
98                BackendMessage::EmptyQueryResponse
99            }
100            b'N' => BackendMessage::NoticeResponse(Self::parse_error_fields(payload)?),
101            _ => return Err(format!("Unknown message type: {}", msg_type as char)),
102        };
103
104        Ok((message, frame_len))
105    }
106
107    fn decode_auth(payload: &[u8]) -> Result<Self, String> {
108        if payload.len() < 4 {
109            return Err("Auth payload too short".to_string());
110        }
111        let auth_type = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
112        match auth_type {
113            0 => {
114                if payload.len() != 4 {
115                    return Err(format!(
116                        "AuthenticationOk invalid payload length: {}",
117                        payload.len()
118                    ));
119                }
120                Ok(BackendMessage::AuthenticationOk)
121            }
122            2 => {
123                if payload.len() != 4 {
124                    return Err(format!(
125                        "AuthenticationKerberosV5 invalid payload length: {}",
126                        payload.len()
127                    ));
128                }
129                Ok(BackendMessage::AuthenticationKerberosV5)
130            }
131            3 => {
132                if payload.len() != 4 {
133                    return Err(format!(
134                        "AuthenticationCleartextPassword invalid payload length: {}",
135                        payload.len()
136                    ));
137                }
138                Ok(BackendMessage::AuthenticationCleartextPassword)
139            }
140            5 => {
141                if payload.len() != 8 {
142                    return Err("MD5 auth payload too short (need salt)".to_string());
143                }
144                let mut salt = [0u8; 4];
145                salt.copy_from_slice(&payload[4..8]);
146                Ok(BackendMessage::AuthenticationMD5Password(salt))
147            }
148            6 => {
149                if payload.len() != 4 {
150                    return Err(format!(
151                        "AuthenticationSCMCredential invalid payload length: {}",
152                        payload.len()
153                    ));
154                }
155                Ok(BackendMessage::AuthenticationSCMCredential)
156            }
157            7 => {
158                if payload.len() != 4 {
159                    return Err(format!(
160                        "AuthenticationGSS invalid payload length: {}",
161                        payload.len()
162                    ));
163                }
164                Ok(BackendMessage::AuthenticationGSS)
165            }
166            8 => Ok(BackendMessage::AuthenticationGSSContinue(
167                payload[4..].to_vec(),
168            )),
169            9 => {
170                if payload.len() != 4 {
171                    return Err(format!(
172                        "AuthenticationSSPI invalid payload length: {}",
173                        payload.len()
174                    ));
175                }
176                Ok(BackendMessage::AuthenticationSSPI)
177            }
178            10 => {
179                // SASL - parse mechanism list
180                let mut mechanisms = Vec::new();
181                let mut pos = 4;
182                while pos < payload.len() {
183                    if payload[pos] == 0 {
184                        break; // list terminator
185                    }
186                    let end = payload[pos..]
187                        .iter()
188                        .position(|&b| b == 0)
189                        .map(|p| pos + p)
190                        .ok_or("SASL mechanism list missing null terminator")?;
191                    mechanisms.push(String::from_utf8_lossy(&payload[pos..end]).to_string());
192                    pos = end + 1;
193                }
194                if pos >= payload.len() {
195                    return Err("SASL mechanism list missing final terminator".to_string());
196                }
197                if pos + 1 != payload.len() {
198                    return Err("SASL mechanism list has trailing bytes".to_string());
199                }
200                if mechanisms.is_empty() {
201                    return Err("SASL mechanism list is empty".to_string());
202                }
203                Ok(BackendMessage::AuthenticationSASL(mechanisms))
204            }
205            11 => {
206                // SASL Continue - server challenge
207                Ok(BackendMessage::AuthenticationSASLContinue(
208                    payload[4..].to_vec(),
209                ))
210            }
211            12 => {
212                // SASL Final - server signature
213                Ok(BackendMessage::AuthenticationSASLFinal(
214                    payload[4..].to_vec(),
215                ))
216            }
217            _ => Err(format!("Unknown auth type: {}", auth_type)),
218        }
219    }
220
221    fn decode_parameter_status(payload: &[u8]) -> Result<Self, String> {
222        let name_end = payload
223            .iter()
224            .position(|&b| b == 0)
225            .ok_or("ParameterStatus missing name terminator")?;
226        let value_start = name_end + 1;
227        if value_start > payload.len() {
228            return Err("ParameterStatus missing value".to_string());
229        }
230        let value_end_rel = payload[value_start..]
231            .iter()
232            .position(|&b| b == 0)
233            .ok_or("ParameterStatus missing value terminator")?;
234        let value_end = value_start + value_end_rel;
235        if value_end + 1 != payload.len() {
236            return Err("ParameterStatus has trailing bytes".to_string());
237        }
238        Ok(BackendMessage::ParameterStatus {
239            name: String::from_utf8_lossy(&payload[..name_end]).to_string(),
240            value: String::from_utf8_lossy(&payload[value_start..value_end]).to_string(),
241        })
242    }
243
244    fn decode_backend_key(payload: &[u8]) -> Result<Self, String> {
245        if payload.len() != 8 {
246            return Err("BackendKeyData payload too short".to_string());
247        }
248        Ok(BackendMessage::BackendKeyData {
249            process_id: i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]),
250            secret_key: i32::from_be_bytes([payload[4], payload[5], payload[6], payload[7]]),
251        })
252    }
253
254    fn decode_ready_for_query(payload: &[u8]) -> Result<Self, String> {
255        if payload.len() != 1 {
256            return Err("ReadyForQuery payload empty".to_string());
257        }
258        let status = match payload[0] {
259            b'I' => TransactionStatus::Idle,
260            b'T' => TransactionStatus::InBlock,
261            b'E' => TransactionStatus::Failed,
262            _ => return Err("Unknown transaction status".to_string()),
263        };
264        Ok(BackendMessage::ReadyForQuery(status))
265    }
266
267    fn decode_row_description(payload: &[u8]) -> Result<Self, String> {
268        if payload.len() < 2 {
269            return Err("RowDescription payload too short".to_string());
270        }
271
272        let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
273        if raw_count < 0 {
274            return Err(format!("RowDescription invalid field count: {}", raw_count));
275        }
276        let field_count = raw_count as usize;
277        let mut fields = Vec::with_capacity(field_count);
278        let mut pos = 2;
279
280        for _ in 0..field_count {
281            // Field name (null-terminated string)
282            let name_end = payload[pos..]
283                .iter()
284                .position(|&b| b == 0)
285                .ok_or("Missing null terminator in field name")?;
286            let name = String::from_utf8_lossy(&payload[pos..pos + name_end]).to_string();
287            pos += name_end + 1; // Skip null terminator
288
289            // Ensure we have enough bytes for the fixed fields
290            if pos + 18 > payload.len() {
291                return Err("RowDescription field truncated".to_string());
292            }
293
294            let table_oid = u32::from_be_bytes([
295                payload[pos],
296                payload[pos + 1],
297                payload[pos + 2],
298                payload[pos + 3],
299            ]);
300            pos += 4;
301
302            let column_attr = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
303            pos += 2;
304
305            let type_oid = u32::from_be_bytes([
306                payload[pos],
307                payload[pos + 1],
308                payload[pos + 2],
309                payload[pos + 3],
310            ]);
311            pos += 4;
312
313            let type_size = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
314            pos += 2;
315
316            let type_modifier = i32::from_be_bytes([
317                payload[pos],
318                payload[pos + 1],
319                payload[pos + 2],
320                payload[pos + 3],
321            ]);
322            pos += 4;
323
324            let format = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
325            if !(0..=1).contains(&format) {
326                return Err(format!("RowDescription invalid format code: {}", format));
327            }
328            pos += 2;
329
330            fields.push(FieldDescription {
331                name,
332                table_oid,
333                column_attr,
334                type_oid,
335                type_size,
336                type_modifier,
337                format,
338            });
339        }
340
341        if pos != payload.len() {
342            return Err("RowDescription has trailing bytes".to_string());
343        }
344
345        Ok(BackendMessage::RowDescription(fields))
346    }
347
348    fn decode_data_row(payload: &[u8]) -> Result<Self, String> {
349        if payload.len() < 2 {
350            return Err("DataRow payload too short".to_string());
351        }
352
353        let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
354        if raw_count < 0 {
355            return Err(format!("DataRow invalid column count: {}", raw_count));
356        }
357        let column_count = raw_count as usize;
358        // Sanity check: each column needs at least 4 bytes (length field)
359        if column_count > (payload.len() - 2) / 4 + 1 {
360            return Err(format!(
361                "DataRow claims {} columns but payload is only {} bytes",
362                column_count,
363                payload.len()
364            ));
365        }
366        let mut columns = Vec::with_capacity(column_count);
367        let mut pos = 2;
368
369        for _ in 0..column_count {
370            if pos + 4 > payload.len() {
371                return Err("DataRow truncated".to_string());
372            }
373
374            let len = i32::from_be_bytes([
375                payload[pos],
376                payload[pos + 1],
377                payload[pos + 2],
378                payload[pos + 3],
379            ]);
380            pos += 4;
381
382            if len == -1 {
383                // NULL value
384                columns.push(None);
385            } else {
386                if len < -1 {
387                    return Err(format!("DataRow invalid column length: {}", len));
388                }
389                let len = len as usize;
390                if len > payload.len().saturating_sub(pos) {
391                    return Err("DataRow column data truncated".to_string());
392                }
393                let data = payload[pos..pos + len].to_vec();
394                pos += len;
395                columns.push(Some(data));
396            }
397        }
398
399        if pos != payload.len() {
400            return Err("DataRow has trailing bytes".to_string());
401        }
402
403        Ok(BackendMessage::DataRow(columns))
404    }
405
406    fn decode_command_complete(payload: &[u8]) -> Result<Self, String> {
407        if payload.last().copied() != Some(0) {
408            return Err("CommandComplete missing null terminator".to_string());
409        }
410        let tag_bytes = &payload[..payload.len() - 1];
411        if tag_bytes.contains(&0) {
412            return Err("CommandComplete contains interior null byte".to_string());
413        }
414        let tag = String::from_utf8_lossy(tag_bytes).to_string();
415        Ok(BackendMessage::CommandComplete(tag))
416    }
417
418    fn decode_error_response(payload: &[u8]) -> Result<Self, String> {
419        Ok(BackendMessage::ErrorResponse(Self::parse_error_fields(
420            payload,
421        )?))
422    }
423
424    fn parse_error_fields(payload: &[u8]) -> Result<ErrorFields, String> {
425        if payload.last().copied() != Some(0) {
426            return Err("ErrorResponse missing final terminator".to_string());
427        }
428        let mut fields = ErrorFields::default();
429        let mut i = 0;
430        while i < payload.len() && payload[i] != 0 {
431            let field_type = payload[i];
432            i += 1;
433            let end = payload[i..]
434                .iter()
435                .position(|&b| b == 0)
436                .map(|p| p + i)
437                .ok_or("ErrorResponse field missing null terminator")?;
438            let value = String::from_utf8_lossy(&payload[i..end]).to_string();
439            i = end + 1;
440
441            match field_type {
442                b'S' => fields.severity = value,
443                b'C' => fields.code = value,
444                b'M' => fields.message = value,
445                b'D' => fields.detail = Some(value),
446                b'H' => fields.hint = Some(value),
447                _ => {}
448            }
449        }
450        if i + 1 != payload.len() {
451            return Err("ErrorResponse has trailing bytes after terminator".to_string());
452        }
453        Ok(fields)
454    }
455
456    fn decode_parameter_description(payload: &[u8]) -> Result<Self, String> {
457        if payload.len() < 2 {
458            return Err("ParameterDescription payload too short".to_string());
459        }
460        let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
461        if raw_count < 0 {
462            return Err(format!("ParameterDescription invalid count: {}", raw_count));
463        }
464        let count = raw_count as usize;
465        let expected_len = 2 + count * 4;
466        if payload.len() < expected_len {
467            return Err(format!(
468                "ParameterDescription truncated: expected {} bytes, got {}",
469                expected_len,
470                payload.len()
471            ));
472        }
473        let mut oids = Vec::with_capacity(count);
474        let mut pos = 2;
475        for _ in 0..count {
476            oids.push(u32::from_be_bytes([
477                payload[pos],
478                payload[pos + 1],
479                payload[pos + 2],
480                payload[pos + 3],
481            ]));
482            pos += 4;
483        }
484        if pos != payload.len() {
485            return Err("ParameterDescription has trailing bytes".to_string());
486        }
487        Ok(BackendMessage::ParameterDescription(oids))
488    }
489
490    fn decode_copy_in_response(payload: &[u8]) -> Result<Self, String> {
491        if payload.len() < 3 {
492            return Err("CopyInResponse payload too short".to_string());
493        }
494        let format = payload[0];
495        if format > 1 {
496            return Err(format!(
497                "CopyInResponse invalid overall format code: {}",
498                format
499            ));
500        }
501        let num_columns = if payload.len() >= 3 {
502            let raw = i16::from_be_bytes([payload[1], payload[2]]);
503            if raw < 0 {
504                return Err(format!(
505                    "CopyInResponse invalid negative column count: {}",
506                    raw
507                ));
508            }
509            raw as usize
510        } else {
511            0
512        };
513        let mut column_formats = Vec::with_capacity(num_columns);
514        let mut pos = 3usize;
515        for _ in 0..num_columns {
516            if pos + 2 > payload.len() {
517                return Err("CopyInResponse truncated column format list".to_string());
518            }
519            let raw = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
520            if !(0..=1).contains(&raw) {
521                return Err(format!("CopyInResponse invalid format code: {}", raw));
522            }
523            column_formats.push(raw as u8);
524            pos += 2;
525        }
526        if pos != payload.len() {
527            return Err("CopyInResponse has trailing bytes".to_string());
528        }
529        Ok(BackendMessage::CopyInResponse {
530            format,
531            column_formats,
532        })
533    }
534
535    fn decode_copy_out_response(payload: &[u8]) -> Result<Self, String> {
536        if payload.len() < 3 {
537            return Err("CopyOutResponse payload too short".to_string());
538        }
539        let format = payload[0];
540        if format > 1 {
541            return Err(format!(
542                "CopyOutResponse invalid overall format code: {}",
543                format
544            ));
545        }
546        let num_columns = if payload.len() >= 3 {
547            let raw = i16::from_be_bytes([payload[1], payload[2]]);
548            if raw < 0 {
549                return Err(format!(
550                    "CopyOutResponse invalid negative column count: {}",
551                    raw
552                ));
553            }
554            raw as usize
555        } else {
556            0
557        };
558        let mut column_formats = Vec::with_capacity(num_columns);
559        let mut pos = 3usize;
560        for _ in 0..num_columns {
561            if pos + 2 > payload.len() {
562                return Err("CopyOutResponse truncated column format list".to_string());
563            }
564            let raw = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
565            if !(0..=1).contains(&raw) {
566                return Err(format!("CopyOutResponse invalid format code: {}", raw));
567            }
568            column_formats.push(raw as u8);
569            pos += 2;
570        }
571        if pos != payload.len() {
572            return Err("CopyOutResponse has trailing bytes".to_string());
573        }
574        Ok(BackendMessage::CopyOutResponse {
575            format,
576            column_formats,
577        })
578    }
579
580    fn decode_copy_both_response(payload: &[u8]) -> Result<Self, String> {
581        if payload.len() < 3 {
582            return Err("CopyBothResponse payload too short".to_string());
583        }
584        let format = payload[0];
585        if format > 1 {
586            return Err(format!(
587                "CopyBothResponse invalid overall format code: {}",
588                format
589            ));
590        }
591        let num_columns = if payload.len() >= 3 {
592            let raw = i16::from_be_bytes([payload[1], payload[2]]);
593            if raw < 0 {
594                return Err(format!(
595                    "CopyBothResponse invalid negative column count: {}",
596                    raw
597                ));
598            }
599            raw as usize
600        } else {
601            0
602        };
603        let mut column_formats = Vec::with_capacity(num_columns);
604        let mut pos = 3usize;
605        for _ in 0..num_columns {
606            if pos + 2 > payload.len() {
607                return Err("CopyBothResponse truncated column format list".to_string());
608            }
609            let raw = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
610            if !(0..=1).contains(&raw) {
611                return Err(format!("CopyBothResponse invalid format code: {}", raw));
612            }
613            column_formats.push(raw as u8);
614            pos += 2;
615        }
616        if pos != payload.len() {
617            return Err("CopyBothResponse has trailing bytes".to_string());
618        }
619        Ok(BackendMessage::CopyBothResponse {
620            format,
621            column_formats,
622        })
623    }
624
625    fn decode_notification_response(payload: &[u8]) -> Result<Self, String> {
626        if payload.len() < 6 {
627            // Minimum: 4 (process_id) + 1 (channel NUL) + 1 (payload NUL)
628            return Err("NotificationResponse too short".to_string());
629        }
630        let process_id = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
631
632        // Channel name (null-terminated)
633        let mut i = 4;
634        let remaining = payload.get(i..).unwrap_or(&[]);
635        let channel_end = remaining
636            .iter()
637            .position(|&b| b == 0)
638            .ok_or("NotificationResponse: missing channel null terminator")?;
639        let channel = String::from_utf8_lossy(&remaining[..channel_end]).to_string();
640        i += channel_end + 1;
641
642        // Payload (null-terminated)
643        let remaining = payload.get(i..).unwrap_or(&[]);
644        let payload_end = remaining
645            .iter()
646            .position(|&b| b == 0)
647            .ok_or("NotificationResponse: missing payload null terminator")?;
648        let notification_payload = String::from_utf8_lossy(&remaining[..payload_end]).to_string();
649        if i + payload_end + 1 != payload.len() {
650            return Err("NotificationResponse has trailing bytes".to_string());
651        }
652
653        Ok(BackendMessage::NotificationResponse {
654            process_id,
655            channel,
656            payload: notification_payload,
657        })
658    }
659}