Skip to main content

qail_pg/protocol/wire/
backend.rs

1//! BackendMessage decoder — server-to-client wire format.
2
3use super::types::*;
4
5/// Maximum backend frame length accepted by the decoder.
6///
7/// Mirrors driver-side guards to keep standalone `BackendMessage::decode`
8/// usage fail-closed against oversized frames.
9pub(crate) const MAX_BACKEND_FRAME_LEN: usize = 64 * 1024 * 1024;
10
11impl BackendMessage {
12    /// Decode a message from wire bytes.
13    pub fn decode(buf: &[u8]) -> Result<(Self, usize), String> {
14        if buf.len() < 5 {
15            return Err("Buffer too short".to_string());
16        }
17
18        let msg_type = buf[0];
19        let len = u32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]) as usize;
20
21        // PG protocol: length includes itself (4 bytes), so minimum valid length is 4.
22        // Anything less is a malformed message.
23        if len < 4 {
24            return Err(format!("Invalid message length: {} (minimum is 4)", len));
25        }
26        if len > MAX_BACKEND_FRAME_LEN {
27            return Err(format!(
28                "Message too large: {} bytes (max {})",
29                len, MAX_BACKEND_FRAME_LEN
30            ));
31        }
32
33        let frame_len = len
34            .checked_add(1)
35            .ok_or_else(|| "Message length overflow".to_string())?;
36
37        if buf.len() < frame_len {
38            return Err("Incomplete message".to_string());
39        }
40
41        let payload = &buf[5..frame_len];
42
43        let message = match msg_type {
44            b'R' => Self::decode_auth(payload)?,
45            b'S' => Self::decode_parameter_status(payload)?,
46            b'K' => Self::decode_backend_key(payload)?,
47            b'Z' => Self::decode_ready_for_query(payload)?,
48            b'T' => Self::decode_row_description(payload)?,
49            b'D' => Self::decode_data_row(payload)?,
50            b'C' => Self::decode_command_complete(payload)?,
51            b'E' => Self::decode_error_response(payload)?,
52            b'1' => {
53                if !payload.is_empty() {
54                    return Err("ParseComplete must have empty payload".to_string());
55                }
56                BackendMessage::ParseComplete
57            }
58            b'2' => {
59                if !payload.is_empty() {
60                    return Err("BindComplete must have empty payload".to_string());
61                }
62                BackendMessage::BindComplete
63            }
64            b'3' => {
65                if !payload.is_empty() {
66                    return Err("CloseComplete must have empty payload".to_string());
67                }
68                BackendMessage::CloseComplete
69            }
70            b'n' => {
71                if !payload.is_empty() {
72                    return Err("NoData must have empty payload".to_string());
73                }
74                BackendMessage::NoData
75            }
76            b's' => {
77                if !payload.is_empty() {
78                    return Err("PortalSuspended must have empty payload".to_string());
79                }
80                BackendMessage::PortalSuspended
81            }
82            b't' => Self::decode_parameter_description(payload)?,
83            b'G' => Self::decode_copy_in_response(payload)?,
84            b'H' => Self::decode_copy_out_response(payload)?,
85            b'W' => Self::decode_copy_both_response(payload)?,
86            b'd' => BackendMessage::CopyData(payload.to_vec()),
87            b'c' => {
88                if !payload.is_empty() {
89                    return Err("CopyDone must have empty payload".to_string());
90                }
91                BackendMessage::CopyDone
92            }
93            b'A' => Self::decode_notification_response(payload)?,
94            b'I' => {
95                if !payload.is_empty() {
96                    return Err("EmptyQueryResponse must have empty payload".to_string());
97                }
98                BackendMessage::EmptyQueryResponse
99            }
100            b'N' => BackendMessage::NoticeResponse(Self::parse_error_fields(payload)?),
101            _ => return Err(format!("Unknown message type: {}", msg_type as char)),
102        };
103
104        Ok((message, frame_len))
105    }
106
107    fn decode_auth(payload: &[u8]) -> Result<Self, String> {
108        if payload.len() < 4 {
109            return Err("Auth payload too short".to_string());
110        }
111        let auth_type = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
112        match auth_type {
113            0 => {
114                if payload.len() != 4 {
115                    return Err(format!(
116                        "AuthenticationOk invalid payload length: {}",
117                        payload.len()
118                    ));
119                }
120                Ok(BackendMessage::AuthenticationOk)
121            }
122            2 => {
123                if payload.len() != 4 {
124                    return Err(format!(
125                        "AuthenticationKerberosV5 invalid payload length: {}",
126                        payload.len()
127                    ));
128                }
129                Ok(BackendMessage::AuthenticationKerberosV5)
130            }
131            3 => {
132                if payload.len() != 4 {
133                    return Err(format!(
134                        "AuthenticationCleartextPassword invalid payload length: {}",
135                        payload.len()
136                    ));
137                }
138                Ok(BackendMessage::AuthenticationCleartextPassword)
139            }
140            5 => {
141                if payload.len() != 8 {
142                    return Err("MD5 auth payload too short (need salt)".to_string());
143                }
144                let mut salt = [0u8; 4];
145                salt.copy_from_slice(&payload[4..8]);
146                Ok(BackendMessage::AuthenticationMD5Password(salt))
147            }
148            7 => {
149                if payload.len() != 4 {
150                    return Err(format!(
151                        "AuthenticationGSS invalid payload length: {}",
152                        payload.len()
153                    ));
154                }
155                Ok(BackendMessage::AuthenticationGSS)
156            }
157            8 => Ok(BackendMessage::AuthenticationGSSContinue(
158                payload[4..].to_vec(),
159            )),
160            9 => {
161                if payload.len() != 4 {
162                    return Err(format!(
163                        "AuthenticationSSPI invalid payload length: {}",
164                        payload.len()
165                    ));
166                }
167                Ok(BackendMessage::AuthenticationSSPI)
168            }
169            10 => {
170                // SASL - parse mechanism list
171                let mut mechanisms = Vec::new();
172                let mut pos = 4;
173                while pos < payload.len() {
174                    if payload[pos] == 0 {
175                        break; // list terminator
176                    }
177                    let end = payload[pos..]
178                        .iter()
179                        .position(|&b| b == 0)
180                        .map(|p| pos + p)
181                        .ok_or("SASL mechanism list missing null terminator")?;
182                    mechanisms.push(String::from_utf8_lossy(&payload[pos..end]).to_string());
183                    pos = end + 1;
184                }
185                if pos >= payload.len() {
186                    return Err("SASL mechanism list missing final terminator".to_string());
187                }
188                if pos + 1 != payload.len() {
189                    return Err("SASL mechanism list has trailing bytes".to_string());
190                }
191                if mechanisms.is_empty() {
192                    return Err("SASL mechanism list is empty".to_string());
193                }
194                Ok(BackendMessage::AuthenticationSASL(mechanisms))
195            }
196            11 => {
197                // SASL Continue - server challenge
198                Ok(BackendMessage::AuthenticationSASLContinue(
199                    payload[4..].to_vec(),
200                ))
201            }
202            12 => {
203                // SASL Final - server signature
204                Ok(BackendMessage::AuthenticationSASLFinal(
205                    payload[4..].to_vec(),
206                ))
207            }
208            _ => Err(format!("Unknown auth type: {}", auth_type)),
209        }
210    }
211
212    fn decode_parameter_status(payload: &[u8]) -> Result<Self, String> {
213        let name_end = payload
214            .iter()
215            .position(|&b| b == 0)
216            .ok_or("ParameterStatus missing name terminator")?;
217        let value_start = name_end + 1;
218        if value_start > payload.len() {
219            return Err("ParameterStatus missing value".to_string());
220        }
221        let value_end_rel = payload[value_start..]
222            .iter()
223            .position(|&b| b == 0)
224            .ok_or("ParameterStatus missing value terminator")?;
225        let value_end = value_start + value_end_rel;
226        if value_end + 1 != payload.len() {
227            return Err("ParameterStatus has trailing bytes".to_string());
228        }
229        Ok(BackendMessage::ParameterStatus {
230            name: String::from_utf8_lossy(&payload[..name_end]).to_string(),
231            value: String::from_utf8_lossy(&payload[value_start..value_end]).to_string(),
232        })
233    }
234
235    fn decode_backend_key(payload: &[u8]) -> Result<Self, String> {
236        if payload.len() != 8 {
237            return Err("BackendKeyData payload too short".to_string());
238        }
239        Ok(BackendMessage::BackendKeyData {
240            process_id: i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]),
241            secret_key: i32::from_be_bytes([payload[4], payload[5], payload[6], payload[7]]),
242        })
243    }
244
245    fn decode_ready_for_query(payload: &[u8]) -> Result<Self, String> {
246        if payload.len() != 1 {
247            return Err("ReadyForQuery payload empty".to_string());
248        }
249        let status = match payload[0] {
250            b'I' => TransactionStatus::Idle,
251            b'T' => TransactionStatus::InBlock,
252            b'E' => TransactionStatus::Failed,
253            _ => return Err("Unknown transaction status".to_string()),
254        };
255        Ok(BackendMessage::ReadyForQuery(status))
256    }
257
258    fn decode_row_description(payload: &[u8]) -> Result<Self, String> {
259        if payload.len() < 2 {
260            return Err("RowDescription payload too short".to_string());
261        }
262
263        let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
264        if raw_count < 0 {
265            return Err(format!("RowDescription invalid field count: {}", raw_count));
266        }
267        let field_count = raw_count as usize;
268        let mut fields = Vec::with_capacity(field_count);
269        let mut pos = 2;
270
271        for _ in 0..field_count {
272            // Field name (null-terminated string)
273            let name_end = payload[pos..]
274                .iter()
275                .position(|&b| b == 0)
276                .ok_or("Missing null terminator in field name")?;
277            let name = String::from_utf8_lossy(&payload[pos..pos + name_end]).to_string();
278            pos += name_end + 1; // Skip null terminator
279
280            // Ensure we have enough bytes for the fixed fields
281            if pos + 18 > payload.len() {
282                return Err("RowDescription field truncated".to_string());
283            }
284
285            let table_oid = u32::from_be_bytes([
286                payload[pos],
287                payload[pos + 1],
288                payload[pos + 2],
289                payload[pos + 3],
290            ]);
291            pos += 4;
292
293            let column_attr = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
294            pos += 2;
295
296            let type_oid = u32::from_be_bytes([
297                payload[pos],
298                payload[pos + 1],
299                payload[pos + 2],
300                payload[pos + 3],
301            ]);
302            pos += 4;
303
304            let type_size = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
305            pos += 2;
306
307            let type_modifier = i32::from_be_bytes([
308                payload[pos],
309                payload[pos + 1],
310                payload[pos + 2],
311                payload[pos + 3],
312            ]);
313            pos += 4;
314
315            let format = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
316            if !(0..=1).contains(&format) {
317                return Err(format!("RowDescription invalid format code: {}", format));
318            }
319            pos += 2;
320
321            fields.push(FieldDescription {
322                name,
323                table_oid,
324                column_attr,
325                type_oid,
326                type_size,
327                type_modifier,
328                format,
329            });
330        }
331
332        if pos != payload.len() {
333            return Err("RowDescription has trailing bytes".to_string());
334        }
335
336        Ok(BackendMessage::RowDescription(fields))
337    }
338
339    fn decode_data_row(payload: &[u8]) -> Result<Self, String> {
340        if payload.len() < 2 {
341            return Err("DataRow payload too short".to_string());
342        }
343
344        let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
345        if raw_count < 0 {
346            return Err(format!("DataRow invalid column count: {}", raw_count));
347        }
348        let column_count = raw_count as usize;
349        // Sanity check: each column needs at least 4 bytes (length field)
350        if column_count > (payload.len() - 2) / 4 + 1 {
351            return Err(format!(
352                "DataRow claims {} columns but payload is only {} bytes",
353                column_count,
354                payload.len()
355            ));
356        }
357        let mut columns = Vec::with_capacity(column_count);
358        let mut pos = 2;
359
360        for _ in 0..column_count {
361            if pos + 4 > payload.len() {
362                return Err("DataRow truncated".to_string());
363            }
364
365            let len = i32::from_be_bytes([
366                payload[pos],
367                payload[pos + 1],
368                payload[pos + 2],
369                payload[pos + 3],
370            ]);
371            pos += 4;
372
373            if len == -1 {
374                // NULL value
375                columns.push(None);
376            } else {
377                if len < -1 {
378                    return Err(format!("DataRow invalid column length: {}", len));
379                }
380                let len = len as usize;
381                if len > payload.len().saturating_sub(pos) {
382                    return Err("DataRow column data truncated".to_string());
383                }
384                let data = payload[pos..pos + len].to_vec();
385                pos += len;
386                columns.push(Some(data));
387            }
388        }
389
390        if pos != payload.len() {
391            return Err("DataRow has trailing bytes".to_string());
392        }
393
394        Ok(BackendMessage::DataRow(columns))
395    }
396
397    fn decode_command_complete(payload: &[u8]) -> Result<Self, String> {
398        if payload.last().copied() != Some(0) {
399            return Err("CommandComplete missing null terminator".to_string());
400        }
401        let tag_bytes = &payload[..payload.len() - 1];
402        if tag_bytes.contains(&0) {
403            return Err("CommandComplete contains interior null byte".to_string());
404        }
405        let tag = String::from_utf8_lossy(tag_bytes).to_string();
406        Ok(BackendMessage::CommandComplete(tag))
407    }
408
409    fn decode_error_response(payload: &[u8]) -> Result<Self, String> {
410        Ok(BackendMessage::ErrorResponse(Self::parse_error_fields(
411            payload,
412        )?))
413    }
414
415    fn parse_error_fields(payload: &[u8]) -> Result<ErrorFields, String> {
416        if payload.last().copied() != Some(0) {
417            return Err("ErrorResponse missing final terminator".to_string());
418        }
419        let mut fields = ErrorFields::default();
420        let mut i = 0;
421        while i < payload.len() && payload[i] != 0 {
422            let field_type = payload[i];
423            i += 1;
424            let end = payload[i..]
425                .iter()
426                .position(|&b| b == 0)
427                .map(|p| p + i)
428                .ok_or("ErrorResponse field missing null terminator")?;
429            let value = String::from_utf8_lossy(&payload[i..end]).to_string();
430            i = end + 1;
431
432            match field_type {
433                b'S' => fields.severity = value,
434                b'C' => fields.code = value,
435                b'M' => fields.message = value,
436                b'D' => fields.detail = Some(value),
437                b'H' => fields.hint = Some(value),
438                _ => {}
439            }
440        }
441        if i + 1 != payload.len() {
442            return Err("ErrorResponse has trailing bytes after terminator".to_string());
443        }
444        Ok(fields)
445    }
446
447    fn decode_parameter_description(payload: &[u8]) -> Result<Self, String> {
448        if payload.len() < 2 {
449            return Err("ParameterDescription payload too short".to_string());
450        }
451        let raw_count = i16::from_be_bytes([payload[0], payload[1]]);
452        if raw_count < 0 {
453            return Err(format!("ParameterDescription invalid count: {}", raw_count));
454        }
455        let count = raw_count as usize;
456        let expected_len = 2 + count * 4;
457        if payload.len() < expected_len {
458            return Err(format!(
459                "ParameterDescription truncated: expected {} bytes, got {}",
460                expected_len,
461                payload.len()
462            ));
463        }
464        let mut oids = Vec::with_capacity(count);
465        let mut pos = 2;
466        for _ in 0..count {
467            oids.push(u32::from_be_bytes([
468                payload[pos],
469                payload[pos + 1],
470                payload[pos + 2],
471                payload[pos + 3],
472            ]));
473            pos += 4;
474        }
475        if pos != payload.len() {
476            return Err("ParameterDescription has trailing bytes".to_string());
477        }
478        Ok(BackendMessage::ParameterDescription(oids))
479    }
480
481    fn decode_copy_in_response(payload: &[u8]) -> Result<Self, String> {
482        if payload.len() < 3 {
483            return Err("CopyInResponse payload too short".to_string());
484        }
485        let format = payload[0];
486        if format > 1 {
487            return Err(format!(
488                "CopyInResponse invalid overall format code: {}",
489                format
490            ));
491        }
492        let num_columns = if payload.len() >= 3 {
493            let raw = i16::from_be_bytes([payload[1], payload[2]]);
494            if raw < 0 {
495                return Err(format!(
496                    "CopyInResponse invalid negative column count: {}",
497                    raw
498                ));
499            }
500            raw as usize
501        } else {
502            0
503        };
504        let mut column_formats = Vec::with_capacity(num_columns);
505        let mut pos = 3usize;
506        for _ in 0..num_columns {
507            if pos + 2 > payload.len() {
508                return Err("CopyInResponse truncated column format list".to_string());
509            }
510            let raw = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
511            if !(0..=1).contains(&raw) {
512                return Err(format!("CopyInResponse invalid format code: {}", raw));
513            }
514            column_formats.push(raw as u8);
515            pos += 2;
516        }
517        if pos != payload.len() {
518            return Err("CopyInResponse has trailing bytes".to_string());
519        }
520        Ok(BackendMessage::CopyInResponse {
521            format,
522            column_formats,
523        })
524    }
525
526    fn decode_copy_out_response(payload: &[u8]) -> Result<Self, String> {
527        if payload.len() < 3 {
528            return Err("CopyOutResponse payload too short".to_string());
529        }
530        let format = payload[0];
531        if format > 1 {
532            return Err(format!(
533                "CopyOutResponse invalid overall format code: {}",
534                format
535            ));
536        }
537        let num_columns = if payload.len() >= 3 {
538            let raw = i16::from_be_bytes([payload[1], payload[2]]);
539            if raw < 0 {
540                return Err(format!(
541                    "CopyOutResponse invalid negative column count: {}",
542                    raw
543                ));
544            }
545            raw as usize
546        } else {
547            0
548        };
549        let mut column_formats = Vec::with_capacity(num_columns);
550        let mut pos = 3usize;
551        for _ in 0..num_columns {
552            if pos + 2 > payload.len() {
553                return Err("CopyOutResponse truncated column format list".to_string());
554            }
555            let raw = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
556            if !(0..=1).contains(&raw) {
557                return Err(format!("CopyOutResponse invalid format code: {}", raw));
558            }
559            column_formats.push(raw as u8);
560            pos += 2;
561        }
562        if pos != payload.len() {
563            return Err("CopyOutResponse has trailing bytes".to_string());
564        }
565        Ok(BackendMessage::CopyOutResponse {
566            format,
567            column_formats,
568        })
569    }
570
571    fn decode_copy_both_response(payload: &[u8]) -> Result<Self, String> {
572        if payload.len() < 3 {
573            return Err("CopyBothResponse payload too short".to_string());
574        }
575        let format = payload[0];
576        if format > 1 {
577            return Err(format!(
578                "CopyBothResponse invalid overall format code: {}",
579                format
580            ));
581        }
582        let num_columns = if payload.len() >= 3 {
583            let raw = i16::from_be_bytes([payload[1], payload[2]]);
584            if raw < 0 {
585                return Err(format!(
586                    "CopyBothResponse invalid negative column count: {}",
587                    raw
588                ));
589            }
590            raw as usize
591        } else {
592            0
593        };
594        let mut column_formats = Vec::with_capacity(num_columns);
595        let mut pos = 3usize;
596        for _ in 0..num_columns {
597            if pos + 2 > payload.len() {
598                return Err("CopyBothResponse truncated column format list".to_string());
599            }
600            let raw = i16::from_be_bytes([payload[pos], payload[pos + 1]]);
601            if !(0..=1).contains(&raw) {
602                return Err(format!("CopyBothResponse invalid format code: {}", raw));
603            }
604            column_formats.push(raw as u8);
605            pos += 2;
606        }
607        if pos != payload.len() {
608            return Err("CopyBothResponse has trailing bytes".to_string());
609        }
610        Ok(BackendMessage::CopyBothResponse {
611            format,
612            column_formats,
613        })
614    }
615
616    fn decode_notification_response(payload: &[u8]) -> Result<Self, String> {
617        if payload.len() < 6 {
618            // Minimum: 4 (process_id) + 1 (channel NUL) + 1 (payload NUL)
619            return Err("NotificationResponse too short".to_string());
620        }
621        let process_id = i32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
622
623        // Channel name (null-terminated)
624        let mut i = 4;
625        let remaining = payload.get(i..).unwrap_or(&[]);
626        let channel_end = remaining
627            .iter()
628            .position(|&b| b == 0)
629            .ok_or("NotificationResponse: missing channel null terminator")?;
630        let channel = String::from_utf8_lossy(&remaining[..channel_end]).to_string();
631        i += channel_end + 1;
632
633        // Payload (null-terminated)
634        let remaining = payload.get(i..).unwrap_or(&[]);
635        let payload_end = remaining
636            .iter()
637            .position(|&b| b == 0)
638            .ok_or("NotificationResponse: missing payload null terminator")?;
639        let notification_payload = String::from_utf8_lossy(&remaining[..payload_end]).to_string();
640        if i + payload_end + 1 != payload.len() {
641            return Err("NotificationResponse has trailing bytes".to_string());
642        }
643
644        Ok(BackendMessage::NotificationResponse {
645            process_id,
646            channel,
647            payload: notification_payload,
648        })
649    }
650}